Skip to content

Commit c03b62f

Browse files
Refactor grid search code
1 parent 1de67c7 commit c03b62f

File tree

1 file changed

+44
-49
lines changed

1 file changed

+44
-49
lines changed

torch_em/util/grid_search.py

+44-49
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def default_grid_search_values_boundary_based_instance_segmentation(
3838
}
3939

4040

41-
class BoundaryBasedInstanceSegmentation(InstanceSegmentationWithDecoder):
41+
class _InstanceSegmentationBase(InstanceSegmentationWithDecoder):
4242
def __init__(self, model, preprocess=None, block_shape=None, halo=None):
4343
self._model = model
4444
self._preprocess = standardize if preprocess is None else preprocess
@@ -47,24 +47,57 @@ def __init__(self, model, preprocess=None, block_shape=None, halo=None):
4747
self._block_shape = block_shape
4848
self._halo = halo
4949

50-
self._foreground = None
51-
self._boundaries = None
52-
5350
self._is_initialized = False
5451

55-
def initialize(self, data):
52+
def _initialize_torch(self, data):
5653
device = next(iter(self._model.parameters())).device
5754

5855
if self._block_shape is None:
59-
scale_factors = self._model.init_kwargs["scale_factors"]
60-
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
56+
if hasattr(self._model, "scale_factors"):
57+
scale_factors = self._model.init_kwargs["scale_factors"]
58+
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
59+
elif hasattr(self._model, "depth"):
60+
depth = self._model.depth
61+
min_divisible = [2**depth, 2**depth]
62+
else:
63+
raise RuntimeError
6164
input_ = self._preprocess(data)
6265
output = predict_with_padding(self._model, input_, min_divisible, device)
6366
else:
6467
output = predict_with_halo(
6568
data, self._model, [device], self._block_shape, self._halo,
6669
preprocess=self._preprocess,
6770
)
71+
return output
72+
73+
def _initialize_modelzoo(self, data):
74+
if self._block_shape is None:
75+
with bioimageio.core.create_prediction_pipeline(self._model) as pp:
76+
dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx")
77+
input_ = xarray.DataArray(data[None, None], dims=dims)
78+
output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0]
79+
output = output.squeeze().values
80+
else:
81+
raise NotImplementedError
82+
return output
83+
84+
85+
class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase):
86+
def __init__(self, model, preprocess=None, block_shape=None, halo=None):
87+
super().__init__(
88+
model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
89+
)
90+
91+
self._foreground = None
92+
self._boundaries = None
93+
94+
def initialize(self, data):
95+
if isinstance(self._model, nn.Module):
96+
output = self._initialize_torch(data)
97+
else:
98+
output = self._initialize_modelzoo(data)
99+
100+
assert output.shape[0] == 2
68101

69102
self._foreground = output[0]
70103
self._boundaries = output[1]
@@ -81,57 +114,19 @@ def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="bin
81114
return segmentation
82115

83116

84-
class DistanceBasedInstanceSegmentation(InstanceSegmentationWithDecoder):
117+
class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase):
85118
"""Over-write micro_sam functionality so that it works for distance based
86119
segmentation with a U-net.
87120
"""
88121
def __init__(self, model, preprocess=None, block_shape=None, halo=None):
89-
self._model = model
90-
self._preprocess = standardize if preprocess is None else preprocess
91-
92-
assert (block_shape is None) == (halo is None)
93-
self._block_shape = block_shape
94-
self._halo = halo
122+
super().__init__(
123+
model=model, preprocess=preprocess, block_shape=block_shape, halo=halo
124+
)
95125

96126
self._foreground = None
97127
self._center_distances = None
98128
self._boundary_distances = None
99129

100-
self._is_initialized = False
101-
102-
def _initialize_torch(self, data):
103-
device = next(iter(self._model.parameters())).device
104-
105-
if self._block_shape is None:
106-
if hasattr(self._model, "scale_factors"):
107-
scale_factors = self._model.init_kwargs["scale_factors"]
108-
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
109-
elif hasattr(self._model, "depth"):
110-
depth = self._model.depth
111-
min_divisible = [2**depth, 2**depth]
112-
else:
113-
raise RuntimeError
114-
input_ = self._preprocess(data)
115-
output = predict_with_padding(self._model, input_, min_divisible, device)
116-
else:
117-
output = predict_with_halo(
118-
data, self._model, [device], self._block_shape, self._halo,
119-
preprocess=self._preprocess,
120-
)
121-
return output
122-
123-
def _initialize_modelzoo(self, data):
124-
if self._block_shape is None:
125-
with bioimageio.core.create_prediction_pipeline(self._model) as pp:
126-
dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx")
127-
input_ = xarray.DataArray(data[None, None], dims=dims)
128-
output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0]
129-
output = output.squeeze().values
130-
else:
131-
raise NotImplementedError
132-
return output
133-
134-
# TODO refactor all this so that we can have a common base class that takes care of it
135130
def initialize(self, data):
136131
if isinstance(self._model, nn.Module):
137132
output = self._initialize_torch(data)

0 commit comments

Comments
 (0)