Skip to content

Commit f00b262

Browse files
committed
Update UNETR - to use bilinear interpolation for upsampling
1 parent f5c24a5 commit f00b262

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

torch_em/model/unetr.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def __init__(
138138
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
139139
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
140140

141-
self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])
141+
self.deconv_out = Upsampler2d(
142+
scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
143+
)
142144

143145
self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])
144146

@@ -274,15 +276,6 @@ def forward(self, x):
274276
#
275277

276278

277-
class SingleDeconv2DBlock(nn.Module):
278-
def __init__(self, in_planes, out_planes):
279-
super().__init__()
280-
self.block = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)
281-
282-
def forward(self, x):
283-
return self.block(x)
284-
285-
286279
class SingleConv2DBlock(nn.Module):
287280
def __init__(self, in_planes, out_planes, kernel_size):
288281
super().__init__()
@@ -310,7 +303,7 @@ class Deconv2DBlock(nn.Module):
310303
def __init__(self, in_planes, out_planes, kernel_size=3):
311304
super().__init__()
312305
self.block = nn.Sequential(
313-
SingleDeconv2DBlock(in_planes, out_planes),
306+
Upsampler2d(scale_factor=2, in_channels=in_planes, out_channels=out_planes),
314307
SingleConv2DBlock(out_planes, out_planes, kernel_size),
315308
nn.BatchNorm2d(out_planes),
316309
nn.ReLU(True)

0 commit comments

Comments
 (0)