18
18
#include < functional>
19
19
20
20
#include " cpu/x64/injectors/jit_uni_postops_injector.hpp"
21
+ #include " cpu/x64/injectors/jit_uni_binary_injector.hpp"
21
22
#include " cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
22
23
#include " cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp"
23
24
#include " cpu/x64/jit_generator.hpp"
@@ -60,10 +61,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
60
61
vreg_zero = Vmm (idx_compute_vreg_start_++);
61
62
}
62
63
bool only_eltwise_or_sum = true ;
64
+ bool with_binary = false ;
63
65
for (int idx = 0 ; idx < post_ops_.len (); ++idx) {
64
66
const auto &e = post_ops_.entry_ [idx];
65
67
if (e.is_eltwise (true )) {
66
68
do_eltwise_ = true ;
69
+ } else if (e.is_binary ()) {
70
+ with_binary = true ;
71
+ only_eltwise_or_sum = false ;
67
72
} else if (e.is_sum ()) {
68
73
do_sum_ = true ;
69
74
sum_scale_ = e.sum .scale ;
@@ -72,6 +77,24 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
72
77
only_eltwise_or_sum = false ;
73
78
}
74
79
}
80
+ if (with_binary) {
81
+ #define PARAM_OFF (field ) offsetof(ker_args_t , field)
82
+ static constexpr bool preserve_gpr = true ;
83
+ static constexpr bool preserve_vmm = true ;
84
+ static constexpr size_t helper_vmm_idx = 1 ;
85
+ static constexpr size_t tail_size = 0 ;
86
+ static constexpr bool use_exact_tail_scalar_bcast = false ;
87
+ const binary_injector::rhs_arg_static_params_t rhs_sp {
88
+ helper_vmm_idx, r13, r14, preserve_gpr,
89
+ preserve_vmm, PARAM_OFF (post_ops_binary_rhs_arg_vec),
90
+ PARAM_OFF (dst_orig), memory_desc_wrapper (pd->dst_md ()),
91
+ tail_size, kreg_rem_mask_short, use_exact_tail_scalar_bcast};
92
+ #undef PARAM_OFF
93
+ const binary_injector::static_params_t bsp {this ->reg_param_bak , rhs_sp};
94
+ jit_binary_injector_ = utils::make_unique<
95
+ binary_injector::jit_uni_binary_injector_t <isa>>(
96
+ this , bsp);
97
+ }
75
98
if (post_ops_.len () > 0 && !only_eltwise_or_sum) {
76
99
vreg_d_weights = Vmm (idx_compute_vreg_max_--);
77
100
vreg_d_bias = Vmm (idx_compute_vreg_max_--);
@@ -120,7 +143,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
120
143
int g, size_t start, size_t end,
121
144
const zero_point_call_params_t &zp,
122
145
const void * post_ops_binary_rhs_arg_vec,
123
- const void * /* dst_orig */ , const exec_ctx_t &ctx,
146
+ const void * dst_orig, const exec_ctx_t &ctx,
124
147
const memory_desc_t &dst_md,
125
148
const single_gemm_conv_chunk_desc_t &chunk_desc) const override {
126
149
@@ -135,6 +158,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
135
158
args.dst = dst
136
159
+ (os_offset * dst_os_stride_ + oc_offset)
137
160
* dst_data_type_size_;
161
+ args.dst_orig = dst_orig;
138
162
args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
139
163
args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
140
164
args.sum_scale = sum_scale_;
@@ -151,6 +175,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
151
175
152
176
struct ker_args_t {
153
177
char *dst;
178
+ const void * dst_orig;
154
179
const acc_data_t *acc;
155
180
const char *bias;
156
181
const float *scales;
@@ -164,11 +189,14 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
164
189
165
190
nstl::vector<jit_uni_eltwise_injector_f32<isa> *> jit_eltwise_injectors_;
166
191
nstl::vector<jit_uni_depthwise_injector_f32<isa> *> jit_depthwise_injectors_;
192
+ std::unique_ptr<binary_injector::jit_uni_binary_injector_t <isa>>
193
+ jit_binary_injector_;
167
194
168
195
using Vmm = typename cpu_isa_traits<isa>::Vmm;
169
196
static const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof (float );
170
197
171
198
Xbyak::Reg64 reg_param = abi_param1;
199
+ Xbyak::Reg64 reg_param_bak = r11;
172
200
Xbyak::Reg64 reg_dst = rdx;
173
201
Xbyak::Reg64 reg_acc = rax;
174
202
Xbyak::Reg64 reg_bias = rbx;
@@ -265,6 +293,7 @@ void jit_pp_ker_t<isa>::generate() {
265
293
}
266
294
}
267
295
296
+ mov (reg_param_bak, reg_param);
268
297
mov (reg_dst, ptr[reg_param + PARAM_OFF (dst)]);
269
298
mov (reg_acc, ptr[reg_param + PARAM_OFF (acc)]);
270
299
mov (reg_bias, ptr[reg_param + PARAM_OFF (bias)]);
@@ -286,10 +315,11 @@ void jit_pp_ker_t<isa>::generate() {
286
315
if (utils::one_of (isa, avx2, sse41))
287
316
mov (reg_table, l_table);
288
317
289
- auto apply_post_ops = [&](size_t offset, int idx) {
318
+ auto apply_post_ops = [&](size_t offset, int idx, bool apply_mask ) {
290
319
std::size_t post_ops_data_offset = 0 ;
291
320
int eltwise_inj_idx = 0 ;
292
321
int depthwise_inj_idx = 0 ;
322
+ int binary_inj_idx = 0 ;
293
323
for (int i = 0 ; i < post_ops_.len (); i++) {
294
324
auto &post_op = post_ops_.entry_ [i];
295
325
if (post_op.is_sum ()) {
@@ -313,6 +343,18 @@ void jit_pp_ker_t<isa>::generate() {
313
343
uni_vcvtdq2ps (vreg_prev_dst (idx), vreg_prev_dst (idx));
314
344
315
345
uni_vfmadd231ps (vreg_dst (idx), vreg_prev_dst (idx), vreg_sum_scale);
346
+ } else if (post_op.is_binary ()) {
347
+ binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
348
+ auto dst_addr = ptr[reg_dst + offset * dst_data_type_size_];
349
+ rhs_arg_params.vmm_idx_to_out_addr .emplace (idx, dst_addr);
350
+ rhs_arg_params.vmm_idx_to_out_elem_off_val .emplace (
351
+ idx, 0 * sizeof (float ));
352
+ if (mayiuse (avx512_core) && apply_mask)
353
+ rhs_arg_params.vmm_tail_idx_ .emplace (idx);
354
+ jit_binary_injector_->compute_vector (
355
+ idx, binary_inj_idx, post_op, rhs_arg_params);
356
+
357
+ binary_inj_idx++;
316
358
} else if (post_op.is_eltwise ()) {
317
359
jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector_range (vreg_dst (idx).getIdx (),
318
360
vreg_dst (idx).getIdx () + 1 );
@@ -330,6 +372,7 @@ void jit_pp_ker_t<isa>::generate() {
330
372
331
373
post_ops_data_offset += jit_depthwise_injectors_[depthwise_inj_idx]->memoryStep ();
332
374
depthwise_inj_idx++;
375
+ binary_inj_idx++;
333
376
} else if (post_op.is_quantization ()) {
334
377
add (reg_oc_offset, reg_g_offset);
335
378
bool do_dequantization = post_op.quantization .alg == alg_kind::quantization_quantize_dequantize;
@@ -404,6 +447,7 @@ void jit_pp_ker_t<isa>::generate() {
404
447
sub (reg_oc_offset, reg_g_offset);
405
448
406
449
post_ops_data_offset += sizeof (float *);
450
+ binary_inj_idx++;
407
451
}
408
452
}
409
453
};
@@ -489,7 +533,7 @@ void jit_pp_ker_t<isa>::generate() {
489
533
if (do_scale_)
490
534
uni_vmulps (vreg_dst (idx), vreg_dst (idx), vreg_scale);
491
535
492
- apply_post_ops (offset, idx);
536
+ apply_post_ops (offset, idx, apply_mask );
493
537
494
538
if (dst_data_type_ != data_type::f32) {
495
539
if (isa == avx512_core) {
0 commit comments