Skip to content

Commit f5c24a5

Browse files
Update grid-search (#209)
* Update grid-search WIP * Refactor grid search code
1 parent 2b45590 commit f5c24a5

File tree

1 file changed

+52
-27
lines changed

1 file changed

+52
-27
lines changed

torch_em/util/grid_search.py

+52-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import numpy as np
2+
import torch.nn as nn
3+
import xarray
4+
5+
import bioimageio.core
26

37
from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder
48
from micro_sam.evaluation.instance_segmentation import (
@@ -34,7 +38,7 @@ def default_grid_search_values_boundary_based_instance_segmentation(
3438
}
3539

3640

37-
class BoundaryBasedInstanceSegmentation(InstanceSegmentationWithDecoder):
41+
class _InstanceSegmentationBase(InstanceSegmentationWithDecoder):
3842
def __init__(self, model, preprocess=None, block_shape=None, halo=None):
3943
self._model = model
4044
self._preprocess = standardize if preprocess is None else preprocess
@@ -43,24 +47,57 @@ def __init__(self, model, preprocess=None, block_shape=None, halo=None):
4347
self._block_shape = block_shape
4448
self._halo = halo
4549

46-
self._foreground = None
47-
self._boundaries = None
48-
4950
self._is_initialized = False
5051

51-
def initialize(self, data):
52+
def _initialize_torch(self, data):
5253
device = next(iter(self._model.parameters())).device
5354

5455
if self._block_shape is None:
55-
scale_factors = self._model.init_kwargs["scale_factors"]
56-
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
56+
if hasattr(self._model, "scale_factors"):
57+
scale_factors = self._model.init_kwargs["scale_factors"]
58+
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
59+
elif hasattr(self._model, "depth"):
60+
depth = self._model.depth
61+
min_divisible = [2**depth, 2**depth]
62+
else:
63+
raise RuntimeError
5764
input_ = self._preprocess(data)
5865
output = predict_with_padding(self._model, input_, min_divisible, device)
5966
else:
6067
output = predict_with_halo(
6168
data, self._model, [device], self._block_shape, self._halo,
6269
preprocess=self._preprocess,
6370
)
71+
return output
72+
73+
def _initialize_modelzoo(self, data):
74+
if self._block_shape is None:
75+
with bioimageio.core.create_prediction_pipeline(self._model) as pp:
76+
dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx")
77+
input_ = xarray.DataArray(data[None, None], dims=dims)
78+
output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0]
79+
output = output.squeeze().values
80+
else:
81+
raise NotImplementedError
82+
return output
83+
84+
85+
class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
86+
def __init__(self, model, preprocess=None, block_shape=None, halo=None):
87+
super().__init__(
88+
model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
89+
)
90+
91+
self._foreground = None
92+
self._boundaries = None
93+
94+
def initialize(self, data):
95+
if isinstance(self._model, nn.Module):
96+
output = self._initialize_torch(data)
97+
else:
98+
output = self._initialize_modelzoo(data)
99+
100+
assert output.shape[0] == 2
64101

65102
self._foreground = output[0]
66103
self._boundaries = output[1]
@@ -77,38 +114,26 @@ def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="bin
77114
return segmentation
78115

79116

80-
class DistanceBasedInstanceSegmentation(InstanceSegmentationWithDecoder):
117+
class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
81118
"""Over-write micro_sam functionality so that it works for distance based
82119
segmentation with a U-net.
83120
"""
84121
def __init__(self, model, preprocess=None, block_shape=None, halo=None):
85-
self._model = model
86-
self._preprocess = standardize if preprocess is None else preprocess
87-
88-
assert (block_shape is None) == (halo is None)
89-
self._block_shape = block_shape
90-
self._halo = halo
122+
super().__init__(
123+
model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
124+
)
91125

92126
self._foreground = None
93127
self._center_distances = None
94128
self._boundary_distances = None
95129

96-
self._is_initialized = False
97-
98130
def initialize(self, data):
99-
device = next(iter(self._model.parameters())).device
100-
101-
if self._block_shape is None:
102-
scale_factors = self._model.init_kwargs["scale_factors"]
103-
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
104-
input_ = self._preprocess(data)
105-
output = predict_with_padding(self._model, input_, min_divisible, device)
131+
if isinstance(self._model, nn.Module):
132+
output = self._initialize_torch(data)
106133
else:
107-
output = predict_with_halo(
108-
data, self._model, [device], self._block_shape, self._halo,
109-
preprocess=self._preprocess,
110-
)
134+
output = self._initialize_modelzoo(data)
111135

136+
assert output.shape[0] == 3
112137
self._foreground = output[0]
113138
self._center_distances = output[1]
114139
self._boundary_distances = output[2]

0 commit comments

Comments
 (0)