Commit 80f7c3d 1 parent 0dd755a commit 80f7c3d Copy full SHA for 80f7c3d
File tree 1 file changed +11
-0
lines changed
nnunetv2/training/nnUNetTrainer
1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -238,13 +238,24 @@ def initialize(self):
238
238
def _do_i_compile (self ):
239
239
# new default: compile is enabled!
240
240
241
+ # compile does not work on mps
242
+ if self .device == torch .device ('mps' ):
243
+ if 'nnUNet_compile' in os .environ .keys () and os .environ ['nnUNet_compile' ].lower () in ('true' , '1' , 't' ):
244
+ self .print_to_log_file ("INFO: torch.compile disabled because of unsupported mps device" )
245
+ return False
246
+
241
247
# CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable
242
248
if self .device == torch .device ('cpu' ):
249
+ if 'nnUNet_compile' in os .environ .keys () and os .environ ['nnUNet_compile' ].lower () in ('true' , '1' , 't' ):
250
+ self .print_to_log_file ("INFO: torch.compile disabled because device is CPU" )
243
251
return False
244
252
245
253
# default torch.compile doesn't work on windows because there are apparently no triton wheels for it
246
254
# https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2
247
255
if os .name == 'nt' :
256
+ if 'nnUNet_compile' in os .environ .keys () and os .environ ['nnUNet_compile' ].lower () in ('true' , '1' , 't' ):
257
+ self .print_to_log_file ("INFO: torch.compile disabled because Windows is not natively supported. If "
258
+ "you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2" )
248
259
return False
249
260
250
261
if 'nnUNet_compile' not in os .environ .keys ():
You can’t perform that action at this time.
0 commit comments