@@ -108,10 +108,11 @@ def __call__(self, inputs):
108
108
109
109
110
110
class ResizeLongestSideInputs :
111
- def __init__ (self , target_shape , is_label = False , is_rgb = False ):
111
+ def __init__ (self , target_shape , is_label = False , is_rgb = False , padding_mode = "constant" ):
112
112
self .target_shape = target_shape
113
113
self .is_label = is_label
114
114
self .is_rgb = is_rgb
115
+ self .padding_mode = padding_mode
115
116
116
117
h , w = self .target_shape [- 2 ], self .target_shape [- 1 ]
117
118
if h != w : # We currently support resize feature for square-shaped target shape only.
@@ -135,7 +136,7 @@ def _get_preprocess_shape(self, oldh, oldw):
135
136
newh = int (newh + 0.5 )
136
137
return (newh , neww )
137
138
138
- def convert_transformed_inputs_to_original_shape (self , resized_inputs ):
139
+ def convert_transformed_inputs_to_original_shape (self , resized_inputs , resize_kwargs = None ):
139
140
if not hasattr (self , "pre_pad_shape" ):
140
141
raise RuntimeError (
141
142
"'convert_transformed_inputs_to_original_shape' is only valid after the '__call__' method has run."
@@ -144,8 +145,15 @@ def convert_transformed_inputs_to_original_shape(self, resized_inputs):
144
145
# First step is to remove the padded region
145
146
inputs = resized_inputs [tuple (self .pre_pad_shape )]
146
147
# Next, we resize the inputs to original shape
148
+
149
+ if resize_kwargs is None : # This allows the user to change resize parameters, eg. for labels, if desired.
150
+ resize_kwargs = self .kwargs
151
+ else :
152
+ if not isinstance (resize_kwargs , dict ):
153
+ raise RuntimeError ("If the 'resize_kwargs' are provided, it must be a dictionary." )
154
+
147
155
inputs = resize (
148
- image = inputs , output_shape = self .original_shape , preserve_range = True , ** self . kwargs
156
+ image = inputs , output_shape = self .original_shape , preserve_range = True , ** resize_kwargs
149
157
)
150
158
return inputs
151
159
@@ -181,13 +189,14 @@ def __call__(self, inputs):
181
189
# NOTE: We store this in case we would like to unpad the inputs.
182
190
self .pre_pad_shape = [slice (pw [0 ], - pw [1 ] if pw [1 ] > 0 else None ) for pw in pad_width ]
183
191
184
- inputs = np .pad (array = inputs , pad_width = pad_width , mode = "constant" )
192
+ inputs = np .pad (array = inputs , pad_width = pad_width , mode = self . padding_mode )
185
193
return inputs
186
194
187
195
188
196
class PadIfNecessary :
189
- def __init__ (self , shape ):
197
+ def __init__ (self , shape , padding_mode = "reflect" ):
190
198
self .shape = tuple (shape )
199
+ self .padding_mode = padding_mode
191
200
192
201
def _pad_if_necessary (self , data ):
193
202
if data .ndim == len (self .shape ):
@@ -204,7 +213,7 @@ def _pad_if_necessary(self, data):
204
213
pad_width = [sh - dsh for dsh , sh in zip (data_shape , pad_shape )]
205
214
assert all (pw >= 0 for pw in pad_width )
206
215
pad_width = [(0 , pw ) for pw in pad_width ]
207
- return np .pad (data , pad_width , mode = "reflect" )
216
+ return np .pad (data , pad_width , mode = self . padding_mode )
208
217
209
218
def __call__ (self , * inputs ):
210
219
outputs = tuple (self ._pad_if_necessary (input_ ) for input_ in inputs )
0 commit comments