@@ -104,14 +104,14 @@ status_t jit_uni_fork_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
104
104
const memory_desc_wrapper bias_d (&bias_md);
105
105
106
106
const int ndims = src_d.ndims ();
107
- // Currently this kernel only supports 2D and 3D convolutions.
108
- if (ndims != 4 && ndims != 5 ) return status::unimplemented;
109
107
110
- const auto blocked_tag = (ndims == 5 ) ? one_of (isa, avx512_core, avx512_core) ? nCdhw16c : nCdhw8c
111
- : one_of (isa, avx512_core, avx512_core) ? nChw16c : nChw8c;
112
- const auto wei_tag = (ndims == 5 ) ? one_of (isa, avx512_core, avx512_core) ? Goidhw16g : Goidhw8g
113
- : one_of (isa, avx512_core, avx512_core) ? Goihw16g : Goihw8g;
114
- const auto nxc_tag = (ndims == 5 ) ? ndhwc : nhwc;
108
+ const auto blocked_tag = one_of (isa, avx512_core) ?
109
+ pick (ndims - 3 , nCw16c, nChw16c, nCdhw16c) :
110
+ pick (ndims - 3 , nCw8c, nChw8c, nCdhw8c);
111
+ const auto wei_tag = one_of (isa, avx512_core) ?
112
+ pick (ndims - 3 , Goiw16g, Goihw16g, Goidhw16g) :
113
+ pick (ndims - 3 , Goiw8g, Goihw8g, Goidhw8g);
114
+ const auto nxc_tag = pick (ndims - 3 , nwc, nhwc, ndhwc);
115
115
116
116
jcp.with_bias = cd.bias_desc .format_kind != format_kind::undef;
117
117
@@ -172,29 +172,29 @@ status_t jit_uni_fork_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
172
172
jcp.ic = src_d.dims ()[1 ];
173
173
174
174
jcp.id = (ndims == 5 ) ? src_d.dims ()[2 ] : 1 ;
175
- jcp.ih = src_d.dims ()[ndims - 2 ];
175
+ jcp.ih = (ndims == 3 ) ? 1 : src_d.dims ()[ndims - 2 ];
176
176
jcp.iw = src_d.dims ()[ndims - 1 ];
177
177
jcp.od = (ndims == 5 ) ? dst_d.dims ()[2 ] : 1 ;
178
- jcp.oh = dst_d.dims ()[ndims - 2 ];
178
+ jcp.oh = (ndims == 3 ) ? 1 : dst_d.dims ()[ndims - 2 ];
179
179
jcp.ow = dst_d.dims ()[ndims - 1 ];
180
180
181
181
jcp.kd = (ndims == 5 ) ? weights_d.dims ()[3 ] : 1 ;
182
- jcp.kh = weights_d.dims ()[ndims - 1 ];
182
+ jcp.kh = (ndims == 3 ) ? 1 : weights_d.dims ()[ndims - 1 ];
183
183
jcp.kw = weights_d.dims ()[ndims];
184
184
185
185
jcp.f_pad = (ndims == 5 ) ? cd.padding [0 ][0 ] : 0 ;
186
- jcp.t_pad = cd.padding [0 ][ndims - 4 ];
186
+ jcp.t_pad = (ndims == 3 ) ? 0 : cd.padding [0 ][ndims - 4 ];
187
187
jcp.l_pad = cd.padding [0 ][ndims - 3 ];
188
188
jcp.back_pad = (ndims == 5 ) ? cd.padding [1 ][0 ] : 0 ;
189
- jcp.b_pad = cd.padding [1 ][ndims - 4 ];
189
+ jcp.b_pad = (ndims == 3 ) ? 0 : cd.padding [1 ][ndims - 4 ];
190
190
jcp.r_pad = cd.padding [1 ][ndims - 3 ];
191
191
192
192
jcp.stride_d = (ndims == 5 ) ? cd.strides [0 ] : 1 ;
193
- jcp.stride_h = cd.strides [ndims - 4 ];
193
+ jcp.stride_h = (ndims == 3 ) ? 1 : cd.strides [ndims - 4 ];
194
194
jcp.stride_w = cd.strides [ndims - 3 ];
195
195
196
196
jcp.dilate_d = (ndims == 5 ) ? cd.dilates [0 ] : 0 ;
197
- jcp.dilate_h = cd.dilates [ndims - 4 ];
197
+ jcp.dilate_h = (ndims == 3 ) ? 0 : cd.dilates [ndims - 4 ];
198
198
jcp.dilate_w = cd.dilates [ndims - 3 ];
199
199
200
200
jcp.loop_order = loop_ngcw;
0 commit comments