Skip to content

Commit 7c26e5a

Browse files
Enable Pytorch 2.6 (#8309)
Partially addresses #8303. ### Description This changes the maximum Numpy version to be below 3.0 for testing with 2.x compatibility. This appears to be resolved with newer versions of dependencies. This will also include fixes for Pytorch 2.6 mostly relating to `torch.load` and `autocast` usage. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1983f27 commit 7c26e5a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+184
-178
lines changed

monai/apps/deepedit/interaction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d
7272

7373
with torch.no_grad():
7474
if engine.amp:
75-
with torch.cuda.amp.autocast():
75+
with torch.autocast("cuda"):
7676
predictions = engine.inferer(inputs, engine.network)
7777
else:
7878
predictions = engine.inferer(inputs, engine.network)

monai/apps/deepgrow/interaction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d
6767
engine.network.eval()
6868
with torch.no_grad():
6969
if engine.amp:
70-
with torch.cuda.amp.autocast():
70+
with torch.autocast("cuda"):
7171
predictions = engine.inferer(inputs, engine.network)
7272
else:
7373
predictions = engine.inferer(inputs, engine.network)

monai/apps/detection/networks/retinanet_detector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def forward(self, images: torch.Tensor):
180180
nesterov=True,
181181
)
182182
torch.save(detector.network.state_dict(), 'model.pt') # save model
183-
detector.network.load_state_dict(torch.load('model.pt')) # load model
183+
detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
184184
"""
185185

186186
def __init__(

monai/apps/detection/networks/retinanet_network.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def __init__(
8888

8989
for layer in self.conv.children():
9090
if isinstance(layer, conv_type): # type: ignore
91-
torch.nn.init.normal_(layer.weight, std=0.01)
92-
torch.nn.init.constant_(layer.bias, 0)
91+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
92+
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
9393

9494
self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
9595
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
@@ -167,8 +167,8 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):
167167

168168
for layer in self.conv.children():
169169
if isinstance(layer, conv_type): # type: ignore
170-
torch.nn.init.normal_(layer.weight, std=0.01)
171-
torch.nn.init.zeros_(layer.bias)
170+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
171+
torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]
172172

173173
def forward(self, x: list[Tensor]) -> list[Tensor]:
174174
"""
@@ -297,7 +297,7 @@ def __init__(
297297
)
298298
self.feature_extractor = feature_extractor
299299

300-
self.feature_map_channels: int = self.feature_extractor.out_channels
300+
self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
301301
self.num_anchors = num_anchors
302302
self.classification_head = RetinaNetClassificationHead(
303303
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims

monai/apps/detection/utils/box_coder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,15 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor:
221221

222222
pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
223223
pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
224-
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
224+
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]
225225

226226
# When convert float32 to float16, Inf or Nan may occur
227227
if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
228228
raise ValueError("pred_whd_axis is NaN or Inf.")
229229

230230
# Distance from center to box's corner.
231231
c_to_c_whd_axis = (
232-
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
232+
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
233233
)
234234

235235
pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)

monai/apps/mmars/mmars.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def load_from_mmar(
241241
return torch.jit.load(_model_file, map_location=map_location)
242242

243243
# loading with `torch.load`
244-
model_dict = torch.load(_model_file, map_location=map_location)
244+
model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
245245
if weights_only:
246246
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
247247

monai/apps/reconstruction/networks/blocks/varnetblock.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor:
5555
Returns:
5656
Output of DC block with the same shape as x
5757
"""
58-
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
58+
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore
5959

6060
def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
6161
"""

monai/bundle/scripts.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def load(
760760
if load_ts_module is True:
761761
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
762762
# loading with `torch.load`
763-
model_dict = torch.load(full_path, map_location=torch.device(device))
763+
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)
764764

765765
if not isinstance(model_dict, Mapping):
766766
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
@@ -1279,9 +1279,8 @@ def verify_net_in_out(
12791279
if input_dtype == torch.float16:
12801280
# fp16 can only be executed in gpu mode
12811281
net.to("cuda")
1282-
from torch.cuda.amp import autocast
12831282

1284-
with autocast():
1283+
with torch.autocast("cuda"):
12851284
output = net(test_data.cuda(), **extra_forward_args_)
12861285
net.to(device_)
12871286
else:
@@ -1330,7 +1329,7 @@ def _export(
13301329
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
13311330
Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
13321331
else:
1333-
ckpt = torch.load(ckpt_file)
1332+
ckpt = torch.load(ckpt_file, weights_only=True)
13341333
copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])
13351334

13361335
# Use the given converter to convert a model and save with metadata, config content

monai/data/dataset.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import warnings
2323
from collections.abc import Callable, Sequence
2424
from copy import copy, deepcopy
25-
from inspect import signature
2625
from multiprocessing.managers import ListProxy
2726
from multiprocessing.pool import ThreadPool
2827
from pathlib import Path
@@ -372,10 +371,7 @@ def _cachecheck(self, item_transformed):
372371

373372
if hashfile is not None and hashfile.is_file(): # cache hit
374373
try:
375-
if "weights_only" in signature(torch.load).parameters:
376-
return torch.load(hashfile, weights_only=False)
377-
else:
378-
return torch.load(hashfile)
374+
return torch.load(hashfile, weights_only=False)
379375
except PermissionError as e:
380376
if sys.platform != "win32":
381377
raise e
@@ -1674,7 +1670,4 @@ def _load_meta_cache(self, meta_hash_file_name):
16741670
if meta_hash_file_name in self._meta_cache:
16751671
return self._meta_cache[meta_hash_file_name]
16761672
else:
1677-
if "weights_only" in signature(torch.load).parameters:
1678-
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
1679-
else:
1680-
return torch.load(self.cache_dir / meta_hash_file_name)
1673+
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)

