Skip to content

Commit fef5a45

Browse files
Adding sequential inference
1 parent d12a0c1 commit fef5a45

File tree

1 file changed

+91
-22
lines changed

1 file changed

+91
-22
lines changed

nnunetv2/inference/predict_from_raw_data.py

+91-22
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
3030
from nnunetv2.utilities.helpers import empty_cache, dummy_context
3131
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
32-
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
32+
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels, convert_labelmap_to_one_hot
3333
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
3434
from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder
3535

@@ -250,13 +250,87 @@ def predict_from_files(self,
250250
if len(list_of_lists_or_source_folder) == 0:
251251
return
252252

253+
if num_processes_preprocessing == 0 and num_processes_segmentation_export == 0:
254+
return self._sequential_prediction(list_of_lists_or_source_folder, seg_from_prev_stage_files,
255+
output_filename_truncated, save_probabilities)
256+
253257
data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
254258
seg_from_prev_stage_files,
255259
output_filename_truncated,
256260
num_processes_preprocessing)
257261

258262
return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)
259263

264+
def _load_data_for_prediction(self, input_file, input_seg, properties, preprocessor, plans_manager,
265+
configuration_manager, dataset_json, label_manager):
266+
if properties is not None:
267+
data, seg = preprocessor.run_case_npy(
268+
input_file,
269+
input_seg,
270+
properties,
271+
plans_manager,
272+
configuration_manager,
273+
dataset_json)
274+
else:
275+
data, seg, properties = preprocessor.run_case(
276+
input_file,
277+
input_seg,
278+
plans_manager,
279+
configuration_manager,
280+
dataset_json)
281+
282+
if input_seg is not None:
283+
seg_onehot = convert_labelmap_to_one_hot(input_seg[0], label_manager.foreground_labels, data.dtype)
284+
data = np.vstack((data, seg_onehot))
285+
286+
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)
287+
if self.device.type == 'cuda':
288+
data = data.pin_memory()
289+
return data, properties
290+
291+
@torch.inference_mode()
292+
def _sequential_prediction(self, input_list_of_lists, seg_from_prev_stage_files,
293+
output_filename_truncated, save_probabilities):
294+
ret = []
295+
configuration_manager = self.configuration_manager
296+
preprocessor = configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing)
297+
plans_manager = self.plans_manager
298+
dataset_json = self.dataset_json
299+
label_manager = plans_manager.get_label_manager(dataset_json)
300+
301+
for i in range(len(input_list_of_lists)):
302+
ofile = output_filename_truncated[i] if output_filename_truncated is not None else None
303+
if ofile is not None:
304+
print(f'\nPredicting {os.path.basename(ofile)}:')
305+
else:
306+
print(f'\nPredicting image of shape {data.shape}:')
307+
308+
data, properties = self._load_data_for_prediction(input_list_of_lists[i],
309+
seg_from_prev_stage_files[
310+
i] if seg_from_prev_stage_files is not None else None,
311+
None,
312+
preprocessor, plans_manager, configuration_manager,
313+
dataset_json, label_manager)
314+
315+
prediction = self.predict_logits_from_preprocessed_data(data)
316+
317+
if ofile is not None:
318+
print('resampling and export')
319+
export_prediction_from_logits(
320+
prediction, properties, self.configuration_manager, self.plans_manager, self.dataset_json, ofile,
321+
save_probabilities)
322+
print(f'done with {os.path.basename(ofile)}')
323+
else:
324+
print('resampling')
325+
ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(
326+
prediction, self.plans_manager, self.configuration_manager, self.label_manager, properties,
327+
save_probabilities))
328+
print(f'\nDone with image of shape {data.shape}:')
329+
330+
compute_gaussian.cache_clear()
331+
empty_cache(self.device)
332+
return ret
333+
260334
def _internal_get_data_iterator_from_lists_of_filenames(self,
261335
input_list_of_lists: List[List[str]],
262336
seg_from_prev_stage_files: Union[List[str], None],
@@ -418,6 +492,7 @@ def predict_from_data_iterator(self,
418492
empty_cache(self.device)
419493
return ret
420494

495+
@torch.inference_mode()
421496
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
422497
segmentation_previous_stage: np.ndarray = None,
423498
output_file_truncated: str = None,
@@ -435,35 +510,29 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
435510
you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!
436511
image_properties must only have a 'spacing' key!
437512
"""
438-
ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],
439-
[output_file_truncated],
440-
self.plans_manager, self.dataset_json, self.configuration_manager,
441-
num_threads_in_multithreaded=1, verbose=self.verbose)
442-
if self.verbose:
443-
print('preprocessing')
444-
dct = next(ppa)
513+
data, properties = self._load_data_for_prediction(input_image, segmentation_previous_stage, image_properties,
514+
self.configuration_manager.preprocessor_class(
515+
verbose=self.verbose_preprocessing),
516+
self.plans_manager,
517+
self.configuration_manager,
518+
self.dataset_json,
519+
self.plans_manager.get_label_manager(self.dataset_json))
445520

446-
if self.verbose:
447-
print('predicting')
448-
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()
521+
predicted_logits = self.predict_logits_from_preprocessed_data(data)
449522

450523
if self.verbose:
451524
print('resampling to original shape')
452525
if output_file_truncated is not None:
453-
export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager,
526+
export_prediction_from_logits(predicted_logits, properties, self.configuration_manager,
454527
self.plans_manager, self.dataset_json, output_file_truncated,
455528
save_or_return_probabilities)
456529
else:
457-
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
458-
self.configuration_manager,
459-
self.label_manager,
460-
dct['data_properties'],
461-
return_probabilities=
462-
save_or_return_probabilities)
463-
if save_or_return_probabilities:
464-
return ret[0], ret[1]
465-
else:
466-
return ret
530+
return convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
531+
self.configuration_manager,
532+
self.label_manager,
533+
properties,
534+
return_probabilities=
535+
save_or_return_probabilities)
467536

468537
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
469538
"""

0 commit comments

Comments
 (0)