Skip to content

Commit e80d558

Browse files
Merge branch 'MIC-DKFZ:master' into seq-inf
2 parents d04a129 + 2eaa371 commit e80d558

File tree

2 files changed

+34
-40
lines changed

2 files changed

+34
-40
lines changed

nnunetv2/training/dataloading/data_loader_2d.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -88,26 +88,23 @@ def generate_train_batch(self):
8888
seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1)
8989

9090
if self.transforms is not None:
91-
if torch is not None:
92-
torch_nthreads = torch.get_num_threads()
93-
torch.set_num_threads(1)
94-
with threadpool_limits(limits=1, user_api=None):
95-
data_all = torch.from_numpy(data_all).float()
96-
seg_all = torch.from_numpy(seg_all).to(torch.int16)
97-
images = []
98-
segs = []
99-
for b in range(self.batch_size):
100-
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
101-
images.append(tmp['image'])
102-
segs.append(tmp['segmentation'])
103-
data_all = torch.stack(images)
104-
if isinstance(segs[0], list):
105-
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
106-
else:
107-
seg_all = torch.stack(segs)
108-
del segs, images
109-
if torch is not None:
110-
torch.set_num_threads(torch_nthreads)
91+
with torch.no_grad():
92+
with threadpool_limits(limits=1, user_api=None):
93+
94+
data_all = torch.from_numpy(data_all).float()
95+
seg_all = torch.from_numpy(seg_all).to(torch.int16)
96+
images = []
97+
segs = []
98+
for b in range(self.batch_size):
99+
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
100+
images.append(tmp['image'])
101+
segs.append(tmp['segmentation'])
102+
data_all = torch.stack(images)
103+
if isinstance(segs[0], list):
104+
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
105+
else:
106+
seg_all = torch.stack(segs)
107+
del segs, images
111108

112109
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}
113110

nnunetv2/training/dataloading/data_loader_3d.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,23 @@ def generate_train_batch(self):
5151
seg_all[j] = np.pad(seg, padding, 'constant', constant_values=-1)
5252

5353
if self.transforms is not None:
54-
if torch is not None:
55-
torch_nthreads = torch.get_num_threads()
56-
torch.set_num_threads(1)
57-
with threadpool_limits(limits=1, user_api=None):
58-
data_all = torch.from_numpy(data_all).float()
59-
seg_all = torch.from_numpy(seg_all).to(torch.int16)
60-
images = []
61-
segs = []
62-
for b in range(self.batch_size):
63-
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
64-
images.append(tmp['image'])
65-
segs.append(tmp['segmentation'])
66-
data_all = torch.stack(images)
67-
if isinstance(segs[0], list):
68-
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
69-
else:
70-
seg_all = torch.stack(segs)
71-
del segs, images
72-
if torch is not None:
73-
torch.set_num_threads(torch_nthreads)
54+
with torch.no_grad():
55+
with threadpool_limits(limits=1, user_api=None):
56+
data_all = torch.from_numpy(data_all).float()
57+
seg_all = torch.from_numpy(seg_all).to(torch.int16)
58+
images = []
59+
segs = []
60+
for b in range(self.batch_size):
61+
tmp = self.transforms(**{'image': data_all[b], 'segmentation': seg_all[b]})
62+
images.append(tmp['image'])
63+
segs.append(tmp['segmentation'])
64+
data_all = torch.stack(images)
65+
if isinstance(segs[0], list):
66+
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
67+
else:
68+
seg_all = torch.stack(segs)
69+
del segs, images
70+
7471
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}
7572

7673
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}

0 commit comments

Comments
 (0)