Skip to content

Commit 637eb68

Browse files
luo-cheng2021azhai219
authored andcommittedNov 29, 2024
[FORK][FEATURE] avx512 fork bf16 dw support binary postops
1 parent 40107a6 commit 637eb68

2 files changed

+102
-5
lines changed
 

‎src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp

+96-3
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,25 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_filter_unrolled(
207207
L(iter_exit_label);
208208
}
209209

210+
template <typename F>
211+
static void iterate(const int ur_ch_blocks, const int ur_w,
212+
const bool mask_tail, const F &f) {
213+
for (int ch = 0; ch < ur_ch_blocks; ch++) {
214+
const bool mask_flag = mask_tail && ch + 1 == ur_ch_blocks;
215+
for (int ow = 0; ow < ur_w; ow++)
216+
f(ch, ow, mask_flag);
217+
}
218+
}
219+
template <typename F>
220+
static void iterate(const int ur_ch_blocks, const int ur_w, const F &f) {
221+
iterate(ur_ch_blocks, ur_w, false, f);
222+
}
223+
210224
void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_postprocess(
211-
int ur_ch_blocks, int ur_w) {
225+
int ur_ch_blocks, int ur_w, bool last_ch_block_flag) {
212226
int eltwise_inj_idx = 0;
213227
int depthwise_inj_idx = 0;
228+
int binary_inj_idx = 0;
214229
std::size_t post_ops_data_offset = 0;
215230
const auto& p = attr_.post_ops_;
216231

@@ -222,6 +237,63 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_postprocess(
222237

223238
eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, end_idx);
224239
eltwise_inj_idx++;
240+
} else if (post_op.is_binary()) {
241+
injector_utils::vmm_index_set_t vmm_idxs;
242+
binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
243+
rhs_arg_params_tail;
244+
245+
const auto dst_layout_nxc = is_dst_layout_nxc();
246+
// width per output channel block
247+
const auto ch_blk = jcp.ch_block;
248+
// ncx: next output channel block stride
249+
const auto ocb_stride
250+
= dst_layout_nxc ? ch_blk : jcp.od * jcp.oh * jcp.ow * ch_blk;
251+
// ncx: next w inside a output channel block
252+
const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
253+
// ncx: if has tail
254+
const auto mask_tail_blocked_layout
255+
= jcp.oc_without_padding % jcp.ch_block && !dst_layout_nxc;
256+
// tail value
257+
const auto mask_tail = jcp.oc_without_padding % jcp.ch_block;
258+
259+
iterate(ur_ch_blocks, ur_w, mask_tail,
260+
[&](int ch, int ow, int mask_flag) {
261+
const size_t aux_output_l_off = jcp.typesize_out
262+
* (ch * ocb_stride + ow * ow_stride);
263+
const auto vmm_idx = get_acc_reg(ch * ur_w + ow).getIdx();
264+
vmm_idxs.emplace(vmm_idx);
265+
266+
rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
267+
vmm_idx, reg_output);
268+
rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
269+
vmm_idx, aux_output_l_off);
270+
if (mask_flag)
271+
rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
272+
});
273+
rhs_arg_params = rhs_arg_params_tail;
274+
rhs_arg_params.vmm_tail_idx_.clear();
275+
276+
Label postops_done;
277+
if (mask_tail_blocked_layout) {
278+
Label postops_no_tail;
279+
push(aux_reg_blocks_offset);
280+
mov(aux_reg_blocks_offset, ptr[param1 + GET_OFF(load_work)]);
281+
cmp(aux_reg_blocks_offset, jcp.nb_ch_blocking * jcp.ch_block);
282+
pop(aux_reg_blocks_offset);
283+
jge(postops_no_tail, T_NEAR);
284+
binary_injector->compute_vector_range(vmm_idxs,
285+
binary_inj_idx, post_op, rhs_arg_params_tail);
286+
jmp(postops_done, T_NEAR);
287+
L(postops_no_tail);
288+
} else if (last_ch_block_flag) {
289+
binary_injector->compute_vector_range(vmm_idxs,
290+
binary_inj_idx, post_op, rhs_arg_params_tail);
291+
} else {
292+
binary_injector->compute_vector_range(vmm_idxs,
293+
binary_inj_idx, post_op, rhs_arg_params);
294+
L(postops_done);
295+
}
296+
binary_inj_idx++;
225297
} else if (post_op.is_depthwise()) {
226298
push(aux_reg_blocks_offset);
227299
base_post_ops_data_offset += reg64_size;
@@ -244,6 +316,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::apply_postprocess(
244316

245317
post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep();
246318
depthwise_inj_idx++;
319+
binary_inj_idx++;
247320
}
248321
}
249322
}
@@ -373,7 +446,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::compute_loop(int ur_w, int ur_ch_b
373446
} else {
374447
apply_filter_unrolled(ur_ch_blocks, ur_w, last_ch_block_flag);
375448
}
376-
apply_postprocess(ur_ch_blocks, ur_w);
449+
apply_postprocess(ur_ch_blocks, ur_w, last_ch_block_flag);
377450
store_dst(ur_ch_blocks, ur_w, last_ch_block_flag);
378451
};
379452

