1
1
import numpy as np
2
+ import torch .nn as nn
3
+ import xarray
4
+
5
+ import bioimageio .core
2
6
3
7
from micro_sam .instance_segmentation import InstanceSegmentationWithDecoder
4
8
from micro_sam .evaluation .instance_segmentation import (
@@ -34,7 +38,7 @@ def default_grid_search_values_boundary_based_instance_segmentation(
34
38
}
35
39
36
40
37
- class BoundaryBasedInstanceSegmentation (InstanceSegmentationWithDecoder ):
41
+ class _InstanceSegmentationBase (InstanceSegmentationWithDecoder ):
38
42
def __init__ (self , model , preprocess = None , block_shape = None , halo = None ):
39
43
self ._model = model
40
44
self ._preprocess = standardize if preprocess is None else preprocess
@@ -43,24 +47,57 @@ def __init__(self, model, preprocess=None, block_shape=None, halo=None):
43
47
self ._block_shape = block_shape
44
48
self ._halo = halo
45
49
46
- self ._foreground = None
47
- self ._boundaries = None
48
-
49
50
self ._is_initialized = False
50
51
51
- def initialize (self , data ):
52
+ def _initialize_torch (self , data ):
52
53
device = next (iter (self ._model .parameters ())).device
53
54
54
55
if self ._block_shape is None :
55
- scale_factors = self ._model .init_kwargs ["scale_factors" ]
56
- min_divisible = [int (np .prod ([sf [i ] for sf in scale_factors ])) for i in range (3 )]
56
+ if hasattr (self ._model , "scale_factors" ):
57
+ scale_factors = self ._model .init_kwargs ["scale_factors" ]
58
+ min_divisible = [int (np .prod ([sf [i ] for sf in scale_factors ])) for i in range (3 )]
59
+ elif hasattr (self ._model , "depth" ):
60
+ depth = self ._model .depth
61
+ min_divisible = [2 ** depth , 2 ** depth ]
62
+ else :
63
+ raise RuntimeError
57
64
input_ = self ._preprocess (data )
58
65
output = predict_with_padding (self ._model , input_ , min_divisible , device )
59
66
else :
60
67
output = predict_with_halo (
61
68
data , self ._model , [device ], self ._block_shape , self ._halo ,
62
69
preprocess = self ._preprocess ,
63
70
)
71
+ return output
72
+
73
+ def _initialize_modelzoo (self , data ):
74
+ if self ._block_shape is None :
75
+ with bioimageio .core .create_prediction_pipeline (self ._model ) as pp :
76
+ dims = tuple ("bcyx" ) if data .ndim == 2 else tuple ("bczyx" )
77
+ input_ = xarray .DataArray (data [None , None ], dims = dims )
78
+ output = bioimageio .core .prediction .predict_with_padding (pp , input_ , padding = True )[0 ]
79
+ output = output .squeeze ().values
80
+ else :
81
+ raise NotImplementedError
82
+ return output
83
+
84
+
85
+ class BoundaryBasedInstanceSegmentation (_InstanceSegmentationBase ):
86
+ def __init__ (self , model , preprocess = None , block_shape = None , halo = None ):
87
+ super ().__init__ (
88
+ model = model , preprocess = preprocess , block_shape = block_shape , halo = halo
89
+ )
90
+
91
+ self ._foreground = None
92
+ self ._boundaries = None
93
+
94
+ def initialize (self , data ):
95
+ if isinstance (self ._model , nn .Module ):
96
+ output = self ._initialize_torch (data )
97
+ else :
98
+ output = self ._initialize_modelzoo (data )
99
+
100
+ assert output .shape [0 ] == 2
64
101
65
102
self ._foreground = output [0 ]
66
103
self ._boundaries = output [1 ]
@@ -77,38 +114,26 @@ def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="bin
77
114
return segmentation
78
115
79
116
80
- class DistanceBasedInstanceSegmentation (InstanceSegmentationWithDecoder ):
117
+ class DistanceBasedInstanceSegmentation (_InstanceSegmentationBase ):
81
118
"""Over-write micro_sam functionality so that it works for distance based
82
119
segmentation with a U-net.
83
120
"""
84
121
def __init__ (self , model , preprocess = None , block_shape = None , halo = None ):
85
- self ._model = model
86
- self ._preprocess = standardize if preprocess is None else preprocess
87
-
88
- assert (block_shape is None ) == (halo is None )
89
- self ._block_shape = block_shape
90
- self ._halo = halo
122
+ super ().__init__ (
123
+ model = model , preprocess = preprocess , block_shape = block_shape , halo = halo
124
+ )
91
125
92
126
self ._foreground = None
93
127
self ._center_distances = None
94
128
self ._boundary_distances = None
95
129
96
- self ._is_initialized = False
97
-
98
130
def initialize (self , data ):
99
- device = next (iter (self ._model .parameters ())).device
100
-
101
- if self ._block_shape is None :
102
- scale_factors = self ._model .init_kwargs ["scale_factors" ]
103
- min_divisible = [int (np .prod ([sf [i ] for sf in scale_factors ])) for i in range (3 )]
104
- input_ = self ._preprocess (data )
105
- output = predict_with_padding (self ._model , input_ , min_divisible , device )
131
+ if isinstance (self ._model , nn .Module ):
132
+ output = self ._initialize_torch (data )
106
133
else :
107
- output = predict_with_halo (
108
- data , self ._model , [device ], self ._block_shape , self ._halo ,
109
- preprocess = self ._preprocess ,
110
- )
134
+ output = self ._initialize_modelzoo (data )
111
135
136
+ assert output .shape [0 ] == 3
112
137
self ._foreground = output [0 ]
113
138
self ._center_distances = output [1 ]
114
139
self ._boundary_distances = output [2 ]
0 commit comments