Skip to content

Commit 1de67c7

Browse files
Update grid-search WIP
1 parent 2b45590 commit 1de67c7

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

torch_em/util/grid_search.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import numpy as np
2+
import torch.nn as nn
3+
import xarray
4+
5+
import bioimageio.core
26

37
from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder
48
from micro_sam.evaluation.instance_segmentation import (
@@ -95,20 +99,46 @@ def __init__(self, model, preprocess=None, block_shape=None, halo=None):
9599

96100
self._is_initialized = False
97101

98-
def initialize(self, data):
102+
def _initialize_torch(self, data):
99103
device = next(iter(self._model.parameters())).device
100104

101105
if self._block_shape is None:
102-
scale_factors = self._model.init_kwargs["scale_factors"]
103-
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
106+
if hasattr(self._model, "scale_factors"):
107+
scale_factors = self._model.init_kwargs["scale_factors"]
108+
min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)]
109+
elif hasattr(self._model, "depth"):
110+
depth = self._model.depth
111+
min_divisible = [2**depth, 2**depth]
112+
else:
113+
raise RuntimeError
104114
input_ = self._preprocess(data)
105115
output = predict_with_padding(self._model, input_, min_divisible, device)
106116
else:
107117
output = predict_with_halo(
108118
data, self._model, [device], self._block_shape, self._halo,
109119
preprocess=self._preprocess,
110120
)
121+
return output
122+
123+
def _initialize_modelzoo(self, data):
124+
if self._block_shape is None:
125+
with bioimageio.core.create_prediction_pipeline(self._model) as pp:
126+
dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx")
127+
input_ = xarray.DataArray(data[None, None], dims=dims)
128+
output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0]
129+
output = output.squeeze().values
130+
else:
131+
raise NotImplementedError
132+
return output
133+
134+
# TODO refactor all this so that we can have a common base class that takes care of it
135+
def initialize(self, data):
136+
if isinstance(self._model, nn.Module):
137+
output = self._initialize_torch(data)
138+
else:
139+
output = self._initialize_modelzoo(data)
111140

141+
assert output.shape[0] == 3
112142
self._foreground = output[0]
113143
self._center_distances = output[1]
114144
self._boundary_distances = output[2]

0 commit comments

Comments
 (0)