@@ -492,21 +565,41 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::loop_ow(int ur_ch_blocks) {
492565

493566
void jit_avx512_fork_dw_conv_fwd_kernel_bf16::generate() {
494567
const auto& p = attr_.post_ops_;
568+
bool with_binary = false;
495569
for (int i = 0; i < p.len(); i++) {
496570
auto& post_op = p.entry_[i];
497571
if (post_op.is_eltwise()) {
498572
eltwise_injectors.push_back(new jit_uni_eltwise_injector<avx512_core>(
499573
this,
500574
post_op.eltwise
501575
));
576+
} else if (post_op.is_binary()) {
577+
with_binary = true;
502578
} else if (post_op.is_depthwise()) {
503579
depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_core>(
504580
this,
505581
post_op
506582
));
507583
}
508584
}
509-
585+
if (with_binary) {
586+
static constexpr bool preserve_gpr = true;
587+
// need to tune if no need to preserve it
588+
static constexpr bool preserve_vmm = true;
589+
static constexpr size_t helper_vmm_idx = 31;
590+
const size_t tail_size = jcp.oc_without_padding
591+
% (cpu_isa_traits<avx512_core>::vlen / sizeof(float));
592+
static constexpr bool use_exact_tail_scalar_bcast = false;
593+
const binary_injector::rhs_arg_static_params_t rhs_sp {
594+
helper_vmm_idx, r10, r11, r12, preserve_gpr,
595+
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
596+
GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_),
597+
tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast};
598+
const binary_injector::static_params_t bsp {this->param1, rhs_sp};
599+
binary_injector = utils::make_unique<
600+
binary_injector::jit_uni_binary_injector_t<avx512_core>>(
601+
this, bsp);
602+
}
510603
this->preamble();
511604

512605
std::size_t post_ops_pointers_count = 0;

‎src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
2828
#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp"
29+
#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
2930

3031
namespace dnnl {
3132
namespace impl {
@@ -36,7 +37,7 @@ struct jit_avx512_fork_dw_conv_fwd_kernel_bf16 : public jit_generator {
3637
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_fork_dw_conv_fwd_kernel_bf16)
3738

3839
jit_avx512_fork_dw_conv_fwd_kernel_bf16(const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t& attr)
39-
: jit_generator(jit_name()), jcp(ajcp), attr_(attr), bf16_emu_(nullptr) {
40+
: jit_generator(jit_name()), jcp(ajcp), attr_(attr), bf16_emu_(nullptr), dst_md_(dst_md) {
4041
if (!isa_has_bf16(jcp.isa))
4142
bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1,
4243
bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4,
@@ -125,14 +126,17 @@ struct jit_avx512_fork_dw_conv_fwd_kernel_bf16 : public jit_generator {
125126
inline void compute_loop(int ur_w, int ur_ch_blocks);
126127
inline void apply_filter(int ur_ch_blocks, int ur_w, bool last_ch_block_flag);
127128
inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w, bool last_ch_block_flag);
128-
inline void apply_postprocess(int ur_ch_blocks, int ur_w);
129+
inline void apply_postprocess(int ur_ch_blocks, int ur_w, bool last_ch_block_flag);
129130
inline void store_dst(int ur_ch_blocks, int ur_w, bool last_ch_block_flag);
130131
inline void loop_ow(int ur_ch_blocks);
131132

132133
nstl::vector<jit_uni_eltwise_injector<avx512_core>*> eltwise_injectors;
133134
nstl::vector<jit_uni_depthwise_injector_f32<avx512_core>*> depthwise_injectors;
135+
std::unique_ptr<binary_injector::jit_uni_binary_injector_t<avx512_core>>
136+
binary_injector;
134137

135138
bf16_emulation_t *bf16_emu_;
139+
memory_desc_t dst_md_;
136140

137141
void generate() override;
138142
};

0 commit comments

Comments
 (0)
Please sign in to comment.