Skip to content

Commit a7412f3

Browse files
dmitry-gorokhovazhai219
authored andcommitted
[FORK][FEATURE] Enable jit sse41 NxN convolution for grayscale input
1 parent 4c8fb2c commit a7412f3

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/cpu/x64/jit_sse41_conv_kernel_f32.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
468468
sum_requires_scale_one, sum_requires_zp_zero));
469469
VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP);
470470

471-
const bool flat = jcp.ic == 3;
471+
const bool flat = jcp.ic == 3 || jcp.ic == 1;
472472
const bool mimo = !flat;
473473

474474
const bool tag_ok = true

src/cpu/x64/jit_sse41_convolution.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ struct jit_sse41_convolution_fwd_t : public primitive_t {
8888
&& IMPLICATION(curr_dst_tag != dat_tag_nxc,
8989
dst_d.format_kind() == format_kind::any)
9090
&& utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
91-
const bool flat = IC() == 3;
91+
const bool flat = IC() == 3 || IC() == 1;;
9292
auto src_tag = is_data_layout_nxc ? dat_tag_nxc
9393
: flat ? dat_tag_ncx
9494
: dat_tag_nCx8c;

0 commit comments

Comments
 (0)