18
18
#include < functional>
19
19
20
20
#include " cpu/x64/injectors/jit_uni_postops_injector.hpp"
21
+ #include " cpu/x64/injectors/jit_uni_binary_injector.hpp"
21
22
#include " cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
22
23
#include " cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp"
23
24
#include " cpu/x64/jit_generator.hpp"
@@ -60,10 +61,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
60
61
vreg_zero = Vmm (idx_compute_vreg_start_++);
61
62
}
62
63
bool only_eltwise_or_sum = true ;
64
+ bool with_binary = false ;
63
65
for (int idx = 0 ; idx < post_ops_.len (); ++idx) {
64
66
const auto &e = post_ops_.entry_ [idx];
65
67
if (e.is_eltwise (true )) {
66
68
do_eltwise_ = true ;
69
+ } else if (e.is_binary ()) {
70
+ with_binary = true ;
71
+ only_eltwise_or_sum = false ;
67
72
} else if (e.is_sum ()) {
68
73
do_sum_ = true ;
69
74
sum_scale_ = e.sum .scale ;
@@ -72,6 +77,24 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
72
77
only_eltwise_or_sum = false ;
73
78
}
74
79
}
80
+ if (with_binary) {
81
+ #define PARAM_OFF (field ) offsetof(ker_args_t , field)
82
+ static constexpr bool preserve_gpr = true ;
83
+ static constexpr bool preserve_vmm = true ;
84
+ static constexpr size_t helper_vmm_idx = 1 ;
85
+ static constexpr size_t tail_size = 0 ;
86
+ static constexpr bool use_exact_tail_scalar_bcast = false ;
87
+ const binary_injector::rhs_arg_static_params_t rhs_sp {
88
+ helper_vmm_idx, r13, r14, r15, preserve_gpr,
89
+ preserve_vmm, PARAM_OFF (post_ops_binary_rhs_arg_vec),
90
+ PARAM_OFF (dst_orig), memory_desc_wrapper (pd->dst_md ()),
91
+ tail_size, kreg_rem_mask_short, use_exact_tail_scalar_bcast};
92
+ #undef PARAM_OFF
93
+ const binary_injector::static_params_t bsp {this ->reg_param_bak , rhs_sp};
94
+ jit_binary_injector_ = utils::make_unique<
95
+ binary_injector::jit_uni_binary_injector_t <isa>>(
96
+ this , bsp);
97
+ }
75
98
if (post_ops_.len () > 0 && !only_eltwise_or_sum) {
76
99
vreg_d_weights = Vmm (idx_compute_vreg_max_--);
77
100
vreg_d_bias = Vmm (idx_compute_vreg_max_--);
@@ -213,12 +236,15 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
213
236
214
237
nstl::vector<jit_uni_eltwise_injector_f32<isa> *> jit_eltwise_injectors_;
215
238
nstl::vector<jit_uni_depthwise_injector_f32<isa> *> jit_depthwise_injectors_;
239
+ std::unique_ptr<binary_injector::jit_uni_binary_injector_t <isa>>
240
+ jit_binary_injector_;
216
241
217
242
size_t number_of_reserved_zmm_regs_;
218
243
using Vmm = typename cpu_isa_traits<isa>::Vmm;
219
244
static const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof (float );
220
245
221
246
Xbyak::Reg64 reg_param = abi_param1;
247
+ Xbyak::Reg64 reg_param_bak = r11;
222
248
Xbyak::Reg64 reg_dst = rdx;
223
249
Xbyak::Reg64 reg_acc = rax;
224
250
Xbyak::Reg64 reg_bias = rbx;
@@ -329,6 +355,7 @@ void jit_pp_ker_t<isa>::generate() {
329
355
}
330
356
}
331
357
358
+ mov (reg_param_bak, reg_param);
332
359
mov (reg_dst, ptr[reg_param + PARAM_OFF (dst)]);
333
360
mov (reg_acc, ptr[reg_param + PARAM_OFF (acc)]);
334
361
mov (reg_bias, ptr[reg_param + PARAM_OFF (bias)]);
@@ -358,10 +385,11 @@ void jit_pp_ker_t<isa>::generate() {
358
385
if (utils::one_of (isa, avx2, sse41))
359
386
mov (reg_table, l_table);
360
387
361
- auto apply_post_ops = [&](size_t offset, int idx) {
388
+ auto apply_post_ops = [&](size_t offset, int idx, bool apply_mask ) {
362
389
std::size_t post_ops_data_offset = 0 ;
363
390
int eltwise_inj_idx = 0 ;
364
391
int depthwise_inj_idx = 0 ;
392
+ int binary_inj_idx = 0 ;
365
393
for (int i = 0 ; i < post_ops_.len (); i++) {
366
394
auto &post_op = post_ops_.entry_ [i];
367
395
if (post_op.is_sum ()) {
@@ -385,6 +413,18 @@ void jit_pp_ker_t<isa>::generate() {
385
413
uni_vcvtdq2ps (vreg_prev_dst (idx), vreg_prev_dst (idx));
386
414
387
415
uni_vfmadd231ps (vreg_dst (idx), vreg_prev_dst (idx), vreg_sum_scale);
416
+ } else if (post_op.is_binary ()) {
417
+ binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
418
+ auto dst_addr = ptr[reg_dst + offset * dst_data_type_size_];
419
+ rhs_arg_params.vmm_idx_to_out_addr .emplace (idx, dst_addr);
420
+ rhs_arg_params.vmm_idx_to_out_elem_off_val .emplace (
421
+ idx, 0 * sizeof (float ));
422
+ if (mayiuse (avx512_core) && apply_mask)
423
+ rhs_arg_params.vmm_tail_idx_ .emplace (idx);
424
+ jit_binary_injector_->compute_vector (
425
+ idx, binary_inj_idx, post_op, rhs_arg_params);
426
+
427
+ binary_inj_idx++;
388
428
} else if (post_op.is_eltwise ()) {
389
429
jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector_range (vreg_dst (idx).getIdx (),
390
430
vreg_dst (idx).getIdx () + 1 );
@@ -402,6 +442,7 @@ void jit_pp_ker_t<isa>::generate() {
402
442
403
443
post_ops_data_offset += jit_depthwise_injectors_[depthwise_inj_idx]->memoryStep ();
404
444
depthwise_inj_idx++;
445
+ binary_inj_idx++;
405
446
} else if (post_op.is_quantization ()) {
406
447
add (reg_oc_offset, reg_g_offset);
407
448
bool do_dequantization = post_op.quantization .alg == alg_kind::quantization_quantize_dequantize;
@@ -476,6 +517,7 @@ void jit_pp_ker_t<isa>::generate() {
476
517
sub (reg_oc_offset, reg_g_offset);
477
518
478
519
post_ops_data_offset += sizeof (float *);
520
+ binary_inj_idx++;
479
521
}
480
522
}
481
523
};
@@ -561,7 +603,7 @@ void jit_pp_ker_t<isa>::generate() {
561
603
if (do_scale_)
562
604
uni_vmulps (vreg_dst (idx), vreg_dst (idx), vreg_scale);
563
605
564
- apply_post_ops (offset, idx);
606
+ apply_post_ops (offset, idx, apply_mask );
565
607
566
608
if (jcp_.with_dst_scale ) {
567
609
uni_vmulps (vreg_dst_, vreg_dst (idx), vreg_dst_scale);
0 commit comments