29
29
from nnunetv2 .utilities .find_class_by_name import recursive_find_python_class
30
30
from nnunetv2 .utilities .helpers import empty_cache , dummy_context
31
31
from nnunetv2 .utilities .json_export import recursive_fix_for_json_export
32
- from nnunetv2 .utilities .label_handling .label_handling import determine_num_input_channels
32
+ from nnunetv2 .utilities .label_handling .label_handling import determine_num_input_channels , convert_labelmap_to_one_hot
33
33
from nnunetv2 .utilities .plans_handling .plans_handler import PlansManager , ConfigurationManager
34
34
from nnunetv2 .utilities .utils import create_lists_from_splitted_dataset_folder
35
35
@@ -250,13 +250,87 @@ def predict_from_files(self,
250
250
if len (list_of_lists_or_source_folder ) == 0 :
251
251
return
252
252
253
+ if num_processes_preprocessing == 0 and num_processes_segmentation_export == 0 :
254
+ return self ._sequential_prediction (list_of_lists_or_source_folder , seg_from_prev_stage_files ,
255
+ output_filename_truncated , save_probabilities )
256
+
253
257
data_iterator = self ._internal_get_data_iterator_from_lists_of_filenames (list_of_lists_or_source_folder ,
254
258
seg_from_prev_stage_files ,
255
259
output_filename_truncated ,
256
260
num_processes_preprocessing )
257
261
258
262
return self .predict_from_data_iterator (data_iterator , save_probabilities , num_processes_segmentation_export )
259
263
264
+ def _load_data_for_prediction (self , input_file , input_seg , properties , preprocessor , plans_manager ,
265
+ configuration_manager , dataset_json , label_manager ):
266
+ if properties is not None :
267
+ data , seg = preprocessor .run_case_npy (
268
+ input_file ,
269
+ input_seg ,
270
+ properties ,
271
+ plans_manager ,
272
+ configuration_manager ,
273
+ dataset_json )
274
+ else :
275
+ data , seg , properties = preprocessor .run_case (
276
+ input_file ,
277
+ input_seg ,
278
+ plans_manager ,
279
+ configuration_manager ,
280
+ dataset_json )
281
+
282
+ if input_seg is not None :
283
+ seg_onehot = convert_labelmap_to_one_hot (input_seg [0 ], label_manager .foreground_labels , data .dtype )
284
+ data = np .vstack ((data , seg_onehot ))
285
+
286
+ data = torch .from_numpy (data ).to (dtype = torch .float32 , memory_format = torch .contiguous_format )
287
+ if self .device .type == 'cuda' :
288
+ data = data .pin_memory ()
289
+ return data , properties
290
+
291
+ @torch .inference_mode ()
292
+ def _sequential_prediction (self , input_list_of_lists , seg_from_prev_stage_files ,
293
+ output_filename_truncated , save_probabilities ):
294
+ ret = []
295
+ configuration_manager = self .configuration_manager
296
+ preprocessor = configuration_manager .preprocessor_class (verbose = self .verbose_preprocessing )
297
+ plans_manager = self .plans_manager
298
+ dataset_json = self .dataset_json
299
+ label_manager = plans_manager .get_label_manager (dataset_json )
300
+
301
+ for i in range (len (input_list_of_lists )):
302
+ ofile = output_filename_truncated [i ] if output_filename_truncated is not None else None
303
+ if ofile is not None :
304
+ print (f'\n Predicting { os .path .basename (ofile )} :' )
305
+ else :
306
+ print (f'\n Predicting image of shape { data .shape } :' )
307
+
308
+ data , properties = self ._load_data_for_prediction (input_list_of_lists [i ],
309
+ seg_from_prev_stage_files [
310
+ i ] if seg_from_prev_stage_files is not None else None ,
311
+ None ,
312
+ preprocessor , plans_manager , configuration_manager ,
313
+ dataset_json , label_manager )
314
+
315
+ prediction = self .predict_logits_from_preprocessed_data (data )
316
+
317
+ if ofile is not None :
318
+ print ('resampling and export' )
319
+ export_prediction_from_logits (
320
+ prediction , properties , self .configuration_manager , self .plans_manager , self .dataset_json , ofile ,
321
+ save_probabilities )
322
+ print (f'done with { os .path .basename (ofile )} ' )
323
+ else :
324
+ print ('resampling' )
325
+ ret .append (convert_predicted_logits_to_segmentation_with_correct_shape (
326
+ prediction , self .plans_manager , self .configuration_manager , self .label_manager , properties ,
327
+ save_probabilities ))
328
+ print (f'\n Done with image of shape { data .shape } :' )
329
+
330
+ compute_gaussian .cache_clear ()
331
+ empty_cache (self .device )
332
+ return ret
333
+
260
334
def _internal_get_data_iterator_from_lists_of_filenames (self ,
261
335
input_list_of_lists : List [List [str ]],
262
336
seg_from_prev_stage_files : Union [List [str ], None ],
@@ -418,6 +492,7 @@ def predict_from_data_iterator(self,
418
492
empty_cache (self .device )
419
493
return ret
420
494
495
+ @torch .inference_mode ()
421
496
def predict_single_npy_array (self , input_image : np .ndarray , image_properties : dict ,
422
497
segmentation_previous_stage : np .ndarray = None ,
423
498
output_file_truncated : str = None ,
@@ -435,35 +510,29 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
435
510
you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!
436
511
image_properties must only have a 'spacing' key!
437
512
"""
438
- ppa = PreprocessAdapterFromNpy ([ input_image ], [ segmentation_previous_stage ], [ image_properties ] ,
439
- [ output_file_truncated ],
440
- self . plans_manager , self . dataset_json , self .configuration_manager ,
441
- num_threads_in_multithreaded = 1 , verbose = self .verbose )
442
- if self .verbose :
443
- print ( 'preprocessing' )
444
- dct = next ( ppa )
513
+ data , properties = self . _load_data_for_prediction ( input_image , segmentation_previous_stage , image_properties ,
514
+ self . configuration_manager . preprocessor_class (
515
+ verbose = self .verbose_preprocessing ) ,
516
+ self .plans_manager ,
517
+ self .configuration_manager ,
518
+ self . dataset_json ,
519
+ self . plans_manager . get_label_manager ( self . dataset_json ) )
445
520
446
- if self .verbose :
447
- print ('predicting' )
448
- predicted_logits = self .predict_logits_from_preprocessed_data (dct ['data' ]).cpu ()
521
+ predicted_logits = self .predict_logits_from_preprocessed_data (data )
449
522
450
523
if self .verbose :
451
524
print ('resampling to original shape' )
452
525
if output_file_truncated is not None :
453
- export_prediction_from_logits (predicted_logits , dct [ 'data_properties' ] , self .configuration_manager ,
526
+ export_prediction_from_logits (predicted_logits , properties , self .configuration_manager ,
454
527
self .plans_manager , self .dataset_json , output_file_truncated ,
455
528
save_or_return_probabilities )
456
529
else :
457
- ret = convert_predicted_logits_to_segmentation_with_correct_shape (predicted_logits , self .plans_manager ,
458
- self .configuration_manager ,
459
- self .label_manager ,
460
- dct ['data_properties' ],
461
- return_probabilities =
462
- save_or_return_probabilities )
463
- if save_or_return_probabilities :
464
- return ret [0 ], ret [1 ]
465
- else :
466
- return ret
530
+ return convert_predicted_logits_to_segmentation_with_correct_shape (predicted_logits , self .plans_manager ,
531
+ self .configuration_manager ,
532
+ self .label_manager ,
533
+ properties ,
534
+ return_probabilities =
535
+ save_or_return_probabilities )
467
536
468
537
def predict_logits_from_preprocessed_data (self , data : torch .Tensor ) -> torch .Tensor :
469
538
"""
0 commit comments