Skip to content

Commit c364f66

Browse files
authored
Minor update to expose padding mode and make resize flexible (#479)
1 parent e0b2356 commit c364f66

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

torch_em/transform/generic.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@ def __call__(self, inputs):
108108

109109

110110
class ResizeLongestSideInputs:
111-
def __init__(self, target_shape, is_label=False, is_rgb=False):
111+
def __init__(self, target_shape, is_label=False, is_rgb=False, padding_mode="constant"):
112112
self.target_shape = target_shape
113113
self.is_label = is_label
114114
self.is_rgb = is_rgb
115+
self.padding_mode = padding_mode
115116

116117
h, w = self.target_shape[-2], self.target_shape[-1]
117118
if h != w: # We currently support resize feature for square-shaped target shape only.
@@ -135,7 +136,7 @@ def _get_preprocess_shape(self, oldh, oldw):
135136
newh = int(newh + 0.5)
136137
return (newh, neww)
137138

138-
def convert_transformed_inputs_to_original_shape(self, resized_inputs):
139+
def convert_transformed_inputs_to_original_shape(self, resized_inputs, resize_kwargs=None):
139140
if not hasattr(self, "pre_pad_shape"):
140141
raise RuntimeError(
141142
"'convert_transformed_inputs_to_original_shape' is only valid after the '__call__' method has run."
@@ -144,8 +145,15 @@ def convert_transformed_inputs_to_original_shape(self, resized_inputs):
144145
# First step is to remove the padded region
145146
inputs = resized_inputs[tuple(self.pre_pad_shape)]
146147
# Next, we resize the inputs to original shape
148+
149+
if resize_kwargs is None: # This allows the user to change resize parameters, eg. for labels, if desired.
150+
resize_kwargs = self.kwargs
151+
else:
152+
if not isinstance(resize_kwargs, dict):
153+
raise RuntimeError("If the 'resize_kwargs' are provided, it must be a dictionary.")
154+
147155
inputs = resize(
148-
image=inputs, output_shape=self.original_shape, preserve_range=True, **self.kwargs
156+
image=inputs, output_shape=self.original_shape, preserve_range=True, **resize_kwargs
149157
)
150158
return inputs
151159

@@ -181,13 +189,14 @@ def __call__(self, inputs):
181189
# NOTE: We store this in case we would like to unpad the inputs.
182190
self.pre_pad_shape = [slice(pw[0], -pw[1] if pw[1] > 0 else None) for pw in pad_width]
183191

184-
inputs = np.pad(array=inputs, pad_width=pad_width, mode="constant")
192+
inputs = np.pad(array=inputs, pad_width=pad_width, mode=self.padding_mode)
185193
return inputs
186194

187195

188196
class PadIfNecessary:
189-
def __init__(self, shape):
197+
def __init__(self, shape, padding_mode="reflect"):
190198
self.shape = tuple(shape)
199+
self.padding_mode = padding_mode
191200

192201
def _pad_if_necessary(self, data):
193202
if data.ndim == len(self.shape):
@@ -204,7 +213,7 @@ def _pad_if_necessary(self, data):
204213
pad_width = [sh - dsh for dsh, sh in zip(data_shape, pad_shape)]
205214
assert all(pw >= 0 for pw in pad_width)
206215
pad_width = [(0, pw) for pw in pad_width]
207-
return np.pad(data, pad_width, mode="reflect")
216+
return np.pad(data, pad_width, mode=self.padding_mode)
208217

209218
def __call__(self, *inputs):
210219
outputs = tuple(self._pad_if_necessary(input_) for input_ in inputs)

0 commit comments

Comments
 (0)