Skip to content

Commit f97a0e9

Browse files
KumoLiupre-commit-ci[bot]monai-bot
authored
Remove deprecated functionality for v1.5 (#8430)
remove deprecated functionality Part of #8421 ### Description Removed Functionality: - metrics: Removed `compute_percent_hausdorff_distance`. Use `compute_hausdorff_distance` with the `percentile` argument instead. - bundle: Removed `net_name`, `net_kwargs`, and `return_state_dict` arguments from `load()`. Use the `model` argument for network instantiation. - bundle: Removed `workflow` argument from `BundleWorkflow` and `ConfigWorkflow`. Use `workflow_type` instead. - networks: Removed `img_size` argument from `SwinUNETR`. Input size checks are now performed during `forward()`. Default Value Changes: - `GeneralizedDiceScore`: Changed default `reduction` from `MEAN_BATCH` to `MEAN`. - `CropForeground` / `CropForegroundd`: Changed default `allow_smaller` from `True` to `False`. - `get_mask_edges`: Changed default `always_return_as_numpy` from `True` to `False`. - `generate_spatial_bounding_box`: Changed default `allow_smaller` from `True` to `False`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: monai-bot <monai.miccai2019@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent 4305bb8 commit f97a0e9

File tree

14 files changed

+34
-166
lines changed

14 files changed

+34
-166
lines changed

docs/source/metrics.rst

-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ Metrics
9898
`Hausdorff distance`
9999
--------------------
100100
.. autofunction:: compute_hausdorff_distance
101-
.. autofunction:: compute_percent_hausdorff_distance
102101

103102
.. autoclass:: HausdorffDistanceMetric
104103
:members:

monai/apps/deepgrow/transforms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ def __call__(self, data):
441441

442442
if np.all(np.less(current_size, self.spatial_size)):
443443
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
444-
box_start = np.array([s.start for s in cropper.slices])
445-
box_end = np.array([s.stop for s in cropper.slices])
444+
box_start = [s.start for s in cropper.slices]
445+
box_end = [s.stop for s in cropper.slices]
446446
else:
447447
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
448448

monai/bundle/scripts.py

+1-24
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from monai._version import get_versions
3333
from monai.apps.utils import _basename, download_url, extractall, get_logger
34-
from monai.bundle.config_item import ConfigComponent
3534
from monai.bundle.config_parser import ConfigParser
3635
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
3736
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
@@ -48,7 +47,6 @@
4847
from monai.utils import (
4948
IgniteInfo,
5049
check_parent_dir,
51-
deprecated_arg,
5250
ensure_tuple,
5351
get_equivalent_dtype,
5452
min_version,
@@ -629,9 +627,6 @@ def download(
629627
_check_monai_version(bundle_dir_, name_)
630628

631629

632-
@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
633-
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
634-
@deprecated_arg("return_state_dict", since="1.2", removed="1.5")
635630
def load(
636631
name: str,
637632
model: torch.nn.Module | None = None,
@@ -650,10 +645,7 @@ def load(
650645
workflow_name: str | BundleWorkflow | None = None,
651646
args_file: str | None = None,
652647
copy_model_args: dict | None = None,
653-
return_state_dict: bool = True,
654648
net_override: dict | None = None,
655-
net_name: str | None = None,
656-
**net_kwargs: Any,
657649
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
658650
"""
659651
Load model weights or TorchScript module of a bundle.
@@ -699,12 +691,7 @@ def load(
699691
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
700692
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
701693
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
702-
return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network
703-
from `_workflow.network_def` will be instantiated and load the achieved weights.
704694
net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
705-
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
706-
This argument only works when loading weights.
707-
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.
708695
709696
Returns:
710697
1. If `load_ts_module` is `False` and `model` is `None`,
@@ -719,9 +706,6 @@ def load(
719706
when `model` and `net_name` are all `None`.
720707
721708
"""
722-
if return_state_dict and (model is not None or net_name is not None):
723-
warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.")
724-
725709
bundle_dir_ = _process_bundle_dir(bundle_dir)
726710
net_override = {} if net_override is None else net_override
727711
copy_model_args = {} if copy_model_args is None else copy_model_args
@@ -757,11 +741,8 @@ def load(
757741
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
758742
model_dict = get_state_dict(model_dict)
759743

760-
if return_state_dict:
761-
return model_dict
762-
763744
_workflow = None
764-
if model is None and net_name is None:
745+
if model is None:
765746
bundle_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
766747
if bundle_config_file.is_file():
767748
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
@@ -781,10 +762,6 @@ def load(
781762
return model_dict
782763
else:
783764
model = _workflow.network_def
784-
elif net_name is not None:
785-
net_kwargs["_target_"] = net_name
786-
configer = ConfigComponent(config=net_kwargs)
787-
model = configer.instantiate() # type: ignore
788765

789766
model.to(device) # type: ignore
790767

monai/bundle/workflows.py

+1-27
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties
2828
from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY
2929
from monai.config import PathLike
30-
from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, ensure_tuple
30+
from monai.utils import BundleProperty, BundlePropertyConfig, ensure_tuple
3131

3232
__all__ = ["BundleWorkflow", "ConfigWorkflow"]
3333

@@ -45,10 +45,6 @@ class BundleWorkflow(ABC):
4545
or "infer", "inference", "eval", "evaluation" for a inference workflow,
4646
other unsupported string will raise a ValueError.
4747
default to `None` for only using meta properties.
48-
workflow: specifies the workflow type: "train" or "training" for a training workflow,
49-
or "infer", "inference", "eval", "evaluation" for a inference workflow,
50-
other unsupported string will raise a ValueError.
51-
default to `None` for common workflow.
5248
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
5349
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
5450
properties will default to loading from "meta". If `properties_path` is None, default properties
@@ -65,17 +61,9 @@ class BundleWorkflow(ABC):
6561
supported_train_type: tuple = ("train", "training")
6662
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")
6763

68-
@deprecated_arg(
69-
"workflow",
70-
since="1.2",
71-
removed="1.5",
72-
new_name="workflow_type",
73-
msg_suffix="please use `workflow_type` instead.",
74-
)
7564
def __init__(
7665
self,
7766
workflow_type: str | None = None,
78-
workflow: str | None = None,
7967
properties_path: PathLike | None = None,
8068
meta_file: str | Sequence[str] | None = None,
8169
logging_file: str | None = None,
@@ -102,7 +90,6 @@ def __init__(
10290
)
10391
meta_file = None
10492

105-
workflow_type = workflow if workflow is not None else workflow_type
10693
if workflow_type is not None:
10794
if workflow_type.lower() in self.supported_train_type:
10895
workflow_type = "train"
@@ -403,10 +390,6 @@ class ConfigWorkflow(BundleWorkflow):
403390
or "infer", "inference", "eval", "evaluation" for a inference workflow,
404391
other unsupported string will raise a ValueError.
405392
default to `None` for common workflow.
406-
workflow: specifies the workflow type: "train" or "training" for a training workflow,
407-
or "infer", "inference", "eval", "evaluation" for a inference workflow,
408-
other unsupported string will raise a ValueError.
409-
default to `None` for common workflow.
410393
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
411394
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
412395
properties will default to loading from "train". If `properties_path` is None, default properties
@@ -419,13 +402,6 @@ class ConfigWorkflow(BundleWorkflow):
419402
420403
"""
421404

422-
@deprecated_arg(
423-
"workflow",
424-
since="1.2",
425-
removed="1.5",
426-
new_name="workflow_type",
427-
msg_suffix="please use `workflow_type` instead.",
428-
)
429405
def __init__(
430406
self,
431407
config_file: str | Sequence[str],
@@ -436,11 +412,9 @@ def __init__(
436412
final_id: str = "finalize",
437413
tracking: str | dict | None = None,
438414
workflow_type: str | None = "train",
439-
workflow: str | None = None,
440415
properties_path: PathLike | None = None,
441416
**override: Any,
442417
) -> None:
443-
workflow_type = workflow if workflow is not None else workflow_type
444418
if config_file is not None:
445419
_config_files = ensure_tuple(config_file)
446420
config_root_path = Path(_config_files[0]).parent

monai/metrics/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .fid import FIDMetric, compute_frechet_distance
2020
from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score
2121
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
22-
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
22+
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance
2323
from .loss_metric import LossMetric
2424
from .meandice import DiceHelper, DiceMetric, compute_dice
2525
from .meaniou import MeanIoU, compute_iou

monai/metrics/generalized_dice.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515

1616
from monai.metrics.utils import do_metric_reduction, ignore_background
17-
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option
17+
from monai.utils import MetricReduction, Weight, deprecated_arg, look_up_option
1818

1919
from .metric import CumulativeIterationMetric
2020

@@ -37,28 +37,19 @@ class GeneralizedDiceScore(CumulativeIterationMetric):
3737
reduction: Define mode of reduction to the metrics. Available reduction modes:
3838
{``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
3939
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
40+
Default value is changed from `MetricReduction.MEAN_BATCH` to `MetricReduction.MEAN` in v1.5.0.
41+
Old versions computed `mean` when `mean_batch` was provided due to bug in reduction.
4042
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
4143
ground truth volume into a weight factor. Defaults to ``"square"``.
4244
4345
Raises:
4446
ValueError: When the `reduction` is not one of MetricReduction enum.
4547
"""
4648

47-
@deprecated_arg_default(
48-
"reduction",
49-
old_default=MetricReduction.MEAN_BATCH,
50-
new_default=MetricReduction.MEAN,
51-
since="1.4.0",
52-
replaced="1.5.0",
53-
msg_suffix=(
54-
"Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
55-
"If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
56-
),
57-
)
5849
def __init__(
5950
self,
6051
include_background: bool = True,
61-
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
52+
reduction: MetricReduction | str = MetricReduction.MEAN,
6253
weight_type: Weight | str = Weight.SQUARE,
6354
) -> None:
6455
super().__init__()

monai/metrics/hausdorff_distance.py

+3-37
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,12 @@
1717
import numpy as np
1818
import torch
1919

20-
from monai.metrics.utils import (
21-
do_metric_reduction,
22-
get_edge_surface_distance,
23-
get_surface_distance,
24-
ignore_background,
25-
prepare_spacing,
26-
)
27-
from monai.utils import MetricReduction, convert_data_type, deprecated
20+
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
21+
from monai.utils import MetricReduction, convert_data_type
2822

2923
from .metric import CumulativeIterationMetric
3024

31-
__all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance", "compute_percent_hausdorff_distance"]
25+
__all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance"]
3226

3327

3428
class HausdorffDistanceMetric(CumulativeIterationMetric):
@@ -216,31 +210,3 @@ def _compute_percentile_hausdorff_distance(
216210
if 0 <= percentile <= 100:
217211
return torch.quantile(surface_distance, percentile / 100)
218212
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")
219-
220-
221-
@deprecated(since="1.3.0", removed="1.5.0")
222-
def compute_percent_hausdorff_distance(
223-
edges_pred: np.ndarray,
224-
edges_gt: np.ndarray,
225-
distance_metric: str = "euclidean",
226-
percentile: float | None = None,
227-
spacing: int | float | np.ndarray | Sequence[int | float] | None = None,
228-
) -> float:
229-
"""
230-
This function is used to compute the directed Hausdorff distance.
231-
"""
232-
233-
surface_distance: np.ndarray = get_surface_distance( # type: ignore
234-
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing
235-
)
236-
237-
# for both pred and gt do not have foreground
238-
if surface_distance.shape == (0,):
239-
return np.nan
240-
241-
if not percentile:
242-
return surface_distance.max() # type: ignore[no-any-return]
243-
244-
if 0 <= percentile <= 100:
245-
return np.percentile(surface_distance, percentile) # type: ignore[no-any-return]
246-
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")

monai/metrics/utils.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
convert_to_numpy,
3131
convert_to_tensor,
3232
deprecated_arg,
33-
deprecated_arg_default,
3433
ensure_tuple_rep,
3534
look_up_option,
3635
optional_import,
@@ -131,9 +130,6 @@ def do_metric_reduction(
131130
return f, not_nans
132131

133132

134-
@deprecated_arg_default(
135-
name="always_return_as_numpy", since="1.3.0", replaced="1.5.0", old_default=True, new_default=False
136-
)
137133
@deprecated_arg(
138134
name="always_return_as_numpy",
139135
since="1.5.0",
@@ -146,7 +142,7 @@ def get_mask_edges(
146142
label_idx: int = 1,
147143
crop: bool = True,
148144
spacing: Sequence | None = None,
149-
always_return_as_numpy: bool = True,
145+
always_return_as_numpy: bool = False,
150146
) -> tuple[NdarrayTensor, NdarrayTensor]:
151147
"""
152148
Compute edges from binary segmentation masks. This
@@ -175,6 +171,7 @@ def get_mask_edges(
175171
otherwise `scipy`'s binary erosion is used to calculate the edges.
176172
always_return_as_numpy: whether to a numpy array regardless of the input type.
177173
If False, return the same type as inputs.
174+
The default value is changed from `True` to `False` in v1.5.0.
178175
"""
179176
# move in the funciton to avoid using all the GPUs
180177
cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion")

monai/networks/nets/swin_unetr.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
2626
from monai.networks.layers import DropPath, trunc_normal_
2727
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
28-
from monai.utils.deprecate_utils import deprecated_arg
2928

3029
rearrange, _ = optional_import("einops", name="rearrange")
3130

@@ -50,16 +49,8 @@ class SwinUNETR(nn.Module):
5049
<https://arxiv.org/abs/2201.01266>"
5150
"""
5251

53-
@deprecated_arg(
54-
name="img_size",
55-
since="1.3",
56-
removed="1.5",
57-
msg_suffix="The img_size argument is not required anymore and "
58-
"checks on the input size are run during forward().",
59-
)
6052
def __init__(
6153
self,
62-
img_size: Sequence[int] | int,
6354
in_channels: int,
6455
out_channels: int,
6556
patch_size: int = 2,
@@ -83,10 +74,6 @@ def __init__(
8374
) -> None:
8475
"""
8576
Args:
86-
img_size: spatial dimension of input image.
87-
This argument is only used for checking that the input image size is divisible by the patch size.
88-
The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5.
89-
It will be removed in an upcoming version.
9077
in_channels: dimension of input channels.
9178
out_channels: dimension of output channels.
9279
patch_size: size of the patch token.
@@ -113,13 +100,13 @@ def __init__(
113100
Examples::
114101
115102
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
116-
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
103+
>>> net = SwinUNETR(in_channels=1, out_channels=4, feature_size=48)
117104
118105
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
119-
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
106+
>>> net = SwinUNETR(in_channels=4, out_channels=3, depths=(2,4,2,2))
120107
121108
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
122-
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
109+
>>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
123110
124111
"""
125112

@@ -130,12 +117,9 @@ def __init__(
130117

131118
self.patch_size = patch_size
132119

133-
img_size = ensure_tuple_rep(img_size, spatial_dims)
134120
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
135121
window_size = ensure_tuple_rep(window_size, spatial_dims)
136122

137-
self._check_input_size(img_size)
138-
139123
if not (0 <= drop_rate <= 1):
140124
raise ValueError("dropout rate should be between 0 and 1.")
141125

@@ -1109,7 +1093,7 @@ def filter_swinunetr(key, value):
11091093
from monai.networks.utils import copy_model_state
11101094
from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr
11111095
1112-
model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
1096+
model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)
11131097
resource = (
11141098
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
11151099
)

0 commit comments

Comments
 (0)