Skip to content

Commit bc48caa

Browse files
luo-cheng2021EgorDuplensky
authored andcommitted
gemm int8 support binary postops
1 parent 9148ff4 commit bc48caa

8 files changed

+57
-8
lines changed

src/cpu/gemm_x8s8s32x_convolution.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t {
9797
bool ok = true;
9898

9999
for (int i = 0; i < po.len(); i++) {
100-
ok = ok && utils::one_of(po.entry_[i].kind, sum, eltwise, depthwise, quantization);
100+
ok = ok && utils::one_of(po.entry_[i].kind, sum, binary, eltwise, depthwise, quantization);
101101
}
102102
return ok;
103103
};

src/cpu/x64/gemm_bf16_convolution.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
111111
static constexpr size_t tail_size = 0;
112112
static constexpr bool use_exact_tail_scalar_bcast = false;
113113
const binary_injector::rhs_arg_static_params_t rhs_sp {
114-
helper_vmm_idx, r13, r14, preserve_gpr,
114+
helper_vmm_idx, reserved_eltwise_gpr, r13, r14, preserve_gpr,
115115
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
116116
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
117117
tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast};

src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::generate() {
591591
% (cpu_isa_traits<avx512_core>::vlen / sizeof(float));
592592
static constexpr bool use_exact_tail_scalar_bcast = false;
593593
const binary_injector::rhs_arg_static_params_t rhs_sp {
594-
helper_vmm_idx, r10, r11, preserve_gpr,
594+
helper_vmm_idx, reserved_eltwise_gpr, r10, r11, preserve_gpr,
595595
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
596596
GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_),
597597
tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast};

src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ struct jit_avx512_fork_dw_conv_fwd_kernel_bf16 : public jit_generator {
9595
mask_t ktail_mask = k_oc_tail_mask;
9696
mask_t k_ch_tail_mask_extended = Xbyak::Opmask(3);
9797

98+
Xbyak::Reg64 reserved_eltwise_gpr = r10;
99+
98100
Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0);
99101
Xbyak::Zmm zmm_src_reg = Xbyak::Zmm(1);
100102
Xbyak::Zmm zmm_prev_dst = Xbyak::Zmm(31);

