-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathtrain.py
31 lines (24 loc) · 1.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, DiceCELoss
import torch
from preporcess import prepare
from utilities import train
data_dir = 'D:/Youtube/Organ and Tumor Segmentation/datasets/Task03_Liver/Data_Train_Test'
model_dir = 'D:/Youtube/Organ and Tumor Segmentation/results/results'
data_in = prepare(data_dir, cache=True)
device = torch.device("cuda:0")
model = UNet(
dimensions=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
#loss_function = DiceCELoss(to_onehot_y=True, sigmoid=True, squared_pred=True, ce_weight=calculate_weights(1792651250,2510860).to(device))
loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=1e-5, amsgrad=True)
if __name__ == '__main__':
train(model, data_in, loss_function, optimizer, 600, model_dir)