monai/data/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
753753
if isinstance(_affine, torch.Tensor):
754754
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
755755
else:
756-
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
756+
spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
757757
if suppress_zeros:
758758
spacing[spacing == 0] = 1.0
759759
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)

monai/data/video_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def get_available_codecs() -> dict[str, str]:
177177
for codec, ext in all_codecs.items():
178178
writer = cv2.VideoWriter()
179179
fname = os.path.join(tmp_dir, f"test{ext}")
180-
fourcc = cv2.VideoWriter_fourcc(*codec)
180+
fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
181181
noviderr = writer.open(fname, fourcc, 1, (10, 10))
182182
if noviderr:
183183
codecs[codec] = ext

monai/engines/evaluator.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class Evaluator(Workflow):
8282
default to `True`.
8383
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
8484
`device`, `non_blocking`.
85-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
86-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
85+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
86+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
8787
8888
"""
8989

@@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
214214
default to `True`.
215215
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
216216
`device`, `non_blocking`.
217-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
218-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
217+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
218+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
219219
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
220220
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
221221
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -324,7 +324,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
324324
# execute forward computation
325325
with engine.mode(engine.network):
326326
if engine.amp:
327-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
327+
with torch.autocast("cuda", **engine.amp_kwargs):
328328
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
329329
else:
330330
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
@@ -394,8 +394,8 @@ class EnsembleEvaluator(Evaluator):
394394
default to `True`.
395395
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
396396
`device`, `non_blocking`.
397-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
398-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
397+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
398+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
399399
400400
"""
401401

@@ -487,7 +487,7 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tenso
487487
for idx, network in enumerate(engine.networks):
488488
with engine.mode(network):
489489
if engine.amp:
490-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
490+
with torch.autocast("cuda", **engine.amp_kwargs):
491491
if isinstance(engine.state.output, dict):
492492
engine.state.output.update(
493493
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}

monai/engines/trainer.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ class SupervisedTrainer(Trainer):
125125
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
126126
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
127127
`device`, `non_blocking`.
128-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
129-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
128+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
129+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
130130
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
131131
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
132132
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -249,7 +249,7 @@ def _compute_pred_loss():
249249
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
250250

251251
if engine.amp and engine.scaler is not None:
252-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
252+
with torch.autocast("cuda", **engine.amp_kwargs):
253253
_compute_pred_loss()
254254
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
255255
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
@@ -335,8 +335,8 @@ class GanTrainer(Trainer):
335335
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
336336
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
337337
`device`, `non_blocking`.
338-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
339-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
338+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
339+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
340340
341341
"""
342342

@@ -512,8 +512,8 @@ class AdversarialTrainer(Trainer):
512512
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
513513
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
514514
`device`, `non_blocking`.
515-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
516-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
515+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
516+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
517517
"""
518518

519519
def __init__(
@@ -683,7 +683,7 @@ def _compute_generator_loss() -> None:
683683
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
684684

685685
if engine.amp and engine.state.g_scaler is not None:
686-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
686+
with torch.autocast("cuda", **engine.amp_kwargs):
687687
_compute_generator_loss()
688688

689689
engine.state.output[Keys.LOSS] = (
@@ -731,7 +731,7 @@ def _compute_discriminator_loss() -> None:
731731
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
732732

733733
if engine.amp and engine.state.d_scaler is not None:
734-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
734+
with torch.autocast("cuda", **engine.amp_kwargs):
735735
_compute_discriminator_loss()
736736

737737
engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()

monai/engines/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_nam
309309
self.scheduler = scheduler
310310

311311
def get_target(self, images, noise, timesteps):
312-
return self.scheduler.get_velocity(images, noise, timesteps)
312+
return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]
313313

314314

315315
def default_make_latent(

monai/engines/workflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ class Workflow(Engine):
9090
default to `True`.
9191
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
9292
`device`, `non_blocking`.
93-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
94-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
93+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
94+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
9595
9696
Raises:
9797
TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.

monai/fl/client/monai_algo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def get_weights(self, extra=None):
574574
model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
575575
if not os.path.isfile(model_path):
576576
raise ValueError(f"No best model checkpoint exists at {model_path}")
577-
weights = torch.load(model_path, map_location="cpu")
577+
weights = torch.load(model_path, map_location="cpu", weights_only=True)
578578
# if weights contain several state dicts, use the one defined by `save_dict_key`
579579
if isinstance(weights, dict) and self.save_dict_key in weights:
580580
weights = weights.get(self.save_dict_key)

monai/handlers/checkpoint_loader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
122122
Args:
123123
engine: Ignite Engine, it can be a trainer, validator or evaluator.
124124
"""
125-
checkpoint = torch.load(self.load_path, map_location=self.map_location)
125+
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
126126

127127
k, _ = list(self.load_dict.items())[0]
128128
# single object and checkpoint is directly a state_dict

0 commit comments

Comments
 (0)