@@ -38,7 +38,7 @@ def default_grid_search_values_boundary_based_instance_segmentation(
38
38
}
39
39
40
40
41
- class BoundaryBasedInstanceSegmentation (InstanceSegmentationWithDecoder ):
41
+ class _InstanceSegmentationBase (InstanceSegmentationWithDecoder ):
42
42
def __init__ (self , model , preprocess = None , block_shape = None , halo = None ):
43
43
self ._model = model
44
44
self ._preprocess = standardize if preprocess is None else preprocess
@@ -47,24 +47,57 @@ def __init__(self, model, preprocess=None, block_shape=None, halo=None):
47
47
self ._block_shape = block_shape
48
48
self ._halo = halo
49
49
50
- self ._foreground = None
51
- self ._boundaries = None
52
-
53
50
self ._is_initialized = False
54
51
55
- def initialize (self , data ):
52
+ def _initialize_torch (self , data ):
56
53
device = next (iter (self ._model .parameters ())).device
57
54
58
55
if self ._block_shape is None :
59
- scale_factors = self ._model .init_kwargs ["scale_factors" ]
60
- min_divisible = [int (np .prod ([sf [i ] for sf in scale_factors ])) for i in range (3 )]
56
+ if hasattr (self ._model , "scale_factors" ):
57
+ scale_factors = self ._model .init_kwargs ["scale_factors" ]
58
+ min_divisible = [int (np .prod ([sf [i ] for sf in scale_factors ])) for i in range (3 )]
59
+ elif hasattr (self ._model , "depth" ):
60
+ depth = self ._model .depth
61
+ min_divisible = [2 ** depth , 2 ** depth ]
62
+ else :
63
+ raise RuntimeError
61
64
input_ = self ._preprocess (data )
62
65
output = predict_with_padding (self ._model , input_ , min_divisible , device )
63
66
else :
64
67
output = predict_with_halo (
65
68
data , self ._model , [device ], self ._block_shape , self ._halo ,
66
69
preprocess = self ._preprocess ,
67
70
)
71
+ return output
72
+
73
+ def _initialize_modelzoo (self , data ):
74
+ if self ._block_shape is None :
75
+ with bioimageio .core .create_prediction_pipeline (self ._model ) as pp :
76
+ dims = tuple ("bcyx" ) if data .ndim == 2 else tuple ("bczyx" )
77
+ input_ = xarray .DataArray (data [None , None ], dims = dims )
78
+ output = bioimageio .core .prediction .predict_with_padding (pp , input_ , padding = True )[0 ]
79
+ output = output .squeeze ().values
80
+ else :
81
+ raise NotImplementedError
82
+ return output
83
+
84
+
85
+ class BoundaryBasedInstanceSegmentation (_InstanceSegmentationBase ):
86
+ def __init__ (self , model , preprocess = None , block_shape = None , halo = None ):
87
+ super ().__init__ (
88
+ model = model , preprocess = preprocess , block_shape = block_shape , halo = halo
89
+ )
90
+
91
+ self ._foreground = None
92
+ self ._boundaries = None
93
+
94
+ def initialize (self , data ):
95
+ if isinstance (self ._model , nn .Module ):
96
+ output = self ._initialize_torch (data )
97
+ else :
98
+ output = self ._initialize_modelzoo (data )
99
+
100
+ assert output .shape [0 ] == 2
68
101
69
102
self ._foreground = output [0 ]
70
103
self ._boundaries = output [1 ]
@@ -81,57 +114,19 @@ def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="bin
81
114
return segmentation
82
115
83
116
84
- class DistanceBasedInstanceSegmentation (InstanceSegmentationWithDecoder ):
117
+ class DistanceBasedInstanceSegmentation (_InstanceSegmentationBase ):
85
118
"""Over-write micro_sam functionality so that it works for distance based
86
119
segmentation with a U-net.
87
120
"""
88
121
def __init__ (self , model , preprocess = None , block_shape = None , halo = None ):
89
- self ._model = model
90
- self ._preprocess = standardize if preprocess is None else preprocess
91
-
92
- assert (block_shape is None ) == (halo is None )
93
- self ._block_shape = block_shape
94
- self ._halo = halo
122
+ super ().__init__ (
123
+ model = model , preprocess = preprocess , block_shape = block_shape , halo = halo
124
+ )
95
125
96
126
self ._foreground = None
97
127
self ._center_distances = None
98
128
self ._boundary_distances = None
99
129
100
- self ._is_initialized = False
101
-
102
- def _initialize_torch (self , data ):
103
- device = next (iter (self ._model .parameters ())).device
104
-
105
- if self ._block_shape is None :
106
- if hasattr (self ._model , "scale_factors" ):
107
- scale_factors = self ._model .init_kwargs ["scale_factors" ]
108
- min_divisible = [int (np .prod ([sf [i ] for sf in scale_factors ])) for i in range (3 )]
109
- elif hasattr (self ._model , "depth" ):
110
- depth = self ._model .depth
111
- min_divisible = [2 ** depth , 2 ** depth ]
112
- else :
113
- raise RuntimeError
114
- input_ = self ._preprocess (data )
115
- output = predict_with_padding (self ._model , input_ , min_divisible , device )
116
- else :
117
- output = predict_with_halo (
118
- data , self ._model , [device ], self ._block_shape , self ._halo ,
119
- preprocess = self ._preprocess ,
120
- )
121
- return output
122
-
123
- def _initialize_modelzoo (self , data ):
124
- if self ._block_shape is None :
125
- with bioimageio .core .create_prediction_pipeline (self ._model ) as pp :
126
- dims = tuple ("bcyx" ) if data .ndim == 2 else tuple ("bczyx" )
127
- input_ = xarray .DataArray (data [None , None ], dims = dims )
128
- output = bioimageio .core .prediction .predict_with_padding (pp , input_ , padding = True )[0 ]
129
- output = output .squeeze ().values
130
- else :
131
- raise NotImplementedError
132
- return output
133
-
134
- # TODO refactor all this so that we can have a common base class that takes care of it
135
130
def initialize (self , data ):
136
131
if isinstance (self ._model , nn .Module ):
137
132
output = self ._initialize_torch (data )
0 commit comments