Skip to content

Commit 51f0eb5

Browse files
luo-cheng2021EgorDuplensky
authored andcommitted
gemm int8 support binary postops
1 parent 756a64e commit 51f0eb5

3 files changed

+49
-5
lines changed

src/cpu/gemm_x8s8s32x_convolution.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t {
102102
bool ok = true;
103103

104104
for (int i = 0; i < po.len(); i++) {
105-
ok = ok && utils::one_of(po.entry_[i].kind, sum, eltwise, depthwise, quantization);
105+
ok = ok && utils::one_of(po.entry_[i].kind, sum, binary, eltwise, depthwise, quantization);
106106
}
107107
return ok;
108108
};

src/cpu/x64/jit_gemm_convolution_utils.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
7474
const binary_injector::static_params_t bsp {this->reg_abi_bak, rhs_sp};
7575
jit_binary_injector_ = utils::make_unique<
7676
binary_injector::jit_uni_binary_injector_t<isa>>(
77-
this, bsp);
77+
this, bsp);
7878
}
7979
if (post_ops_.len() > 0 && !only_eltwise) {
8080
vreg_d_weights = Vmm(idx_compute_vreg_max_--);

src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp

+47-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <functional>
1919

2020
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
21+
#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
2122
#include "cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
2223
#include "cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp"
2324
#include "cpu/x64/jit_generator.hpp"
@@ -60,10 +61,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
6061
vreg_zero = Vmm(idx_compute_vreg_start_++);
6162
}
6263
bool only_eltwise_or_sum = true;
64+
bool with_binary = false;
6365
for (int idx = 0; idx < post_ops_.len(); ++idx) {
6466
const auto &e = post_ops_.entry_[idx];
6567
if (e.is_eltwise(true)) {
6668
do_eltwise_ = true;
69+
} else if (e.is_binary()) {
70+
with_binary = true;
71+
only_eltwise_or_sum = false;
6772
} else if (e.is_sum()) {
6873
do_sum_ = true;
6974
sum_scale_ = e.sum.scale;
@@ -72,6 +77,24 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
7277
only_eltwise_or_sum = false;
7378
}
7479
}
80+
if (with_binary) {
81+
#define PARAM_OFF(field) offsetof(ker_args_t, field)
82+
static constexpr bool preserve_gpr = true;
83+
static constexpr bool preserve_vmm = true;
84+
static constexpr size_t helper_vmm_idx = 1;
85+
static constexpr size_t tail_size = 0;
86+
static constexpr bool use_exact_tail_scalar_bcast = false;
87+
const binary_injector::rhs_arg_static_params_t rhs_sp {
88+
helper_vmm_idx, r13, r14, preserve_gpr,
89+
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
90+
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
91+
tail_size, kreg_rem_mask_short, use_exact_tail_scalar_bcast};
92+
#undef PARAM_OFF
93+
const binary_injector::static_params_t bsp {this->reg_param_bak, rhs_sp};
94+
jit_binary_injector_ = utils::make_unique<
95+
binary_injector::jit_uni_binary_injector_t<isa>>(
96+
this, bsp);
97+
}
7598
if (post_ops_.len() > 0 && !only_eltwise_or_sum) {
7699
vreg_d_weights = Vmm(idx_compute_vreg_max_--);
77100
vreg_d_bias = Vmm(idx_compute_vreg_max_--);
@@ -120,7 +143,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
120143
int g, size_t start, size_t end,
121144
const zero_point_call_params_t &zp,
122145
const void * post_ops_binary_rhs_arg_vec,
123-
const void * /* dst_orig */, const exec_ctx_t &ctx,
146+
const void * dst_orig, const exec_ctx_t &ctx,
124147
const memory_desc_t &dst_md,
125148
const single_gemm_conv_chunk_desc_t &chunk_desc) const override {
126149

@@ -135,6 +158,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
135158
args.dst = dst
136159
+ (os_offset * dst_os_stride_ + oc_offset)
137160
* dst_data_type_size_;
161+
args.dst_orig = dst_orig;
138162
args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
139163
args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
140164
args.sum_scale = sum_scale_;
@@ -151,6 +175,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
151175

152176
struct ker_args_t {
153177
char *dst;
178+
const void* dst_orig;
154179
const acc_data_t *acc;
155180
const char *bias;
156181
const float *scales;
@@ -164,11 +189,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
164189

165190
nstl::vector<jit_uni_eltwise_injector_f32<isa> *> jit_eltwise_injectors_;
166191
nstl::vector<jit_uni_depthwise_injector_f32<isa> *> jit_depthwise_injectors_;
192+
std::unique_ptr<binary_injector::jit_uni_binary_injector_t<isa>>
193+
jit_binary_injector_;
167194

168195
using Vmm = typename cpu_isa_traits<isa>::Vmm;
169196
static const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof(float);
170197

171198
Xbyak::Reg64 reg_param = abi_param1;
199+
Xbyak::Reg64 reg_param_bak = r11;
172200
Xbyak::Reg64 reg_dst = rdx;
173201
Xbyak::Reg64 reg_acc = rax;
174202
Xbyak::Reg64 reg_bias = rbx;
@@ -265,6 +293,7 @@ void jit_pp_ker_t<isa>::generate() {
265293
}
266294
}
267295

296+
mov(reg_param_bak, reg_param);
268297
mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
269298
mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
270299
mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
@@ -286,10 +315,11 @@ void jit_pp_ker_t<isa>::generate() {
286315
if (utils::one_of(isa, avx2, sse41))
287316
mov(reg_table, l_table);
288317

289-
auto apply_post_ops = [&](size_t offset, int idx) {
318+
auto apply_post_ops = [&](size_t offset, int idx, bool apply_mask) {
290319
std::size_t post_ops_data_offset = 0;
291320
int eltwise_inj_idx = 0;
292321
int depthwise_inj_idx = 0;
322+
int binary_inj_idx = 0;
293323
for (int i = 0; i < post_ops_.len(); i++) {
294324
auto &post_op = post_ops_.entry_[i];
295325
if (post_op.is_sum()) {
@@ -313,6 +343,18 @@ void jit_pp_ker_t<isa>::generate() {
313343
uni_vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx));
314344

315345
uni_vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
346+
} else if (post_op.is_binary()) {
347+
binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
348+
auto dst_addr = ptr[reg_dst + offset * dst_data_type_size_];
349+
rhs_arg_params.vmm_idx_to_out_addr.emplace(idx, dst_addr);
350+
rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
351+
idx, 0 * sizeof(float));
352+
if (mayiuse(avx512_core) && apply_mask)
353+
rhs_arg_params.vmm_tail_idx_.emplace(idx);
354+
jit_binary_injector_->compute_vector(
355+
idx, binary_inj_idx, post_op, rhs_arg_params);
356+
357+
binary_inj_idx++;
316358
} else if (post_op.is_eltwise()) {
317359
jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector_range(vreg_dst(idx).getIdx(),
318360
vreg_dst(idx).getIdx() + 1);
@@ -330,6 +372,7 @@ void jit_pp_ker_t<isa>::generate() {
330372

331373
post_ops_data_offset += jit_depthwise_injectors_[depthwise_inj_idx]->memoryStep();
332374
depthwise_inj_idx++;
375+
binary_inj_idx++;
333376
} else if (post_op.is_quantization()) {
334377
add(reg_oc_offset, reg_g_offset);
335378
bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize;
@@ -404,6 +447,7 @@ void jit_pp_ker_t<isa>::generate() {
404447
sub(reg_oc_offset, reg_g_offset);
405448

406449
post_ops_data_offset += sizeof(float*);
450+
binary_inj_idx++;
407451
}
408452
}
409453
};
@@ -489,7 +533,7 @@ void jit_pp_ker_t<isa>::generate() {
489533
if (do_scale_)
490534
uni_vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
491535

492-
apply_post_ops(offset, idx);
536+
apply_post_ops(offset, idx, apply_mask);
493537

494538
if (dst_data_type_ != data_type::f32) {
495539
if (isa == avx512_core) {

0 commit comments

Comments
 (0)