Skip to content

Commit b2cdc2c

Browse files
committed
[FORK][FIX][x64] Add proper post op checks to gemm_conv
1 parent 4e29b77 commit b2cdc2c

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

src/cpu/gemm_convolution_utils.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ struct ref_pp_kernel_t : pp_kernel_t {
8080
ref_depthwise_injectors_.clear();
8181
}
8282

83+
static bool post_ops_ok(const convolution_pd_t *pd) {
84+
using namespace dnnl::impl::primitive_kind;
85+
const auto& po = pd->attr()->post_ops_;
86+
for (int i = 0; i < po.len(); i++) {
87+
if (!utils::one_of(po.entry_[i].kind, eltwise, depthwise, quantization)) {
88+
return false;
89+
}
90+
}
91+
return true;
92+
}
93+
8394
virtual void operator()(float *dst_orig, float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride,
8495
const std::vector<const void *>& post_ops_binary_rhs_arg_vec) const override;
8596

@@ -197,9 +208,12 @@ pp_kernel_t *pp_kernel_t::create(
197208
if (res) return res;
198209
#endif
199210

200-
return new ref_pp_kernel_t(pd, jcp);
201-
}
211+
if (ref_pp_kernel_t::post_ops_ok(pd)) {
212+
return new ref_pp_kernel_t(pd, jcp);
213+
}
202214

215+
return nullptr;
216+
}
203217
} // namespace gemm_convolution_utils
204218

205219
namespace jit_gemm_convolution_utils {

src/cpu/x64/jit_gemm_convolution_utils.cpp

+20-3
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,23 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
110110
}
111111
}
112112

113+
static bool post_ops_ok(const convolution_pd_t *pd) {
114+
const auto& post_ops = pd->attr()->post_ops_;
115+
auto dst_md = pd->dst_md();
116+
for (int i = 0; i < post_ops.len(); i++) {
117+
const auto &post_op = post_ops.entry_[i];
118+
if (post_op.is_binary()) {
119+
if (!binary_injector::is_supported(isa,
120+
binary_injector::get_src1_desc(post_op, *dst_md),
121+
*dst_md,
122+
default_strategies())) {
123+
return false;
124+
}
125+
}
126+
}
127+
return true;
128+
}
129+
113130
private:
114131
void generate() override;
115132

@@ -391,11 +408,11 @@ void jit_pp_kernel_t<isa>::generate() {
391408

392409
pp_kernel_t *jit_pp_kernel_create(
393410
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) {
394-
if (mayiuse(avx512_core)) {
411+
if (mayiuse(avx512_core) && jit_pp_kernel_t<avx512_core>::post_ops_ok(pd)) {
395412
return new jit_pp_kernel_t<avx512_core>(pd, jcp);
396-
} else if (mayiuse(avx2)) {
413+
} else if (mayiuse(avx2) && jit_pp_kernel_t<avx2>::post_ops_ok(pd)) {
397414
return new jit_pp_kernel_t<avx2>(pd, jcp);
398-
} else if (mayiuse(sse41)) {
415+
} else if (mayiuse(sse41) && jit_pp_kernel_t<sse41>::post_ops_ok(pd)) {
399416
return new jit_pp_kernel_t<sse41>(pd, jcp);
400417
}
401418
return nullptr;

0 commit comments

Comments
 (0)