src/cpu/x64/jit_gemm_convolution_utils.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
6666
static constexpr size_t tail_size = 0;
6767
static constexpr bool use_exact_tail_scalar_bcast = false;
6868
const binary_injector::rhs_arg_static_params_t rhs_sp {
69-
helper_vmm_idx, r13, r14, preserve_gpr,
69+
helper_vmm_idx, reserved_eltwise_gpr, r13, r14, preserve_gpr,
7070
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
7171
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
7272
tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast};
7373
#undef PARAM_OFF
7474
const binary_injector::static_params_t bsp {this->reg_abi_bak, rhs_sp};
7575
jit_binary_injector_ = utils::make_unique<
7676
binary_injector::jit_uni_binary_injector_t<isa>>(
77-
this, bsp);
77+
this, bsp);
7878
}
7979
if (post_ops_.len() > 0 && !only_eltwise) {
8080
vreg_d_weights = Vmm(idx_compute_vreg_max_--);
@@ -155,6 +155,7 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
155155
Xbyak::Reg64 reg_d_bias = r15;
156156
Xbyak::Reg64 reg_post_ops_data = rax;
157157
Vmm vreg_d_weights, vreg_d_bias;
158+
Xbyak::Reg64 reserved_eltwise_gpr = r10;
158159

159160
int idx_compute_vreg_start_;
160161
int idx_compute_vreg_max_;

src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp

+46-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <functional>
1919

2020
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
21+
#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
2122
#include "cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
2223
#include "cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp"
2324
#include "cpu/x64/jit_generator.hpp"
@@ -60,10 +61,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
6061
vreg_zero = Vmm(idx_compute_vreg_start_++);
6162
}
6263
bool only_eltwise_or_sum = true;
64+
bool with_binary = false;
6365
for (int idx = 0; idx < post_ops_.len(); ++idx) {
6466
const auto &e = post_ops_.entry_[idx];
6567
if (e.is_eltwise(true)) {
6668
do_eltwise_ = true;
69+
} else if (e.is_binary()) {
70+
with_binary = true;
71+
only_eltwise_or_sum = false;
6772
} else if (e.is_sum()) {
6873
do_sum_ = true;
6974
sum_scale_ = e.sum.scale;
@@ -72,6 +77,24 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
7277
only_eltwise_or_sum = false;
7378
}
7479
}
80+
if (with_binary) {
81+
#define PARAM_OFF(field) offsetof(ker_args_t, field)
82+
static constexpr bool preserve_gpr = true;
83+
static constexpr bool preserve_vmm = true;
84+
static constexpr size_t helper_vmm_idx = 1;
85+
static constexpr size_t tail_size = 0;
86+
static constexpr bool use_exact_tail_scalar_bcast = false;
87+
const binary_injector::rhs_arg_static_params_t rhs_sp {
88+
helper_vmm_idx, r13, r14, r15, preserve_gpr,
89+
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
90+
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
91+
tail_size, kreg_rem_mask_short, use_exact_tail_scalar_bcast};
92+
#undef PARAM_OFF
93+
const binary_injector::static_params_t bsp {this->reg_param_bak, rhs_sp};
94+
jit_binary_injector_ = utils::make_unique<
95+
binary_injector::jit_uni_binary_injector_t<isa>>(
96+
this, bsp);
97+
}
7598
if (post_ops_.len() > 0 && !only_eltwise_or_sum) {
7699
vreg_d_weights = Vmm(idx_compute_vreg_max_--);
77100
vreg_d_bias = Vmm(idx_compute_vreg_max_--);
@@ -133,6 +156,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
133156
args.dst = dst
134157
+ (os_offset * dst_os_stride_ + oc_offset)
135158
* dst_data_type_size_;
159+
args.dst_orig = dst_orig;
136160
args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
137161
args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
138162
args.sum_scale = sum_scale_;
@@ -149,6 +173,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
149173

150174
struct ker_args_t {
151175
char *dst;
176+
const void* dst_orig;
152177
const acc_data_t *acc;
153178
const char *bias;
154179
const float *scales;
@@ -163,11 +188,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
163188

164189
nstl::vector<jit_uni_eltwise_injector_f32<isa> *> jit_eltwise_injectors_;
165190
nstl::vector<jit_uni_depthwise_injector_f32<isa> *> jit_depthwise_injectors_;
191+
std::unique_ptr<binary_injector::jit_uni_binary_injector_t<isa>>
192+
jit_binary_injector_;
166193

167194
using Vmm = typename cpu_isa_traits<isa>::Vmm;
168195
static const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof(float);
169196

170197
Xbyak::Reg64 reg_param = abi_param1;
198+
Xbyak::Reg64 reg_param_bak = r11;
171199
Xbyak::Reg64 reg_dst = rdx;
172200
Xbyak::Reg64 reg_acc = rax;
173201
Xbyak::Reg64 reg_bias = rbx;
@@ -264,6 +292,7 @@ void jit_pp_ker_t<isa>::generate() {
264292
}
265293
}
266294

295+
mov(reg_param_bak, reg_param);
267296
mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
268297
mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
269298
mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
@@ -285,10 +314,11 @@ void jit_pp_ker_t<isa>::generate() {
285314
if (utils::one_of(isa, avx2, sse41))
286315
mov(reg_table, l_table);
287316

288-
auto apply_post_ops = [&](size_t offset, int idx) {
317+
auto apply_post_ops = [&](size_t offset, int idx, bool apply_mask) {
289318
std::size_t post_ops_data_offset = 0;
290319
int eltwise_inj_idx = 0;
291320
int depthwise_inj_idx = 0;
321+
int binary_inj_idx = 0;
292322
for (int i = 0; i < post_ops_.len(); i++) {
293323
auto &post_op = post_ops_.entry_[i];
294324
if (post_op.is_sum()) {
@@ -312,6 +342,18 @@ void jit_pp_ker_t<isa>::generate() {
312342
uni_vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx));
313343

314344
uni_vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
345+
} else if (post_op.is_binary()) {
346+
binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
347+
auto dst_addr = ptr[reg_dst + offset * dst_data_type_size_];
348+
rhs_arg_params.vmm_idx_to_out_addr.emplace(idx, dst_addr);
349+
rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
350+
idx, 0 * sizeof(float));
351+
if (mayiuse(avx512_core) && apply_mask)
352+
rhs_arg_params.vmm_tail_idx_.emplace(idx);
353+
jit_binary_injector_->compute_vector(
354+
idx, binary_inj_idx, post_op, rhs_arg_params);
355+
356+
binary_inj_idx++;
315357
} else if (post_op.is_eltwise()) {
316358
jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector_range(vreg_dst(idx).getIdx(),
317359
vreg_dst(idx).getIdx() + 1);
@@ -329,6 +371,7 @@ void jit_pp_ker_t<isa>::generate() {
329371

330372
post_ops_data_offset += jit_depthwise_injectors_[depthwise_inj_idx]->memoryStep();
331373
depthwise_inj_idx++;
374+
binary_inj_idx++;
332375
} else if (post_op.is_quantization()) {
333376
add(reg_oc_offset, reg_g_offset);
334377
bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize;
@@ -403,6 +446,7 @@ void jit_pp_ker_t<isa>::generate() {
403446
sub(reg_oc_offset, reg_g_offset);
404447

405448
post_ops_data_offset += sizeof(float*);
449+
binary_inj_idx++;
406450
}
407451
}
408452
};
@@ -488,7 +532,7 @@ void jit_pp_ker_t<isa>::generate() {
488532
if (do_scale_)
489533
uni_vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
490534

491-
apply_post_ops(offset, idx);
535+
apply_post_ops(offset, idx, apply_mask);
492536

493537
if (dst_data_type_ != data_type::f32) {
494538
if (isa == avx512_core) {

src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ void jit_uni_fork_dw_conv_fwd_kernel_f32<isa>::generate() {
772772
% (cpu_isa_traits<isa>::vlen / sizeof(float));
773773
static constexpr bool use_exact_tail_scalar_bcast = false;
774774
const binary_injector::rhs_arg_static_params_t rhs_sp {
775-
helper_vmm_idx, r10, r11, preserve_gpr,
775+
helper_vmm_idx, reserved_eltwise_gpr, r10, r11, preserve_gpr,
776776
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
777777
GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_),
778778
tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast};

src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ struct jit_uni_fork_dw_conv_fwd_kernel_f32 : public jit_generator {
8787
reg64_t aux_reg_ch_blocks = reg_ur_w;
8888
reg64_t aux_reg_blocks_offset = abi_not_param1;
8989

90+
Xbyak::Reg64 reserved_eltwise_gpr = r10;
91+
9092
reg64_t reg_d_weights = imm_addr64;
9193
reg64_t reg_d_bias = iter_kh;
9294
int base_post_ops_data_offset = 0;

0 commit comments

Comments
 (0)