18
18
#include < functional>
19
19
20
20
#include " cpu/x64/injectors/jit_uni_postops_injector.hpp"
21
+ #include " cpu/x64/injectors/jit_uni_binary_injector.hpp"
21
22
#include " cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
22
23
#include " cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp"
23
24
#include " cpu/x64/jit_generator.hpp"
@@ -60,10 +61,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
60
61
vreg_zero = Vmm (idx_compute_vreg_start_++);
61
62
}
62
63
bool only_eltwise_or_sum = true ;
64
+ bool with_binary = false ;
63
65
for (int idx = 0 ; idx < post_ops_.len (); ++idx) {
64
66
const auto &e = post_ops_.entry_ [idx];
65
67
if (e.is_eltwise (true )) {
66
68
do_eltwise_ = true ;
69
+ } else if (e.is_binary ()) {
70
+ with_binary = true ;
71
+ only_eltwise_or_sum = false ;
67
72
} else if (e.is_sum ()) {
68
73
do_sum_ = true ;
69
74
sum_scale_ = e.sum .scale ;
@@ -72,6 +77,24 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
72
77
only_eltwise_or_sum = false ;
73
78
}
74
79
}
80
+ if (with_binary) {
81
+ #define PARAM_OFF (field ) offsetof(ker_args_t , field)
82
+ static constexpr bool preserve_gpr = true ;
83
+ static constexpr bool preserve_vmm = true ;
84
+ static constexpr size_t helper_vmm_idx = 1 ;
85
+ static constexpr size_t tail_size = 0 ;
86
+ static constexpr bool use_exact_tail_scalar_bcast = false ;
87
+ const binary_injector::rhs_arg_static_params_t rhs_sp {
88
+ helper_vmm_idx, r13, r14, r15, preserve_gpr,
89
+ preserve_vmm, PARAM_OFF (post_ops_binary_rhs_arg_vec),
90
+ PARAM_OFF (dst_orig), memory_desc_wrapper (pd->dst_md ()),
91
+ tail_size, kreg_rem_mask_short, use_exact_tail_scalar_bcast};
92
+ #undef PARAM_OFF
93
+ const binary_injector::static_params_t bsp {this ->reg_param_bak , rhs_sp};
94
+ jit_binary_injector_ = utils::make_unique<
95
+ binary_injector::jit_uni_binary_injector_t <isa>>(
96
+ this , bsp);
97
+ }
75
98
if (post_ops_.len () > 0 && !only_eltwise_or_sum) {
76
99
vreg_d_weights = Vmm (idx_compute_vreg_max_--);
77
100
vreg_d_bias = Vmm (idx_compute_vreg_max_--);
@@ -133,6 +156,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
133
156
args.dst = dst
134
157
+ (os_offset * dst_os_stride_ + oc_offset)
135
158
* dst_data_type_size_;
159
+ args.dst_orig = dst_orig;
136
160
args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
137
161
args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
138
162
args.sum_scale = sum_scale_;
@@ -149,6 +173,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
149
173
150
174
struct ker_args_t {
151
175
char *dst;
176
+ const void * dst_orig;
152
177
const acc_data_t *acc;
153
178
const char *bias;
154
179
const float *scales;
@@ -163,11 +188,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
163
188
164
189
nstl::vector<jit_uni_eltwise_injector_f32<isa> *> jit_eltwise_injectors_;
165
190
nstl::vector<jit_uni_depthwise_injector_f32<isa> *> jit_depthwise_injectors_;
191
+ std::unique_ptr<binary_injector::jit_uni_binary_injector_t <isa>>
192
+ jit_binary_injector_;
166
193
167
194
using Vmm = typename cpu_isa_traits<isa>::Vmm;
168
195
static const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof (float );
169
196
170
197
Xbyak::Reg64 reg_param = abi_param1;
198
+ Xbyak::Reg64 reg_param_bak = r11;
171
199
Xbyak::Reg64 reg_dst = rdx;
172
200
Xbyak::Reg64 reg_acc = rax;
173
201
Xbyak::Reg64 reg_bias = rbx;
@@ -264,6 +292,7 @@ void jit_pp_ker_t<isa>::generate() {
264
292
}
265
293
}
266
294
295
+ mov (reg_param_bak, reg_param);
267
296
mov (reg_dst, ptr[reg_param + PARAM_OFF (dst)]);
268
297
mov (reg_acc, ptr[reg_param + PARAM_OFF (acc)]);
269
298
mov (reg_bias, ptr[reg_param + PARAM_OFF (bias)]);
@@ -285,10 +314,11 @@ void jit_pp_ker_t<isa>::generate() {
285
314
if (utils::one_of (isa, avx2, sse41))
286
315
mov (reg_table, l_table);
287
316
288
- auto apply_post_ops = [&](size_t offset, int idx) {
317
+ auto apply_post_ops = [&](size_t offset, int idx, bool apply_mask ) {
289
318
std::size_t post_ops_data_offset = 0 ;
290
319
int eltwise_inj_idx = 0 ;
291
320
int depthwise_inj_idx = 0 ;
321
+ int binary_inj_idx = 0 ;
292
322
for (int i = 0 ; i < post_ops_.len (); i++) {
293
323
auto &post_op = post_ops_.entry_ [i];
294
324
if (post_op.is_sum ()) {
@@ -312,6 +342,18 @@ void jit_pp_ker_t<isa>::generate() {
312
342
uni_vcvtdq2ps (vreg_prev_dst (idx), vreg_prev_dst (idx));
313
343
314
344
uni_vfmadd231ps (vreg_dst (idx), vreg_prev_dst (idx), vreg_sum_scale);
345
+ } else if (post_op.is_binary ()) {
346
+ binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
347
+ auto dst_addr = ptr[reg_dst + offset * dst_data_type_size_];
348
+ rhs_arg_params.vmm_idx_to_out_addr .emplace (idx, dst_addr);
349
+ rhs_arg_params.vmm_idx_to_out_elem_off_val .emplace (
350
+ idx, 0 * sizeof (float ));
351
+ if (mayiuse (avx512_core) && apply_mask)
352
+ rhs_arg_params.vmm_tail_idx_ .emplace (idx);
353
+ jit_binary_injector_->compute_vector (
354
+ idx, binary_inj_idx, post_op, rhs_arg_params);
355
+
356
+ binary_inj_idx++;
315
357
} else if (post_op.is_eltwise ()) {
316
358
jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector_range (vreg_dst (idx).getIdx (),
317
359
vreg_dst (idx).getIdx () + 1 );
@@ -329,6 +371,7 @@ void jit_pp_ker_t<isa>::generate() {
329
371
330
372
post_ops_data_offset += jit_depthwise_injectors_[depthwise_inj_idx]->memoryStep ();
331
373
depthwise_inj_idx++;
374
+ binary_inj_idx++;
332
375
} else if (post_op.is_quantization ()) {
333
376
add (reg_oc_offset, reg_g_offset);
334
377
bool do_dequantization = post_op.quantization .alg == alg_kind::quantization_quantize_dequantize;
@@ -403,6 +446,7 @@ void jit_pp_ker_t<isa>::generate() {
403
446
sub (reg_oc_offset, reg_g_offset);
404
447
405
448
post_ops_data_offset += sizeof (float *);
449
+ binary_inj_idx++;
406
450
}
407
451
}
408
452
};
@@ -488,7 +532,7 @@ void jit_pp_ker_t<isa>::generate() {
488
532
if (do_scale_)
489
533
uni_vmulps (vreg_dst (idx), vreg_dst (idx), vreg_scale);
490
534
491
- apply_post_ops (offset, idx);
535
+ apply_post_ops (offset, idx, apply_mask );
492
536
493
537
if (dst_data_type_ != data_type::f32) {
494
538
if (isa == avx512_core) {
0 commit comments