@@ -138,7 +138,9 @@ def __init__(
138
138
self .deconv3 = Deconv2DBlock (features_decoder [1 ], features_decoder [2 ])
139
139
self .deconv4 = Deconv2DBlock (features_decoder [2 ], features_decoder [3 ])
140
140
141
- self .deconv_out = SingleDeconv2DBlock (features_decoder [- 1 ], features_decoder [- 1 ])
141
+ self .deconv_out = Upsampler2d (
142
+ scale_factor = 2 , in_channels = features_decoder [- 1 ], out_channels = features_decoder [- 1 ]
143
+ )
142
144
143
145
self .decoder_head = ConvBlock2d (2 * features_decoder [- 1 ], features_decoder [- 1 ])
144
146
@@ -274,15 +276,6 @@ def forward(self, x):
274
276
#
275
277
276
278
277
- class SingleDeconv2DBlock (nn .Module ):
278
- def __init__ (self , in_planes , out_planes ):
279
- super ().__init__ ()
280
- self .block = nn .ConvTranspose2d (in_planes , out_planes , kernel_size = 2 , stride = 2 , padding = 0 , output_padding = 0 )
281
-
282
- def forward (self , x ):
283
- return self .block (x )
284
-
285
-
286
279
class SingleConv2DBlock (nn .Module ):
287
280
def __init__ (self , in_planes , out_planes , kernel_size ):
288
281
super ().__init__ ()
@@ -310,7 +303,7 @@ class Deconv2DBlock(nn.Module):
310
303
def __init__ (self , in_planes , out_planes , kernel_size = 3 ):
311
304
super ().__init__ ()
312
305
self .block = nn .Sequential (
313
- SingleDeconv2DBlock ( in_planes , out_planes ),
306
+ Upsampler2d ( scale_factor = 2 , in_channels = in_planes , out_channels = out_planes ),
314
307
SingleConv2DBlock (out_planes , out_planes , kernel_size ),
315
308
nn .BatchNorm2d (out_planes ),
316
309
nn .ReLU (True )
0 commit comments