Skip to content

Commit 80f7c3d

Browse files
committed
disable torch.compile for mps, give clearer error messages, fix #2244
1 parent 0dd755a commit 80f7c3d

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

+11
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,24 @@ def initialize(self):
238238
def _do_i_compile(self):
239239
# new default: compile is enabled!
240240

241+
# compile does not work on mps
242+
if self.device == torch.device('mps'):
243+
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
244+
self.print_to_log_file("INFO: torch.compile disabled because of unsupported mps device")
245+
return False
246+
241247
# CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable
242248
if self.device == torch.device('cpu'):
249+
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
250+
self.print_to_log_file("INFO: torch.compile disabled because device is CPU")
243251
return False
244252

245253
# default torch.compile doesn't work on windows because there are apparently no triton wheels for it
246254
# https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2
247255
if os.name == 'nt':
256+
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
257+
self.print_to_log_file("INFO: torch.compile disabled because Windows is not natively supported. If "
258+
"you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2")
248259
return False
249260

250261
if 'nnUNet_compile' not in os.environ.keys():

0 commit comments

Comments
 (0)