@@ -51,26 +51,23 @@ def generate_train_batch(self):
51
51
seg_all [j ] = np .pad (seg , padding , 'constant' , constant_values = - 1 )
52
52
53
53
if self .transforms is not None :
54
- if torch is not None :
55
- torch_nthreads = torch .get_num_threads ()
56
- torch .set_num_threads (1 )
57
- with threadpool_limits (limits = 1 , user_api = None ):
58
- data_all = torch .from_numpy (data_all ).float ()
59
- seg_all = torch .from_numpy (seg_all ).to (torch .int16 )
60
- images = []
61
- segs = []
62
- for b in range (self .batch_size ):
63
- tmp = self .transforms (** {'image' : data_all [b ], 'segmentation' : seg_all [b ]})
64
- images .append (tmp ['image' ])
65
- segs .append (tmp ['segmentation' ])
66
- data_all = torch .stack (images )
67
- if isinstance (segs [0 ], list ):
68
- seg_all = [torch .stack ([s [i ] for s in segs ]) for i in range (len (segs [0 ]))]
69
- else :
70
- seg_all = torch .stack (segs )
71
- del segs , images
72
- if torch is not None :
73
- torch .set_num_threads (torch_nthreads )
54
+ with torch .no_grad ():
55
+ with threadpool_limits (limits = 1 , user_api = None ):
56
+ data_all = torch .from_numpy (data_all ).float ()
57
+ seg_all = torch .from_numpy (seg_all ).to (torch .int16 )
58
+ images = []
59
+ segs = []
60
+ for b in range (self .batch_size ):
61
+ tmp = self .transforms (** {'image' : data_all [b ], 'segmentation' : seg_all [b ]})
62
+ images .append (tmp ['image' ])
63
+ segs .append (tmp ['segmentation' ])
64
+ data_all = torch .stack (images )
65
+ if isinstance (segs [0 ], list ):
66
+ seg_all = [torch .stack ([s [i ] for s in segs ]) for i in range (len (segs [0 ]))]
67
+ else :
68
+ seg_all = torch .stack (segs )
69
+ del segs , images
70
+
74
71
return {'data' : data_all , 'target' : seg_all , 'keys' : selected_keys }
75
72
76
73
return {'data' : data_all , 'target' : seg_all , 'keys' : selected_keys }
0 commit comments