@@ -207,10 +207,25 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_filter_unrolled(
207
207
L (iter_exit_label);
208
208
}
209
209
210
+ template <typename F>
211
+ static void iterate(const int ur_ch_blocks, const int ur_w,
212
+ const bool mask_tail, const F &f) {
213
+ for (int ch = 0 ; ch < ur_ch_blocks; ch++) {
214
+ const bool mask_flag = mask_tail && ch + 1 == ur_ch_blocks;
215
+ for (int ow = 0 ; ow < ur_w; ow++)
216
+ f (ch, ow, mask_flag);
217
+ }
218
+ }
219
+ template <typename F>
220
+ static void iterate(const int ur_ch_blocks, const int ur_w, const F &f) {
221
+ iterate(ur_ch_blocks, ur_w, false , f);
222
+ }
223
+
210
224
void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_postprocess (
211
- int ur_ch_blocks, int ur_w) {
225
+ int ur_ch_blocks, int ur_w, bool last_ch_block_flag ) {
212
226
int eltwise_inj_idx = 0 ;
213
227
int depthwise_inj_idx = 0 ;
228
+ int binary_inj_idx = 0 ;
214
229
std::size_t post_ops_data_offset = 0 ;
215
230
const auto & p = attr_.post_ops_ ;
216
231
@@ -222,6 +237,63 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_postprocess(
222
237
223
238
eltwise_injectors[eltwise_inj_idx]->compute_vector_range (start_idx, end_idx);
224
239
eltwise_inj_idx++;
240
+ } else if (post_op.is_binary ()) {
241
+ injector_utils::vmm_index_set_t vmm_idxs;
242
+ binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
243
+ rhs_arg_params_tail;
244
+
245
+ const auto dst_layout_nxc = is_dst_layout_nxc ();
246
+ // width per output channel block
247
+ const auto ch_blk = jcp.ch_block ;
248
+ // ncx: next output channel block stride
249
+ const auto ocb_stride
250
+ = dst_layout_nxc ? ch_blk : jcp.od * jcp.oh * jcp.ow * ch_blk;
251
+ // ncx: next w inside a output channel block
252
+ const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
253
+ // ncx: if has tail
254
+ const auto mask_tail_blocked_layout
255
+ = jcp.oc_without_padding % jcp.ch_block && !dst_layout_nxc;
256
+ // tail value
257
+ const auto mask_tail = jcp.oc_without_padding % jcp.ch_block ;
258
+
259
+ iterate(ur_ch_blocks, ur_w, mask_tail,
260
+ [&](int ch, int ow, int mask_flag) {
261
+ const size_t aux_output_l_off = jcp.typesize_out
262
+ * (ch * ocb_stride + ow * ow_stride);
263
+ const auto vmm_idx = get_acc_reg (ch * ur_w + ow).getIdx ();
264
+ vmm_idxs.emplace (vmm_idx);
265
+
266
+ rhs_arg_params_tail.vmm_idx_to_out_reg .emplace (
267
+ vmm_idx, reg_output);
268
+ rhs_arg_params_tail.vmm_idx_to_out_elem_off_val .emplace (
269
+ vmm_idx, aux_output_l_off);
270
+ if (mask_flag)
271
+ rhs_arg_params_tail.vmm_tail_idx_ .emplace (vmm_idx);
272
+ });
273
+ rhs_arg_params = rhs_arg_params_tail;
274
+ rhs_arg_params.vmm_tail_idx_ .clear ();
275
+
276
+ Label postops_done;
277
+ if (mask_tail_blocked_layout) {
278
+ Label postops_no_tail;
279
+ push (aux_reg_blocks_offset);
280
+ mov (aux_reg_blocks_offset, ptr[param1 + GET_OFF (load_work)]);
281
+ cmp (aux_reg_blocks_offset, jcp.nb_ch_blocking * jcp.ch_block );
282
+ pop (aux_reg_blocks_offset);
283
+ jge (postops_no_tail, T_NEAR);
284
+ binary_injector->compute_vector_range (vmm_idxs,
285
+ binary_inj_idx, post_op, rhs_arg_params_tail);
286
+ jmp (postops_done, T_NEAR);
287
+ L (postops_no_tail);
288
+ } else if (last_ch_block_flag) {
289
+ binary_injector->compute_vector_range (vmm_idxs,
290
+ binary_inj_idx, post_op, rhs_arg_params_tail);
291
+ } else {
292
+ binary_injector->compute_vector_range (vmm_idxs,
293
+ binary_inj_idx, post_op, rhs_arg_params);
294
+ L (postops_done);
295
+ }
296
+ binary_inj_idx++;
225
297
} else if (post_op.is_depthwise ()) {
226
298
push (aux_reg_blocks_offset);
227
299
base_post_ops_data_offset += reg64_size;
@@ -244,6 +316,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_postprocess(
244
316
245
317
post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep ();
246
318
depthwise_inj_idx++;
319
+ binary_inj_idx++;
247
320
}
248
321
}
249
322
}
@@ -373,7 +446,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::compute_loop(int ur_w, int ur_ch_b
373
446
} else {
374
447
apply_filter_unrolled (ur_ch_blocks, ur_w, last_ch_block_flag);
375
448
}
376
- apply_postprocess (ur_ch_blocks, ur_w);
449
+ apply_postprocess (ur_ch_blocks, ur_w, last_ch_block_flag );
377
450
store_dst (ur_ch_blocks, ur_w, last_ch_block_flag);
378
451
};
379
452
@@ -492,21 +565,41 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::loop_ow(int ur_ch_blocks) {
492
565
493
566
void jit_avx512_fork_dw_conv_fwd_kernel_bf16::generate () {
494
567
const auto & p = attr_.post_ops_ ;
568
+ bool with_binary = false ;
495
569
for (int i = 0 ; i < p.len (); i++) {
496
570
auto & post_op = p.entry_ [i];
497
571
if (post_op.is_eltwise ()) {
498
572
eltwise_injectors.push_back (new jit_uni_eltwise_injector<avx512_core>(
499
573
this ,
500
574
post_op.eltwise
501
575
));
576
+ } else if (post_op.is_binary ()) {
577
+ with_binary = true ;
502
578
} else if (post_op.is_depthwise ()) {
503
579
depthwise_injectors.push_back (new jit_uni_depthwise_injector_f32<avx512_core>(
504
580
this ,
505
581
post_op
506
582
));
507
583
}
508
584
}
509
-
585
+ if (with_binary) {
586
+ static constexpr bool preserve_gpr = true ;
587
+ // need to tune if no need to preserve it
588
+ static constexpr bool preserve_vmm = true ;
589
+ static constexpr size_t helper_vmm_idx = 31 ;
590
+ const size_t tail_size = jcp.oc_without_padding
591
+ % (cpu_isa_traits<avx512_core>::vlen / sizeof (float ));
592
+ static constexpr bool use_exact_tail_scalar_bcast = false ;
593
+ const binary_injector::rhs_arg_static_params_t rhs_sp {
594
+ helper_vmm_idx, r10, r11, r12, preserve_gpr,
595
+ preserve_vmm, GET_OFF (post_ops_binary_rhs_arg_vec),
596
+ GET_OFF (dst_orig), memory_desc_wrapper (&dst_md_),
597
+ tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast};
598
+ const binary_injector::static_params_t bsp {this ->param1 , rhs_sp};
599
+ binary_injector = utils::make_unique<
600
+ binary_injector::jit_uni_binary_injector_t <avx512_core>>(
601
+ this , bsp);
602
+ }
510
603
this ->preamble ();
511
604
512
605
std::size_t post_ops_pointers_count = 0 ;
0 commit comments