|
1 | 1 | import numpy as np
|
| 2 | +import torch.nn as nn |
| 3 | +import xarray |
| 4 | + |
| 5 | +import bioimageio.core |
2 | 6 |
|
3 | 7 | from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder
|
4 | 8 | from micro_sam.evaluation.instance_segmentation import (
|
@@ -95,20 +99,46 @@ def __init__(self, model, preprocess=None, block_shape=None, halo=None):
|
95 | 99 |
|
96 | 100 | self._is_initialized = False
|
97 | 101 |
|
98 |
| - def initialize(self, data): |
| 102 | + def _initialize_torch(self, data): |
99 | 103 | device = next(iter(self._model.parameters())).device
|
100 | 104 |
|
101 | 105 | if self._block_shape is None:
|
102 |
| - scale_factors = self._model.init_kwargs["scale_factors"] |
103 |
| - min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] |
| 106 | + if hasattr(self._model, "scale_factors"): |
| 107 | + scale_factors = self._model.init_kwargs["scale_factors"] |
| 108 | + min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] |
| 109 | + elif hasattr(self._model, "depth"): |
| 110 | + depth = self._model.depth |
| 111 | + min_divisible = [2**depth, 2**depth] |
| 112 | + else: |
| 113 | + raise RuntimeError |
104 | 114 | input_ = self._preprocess(data)
|
105 | 115 | output = predict_with_padding(self._model, input_, min_divisible, device)
|
106 | 116 | else:
|
107 | 117 | output = predict_with_halo(
|
108 | 118 | data, self._model, [device], self._block_shape, self._halo,
|
109 | 119 | preprocess=self._preprocess,
|
110 | 120 | )
|
| 121 | + return output |
| 122 | + |
| 123 | + def _initialize_modelzoo(self, data): |
| 124 | + if self._block_shape is None: |
| 125 | + with bioimageio.core.create_prediction_pipeline(self._model) as pp: |
| 126 | + dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx") |
| 127 | + input_ = xarray.DataArray(data[None, None], dims=dims) |
| 128 | + output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0] |
| 129 | + output = output.squeeze().values |
| 130 | + else: |
| 131 | + raise NotImplementedError |
| 132 | + return output |
| 133 | + |
| 134 | + # TODO refactor all this so that we can have a common base class that takes care of it |
| 135 | + def initialize(self, data): |
| 136 | + if isinstance(self._model, nn.Module): |
| 137 | + output = self._initialize_torch(data) |
| 138 | + else: |
| 139 | + output = self._initialize_modelzoo(data) |
111 | 140 |
|
| 141 | + assert output.shape[0] == 3 |
112 | 142 | self._foreground = output[0]
|
113 | 143 | self._center_distances = output[1]
|
114 | 144 | self._boundary_distances = output[2]
|
|
0 commit comments