Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ad33f1c

Browse files
committedAug 19, 2024·
cpu:aarch64: Extend Arm SVE support for Depthwise Convolution Kernels
1 parent 3355b98 commit ad33f1c

7 files changed

+237
-205
lines changed
 

‎src/cpu/aarch64/cpu_reducer.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,11 @@ struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t<data_type, isa> {
217217
const int load_len[nbranches] = {vlen, vlen, typesize};
218218
Label loop_x_label[nbranches + 1];
219219

220-
this->ptrue(preg_all.b);
220+
switch (isa) {
221+
case sve_256: this->ptrue(preg_all.b, VL32); break;
222+
case sve_512: this->ptrue(preg_all.b, VL64); break;
223+
default: assert(!"Unsupported ISA"); break;
224+
}
221225
if (typesize == 4)
222226
this->ptrue(preg_one.s, VL1);
223227
else

‎src/cpu/aarch64/jit_uni_dw_conv_kernel_f32.cpp

+24-26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2022 Intel Corporation
3-
* Copyright 2021-2022 FUJITSU LIMITED
3+
* Copyright 2021-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -43,25 +43,23 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) {
4343
const auto ch_blk = jcp.ch_block;
4444
const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk;
4545
const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
46-
4746
for (int ch = 0; ch < ur_ch_blocks; ch++) {
4847
for (int ow = 0; ow < ur_w; ow++) {
49-
ZReg zreg_acc = get_acc_reg(ch * ur_w + ow);
5048
ZRegS zregs_acc = get_acc_reg_s(ch * ur_w + ow);
5149

5250
int b_off = ch * ch_blk;
5351
if (this->jcp.with_bias) {
5452
add_imm(reg_tmp_addr, reg_bias, b_off * sizeof(float),
5553
reg_tmp_imm);
56-
ldr(zreg_acc, ptr(reg_tmp_addr));
54+
ld1w(zregs_acc, P_ALL_ONE, ptr(reg_tmp_addr));
5755
} else
5856
fmov(zregs_acc); // zero clear
5957

6058
int o_off = ch * ocb_stride + ow * ow_stride;
6159
if (this->jcp.with_sum) {
6260
add_imm(reg_tmp_addr, reg_output, o_off * sizeof(float),
6361
reg_tmp_imm);
64-
ldr(ZReg(0), ptr(reg_tmp_addr));
62+
ld1w(ZRegS(0), P_ALL_ONE, ptr(reg_tmp_addr));
6563
fadd(zregs_acc, zregs_acc, ZRegS(0));
6664
}
6765
}
@@ -96,15 +94,15 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
9694
ldr(aux_reg_input, ptr(aux_reg_input_buffer_ptr));
9795
add(aux_reg_input, aux_reg_input, reg_iw_offset);
9896
}
97+
9998
for (int ch = 0; ch < ur_ch_blocks; ch++) {
10099
for (int kw = 0; kw < jcp.kw; kw++) {
101100
int ker_off = ch * jcp.kh * jcp.kw * ch_blk + kw * ch_blk;
102101

103-
ZReg zreg_ker = get_ker_reg(0);
104102
ZRegS zregs_ker = get_ker_reg_s(0);
105103
add_imm(reg_tmp_addr, aux_reg_kernel, ker_off * sizeof(float),
106104
reg_tmp_imm);
107-
ldr(zreg_ker, ptr(reg_tmp_addr));
105+
ld1w(zregs_ker, P_ALL_ONE, ptr(reg_tmp_addr));
108106

109107
int ow_start = get_ow_start(kw, pad_l);
110108
int ow_end = get_ow_end(ur_w, kw, pad_r);
@@ -113,11 +111,10 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
113111
+ (ow * stride_w - pad_l) * iw_stride
114112
+ kw * dilate_w * iw_stride;
115113

116-
ZReg zreg_src = get_src_reg(0);
117114
ZRegS zregs_src = get_src_reg_s(0);
118115
add_imm(reg_tmp_addr, aux_reg_input,
119116
inp_off * jcp.typesize_in, reg_tmp_imm);
120-
ldr(zreg_src, ptr(reg_tmp_addr));
117+
ld1w(zregs_src, P_ALL_ONE, ptr(reg_tmp_addr));
121118

122119
ZRegS zregs_acc = get_acc_reg_s(ch * ur_w + ow);
123120
fmla(zregs_acc, P_ALL_ONE, zregs_src, zregs_ker);
@@ -164,11 +161,11 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
164161
for (int ow = 0; ow < ur_w; ow++) {
165162
const int o_off = ch * ocb_stride + ow * ow_stride;
166163

167-
ZReg zreg_dst = get_acc_reg(ch * ur_w + ow);
164+
ZRegS zregS_dst = get_acc_reg_s(ch * ur_w + ow);
168165

169166
add_imm(reg_tmp_addr, reg_output, o_off * sizeof(float),
170167
reg_tmp_imm);
171-
str(zreg_dst, ptr(reg_tmp_addr));
168+
st1w(zregS_dst, P_ALL_ONE, ptr(reg_tmp_addr));
172169
}
173170
}
174171
}
@@ -322,8 +319,11 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::ow_loop(int ur_ch_blocks) {
322319

323320
template <cpu_isa_t isa>
324321
void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
322+
const int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
325323
this->preamble();
326-
324+
//TO DO : renaming predicate register (P_ALL_ONE)
325+
if (simd_w_ != cpu_sveLen / sizeof(float))
326+
set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1);
327327
if (jcp.is_fused_conv) {
328328
ldr(reg_input_buffer_ptr, ptr(abi_param1, GET_OFF(src)));
329329
mov(reg_iw_offset, 0);
@@ -366,6 +366,7 @@ void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
366366
}
367367

368368
template struct jit_uni_dw_conv_fwd_kernel_f32<sve_512>;
369+
template struct jit_uni_dw_conv_fwd_kernel_f32<sve_256>;
369370

370371
template <cpu_isa_t isa>
371372
inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
@@ -412,21 +413,19 @@ inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
412413
for (int ch = 0; ch < ur_ch_blocks;
413414
ch++) { // unrolloing channel blocks
414415
int ker_off = ch * kh * kw * ch_blk;
415-
ZReg zreg_ker = get_ker_reg(0);
416416
ZRegS zregs_ker = get_ker_reg_s(0);
417417

418418
add_imm(reg_tmp_addr, aux1_reg_kernel, ker_off * sizeof(float),
419419
reg_tmp_imm);
420-
ldr(zreg_ker, ptr(reg_tmp_addr));
420+
ld1w(zregs_ker, P_ALL_ONE / T_z, ptr(reg_tmp_addr));
421421

422422
for (int w = 0; w < ur_str_w; w++) {
423423
int ddst_off = (ch * oh * ow + w) * ch_blk;
424424

425-
ZReg zreg_src = get_src_reg(0);
426425
ZRegS zregs_src = get_src_reg_s(0);
427426
add_imm(reg_tmp_addr, aux1_reg_ddst,
428427
ddst_off * sizeof(float), reg_tmp_imm);
429-
ldr(zreg_src, ptr(reg_tmp_addr));
428+
ld1w(zregs_src, P_ALL_ONE / T_z, ptr(reg_tmp_addr));
430429

431430
ZRegS zregs_acc = get_acc_reg_s(ch * ur_str_w + w);
432431
fmla(zregs_acc, P_ALL_ONE, zregs_src, zregs_ker);
@@ -465,11 +464,11 @@ inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
465464
for (int ch = 0; ch < ur_ch_blocks; ch++) {
466465
for (int w = 0; w < ur_str_w; w++) {
467466
int dsrc_off = (ch * ih * iw + w * stride_w) * ch_blk;
468-
ZReg zreg_acc = get_acc_reg(ch * ur_str_w + w);
467+
ZRegS zregs_acc = get_acc_reg_s(ch * ur_str_w + w);
469468

470469
add_imm(reg_tmp_addr, reg_dsrc, dsrc_off * sizeof(float),
471470
reg_tmp_imm);
472-
str(zreg_acc, ptr(reg_tmp_addr));
471+
st1w(zregs_acc, P_ALL_ONE / T_z, ptr(reg_tmp_addr));
473472
}
474473
}
475474
}
@@ -569,6 +568,7 @@ void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
569568
}
570569

571570
template struct jit_uni_dw_conv_bwd_data_kernel_f32<sve_512>;
571+
template struct jit_uni_dw_conv_bwd_data_kernel_f32<sve_256>;
572572

573573
template <cpu_isa_t isa>
574574
inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() {
@@ -1096,15 +1096,12 @@ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
10961096

10971097
template <cpu_isa_t isa>
10981098
void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
1099+
const int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
10991100
preamble();
1100-
1101-
if (simd_w == 16)
1102-
ptrue(P_ALL_ONE.b);
1103-
else if (simd_w == 8)
1104-
ptrue(P_ALL_ONE.b, VL32);
1105-
else
1106-
assert(!"Unsupport: simd_w != 16, 8");
1107-
1101+
//TO DO : renaming predicate register (P_ALL_ONE)
1102+
if (simd_w_ != cpu_sveLen / sizeof(float))
1103+
set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1);
1104+
if (simd_w_ != 16 || simd_w_ != 8) assert(!"Unsupport: simd_w != 16, 8");
11081105
ldr(reg_input_baddr,
11091106
ptr(abi_param1,
11101107
static_cast<int32_t>(offsetof(jit_dw_conv_call_s, input))));
@@ -1123,6 +1120,7 @@ void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
11231120
}
11241121

11251122
template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sve_512>;
1123+
template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sve_256>;
11261124

11271125
} // namespace aarch64
11281126
} // namespace cpu

‎src/cpu/aarch64/jit_uni_dw_conv_kernel_f32.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2022 Intel Corporation
3-
* Copyright 2021-2022 FUJITSU LIMITED
3+
* Copyright 2021-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -40,8 +40,8 @@ struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator {
4040
jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp)
4141
: jcp(ajcp), eltwise_injector_(nullptr) {
4242
if (jcp.with_eltwise)
43-
eltwise_injector_ = new jit_uni_eltwise_injector_f32<sve_512>(
44-
this, jcp.eltwise);
43+
eltwise_injector_
44+
= new jit_uni_eltwise_injector_f32<isa>(this, jcp.eltwise);
4545
}
4646

4747
~jit_uni_dw_conv_fwd_kernel_f32() { delete eltwise_injector_; }
@@ -133,7 +133,7 @@ struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator {
133133
format_tag::nwc);
134134
}
135135

136-
jit_uni_eltwise_injector_f32<sve_512> *eltwise_injector_;
136+
jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
137137
void generate() override;
138138
};
139139

‎src/cpu/aarch64/jit_uni_dw_conv_kernel_utils.hpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2022 Intel Corporation
3-
* Copyright 2021-2022 FUJITSU LIMITED
3+
* Copyright 2021-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -107,8 +107,6 @@ status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
107107
const auto wei_tag = isa == sve_512 ? Goihw16g : Goihw8g;
108108
const auto nxc_tag = nhwc;
109109
jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
110-
if ((blocked_tag != nChw16c) || (wei_tag != Goihw16g))
111-
return status::unimplemented;
112110

113111
if (src_d.format_kind() == format_kind::any) {
114112
CHECK(memory_desc_init_by_tag(src_md, blocked_tag));
@@ -146,8 +144,6 @@ status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
146144
if (!mayiuse(isa)) return status::unimplemented;
147145

148146
const int simd_w = isa == sve_512 ? 16 : 8;
149-
if (simd_w != 16) return status::unimplemented;
150-
151147
jcp.prop_kind = cd.prop_kind;
152148

153149
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
@@ -258,7 +254,7 @@ status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
258254
if (dst_d.data_type() == data_type::s32) return status::unimplemented;
259255
}
260256
bool ok_to_pad_channels = true && jcp.oc == jcp.ngroups
261-
&& jcp.ic == jcp.ngroups && isa == sve_512;
257+
&& jcp.ic == jcp.ngroups && (isa == sve_256 || isa == sve_512);
262258
if (ok_to_pad_channels) {
263259
jcp.oc = rnd_up(jcp.oc, simd_w);
264260
jcp.ic = rnd_up(jcp.oc, simd_w);
@@ -286,6 +282,7 @@ void jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_scratchpad(
286282
}
287283

288284
template struct jit_uni_dw_conv_fwd_kernel<sve_512, data_type::f32>;
285+
template struct jit_uni_dw_conv_fwd_kernel<sve_256, data_type::f32>;
289286

290287
template <cpu_isa_t isa, data_type_t kernel_dt>
291288
struct jit_uni_dw_conv_bwd_data_kernel {
@@ -372,7 +369,7 @@ status_t jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_conf(
372369
jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
373370

374371
bool ok_to_pad_channels = true && jcp.oc == jcp.ngroups
375-
&& jcp.ic == jcp.ngroups && isa == sve_512;
372+
&& jcp.ic == jcp.ngroups && (isa == sve_256 || isa == sve_512);
376373
if (ok_to_pad_channels) {
377374
jcp.oc = rnd_up(jcp.oc, simd_w);
378375
jcp.ic = rnd_up(jcp.oc, simd_w);
@@ -418,6 +415,7 @@ void jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_scratchpad(
418415
}
419416

420417
template struct jit_uni_dw_conv_bwd_data_kernel<sve_512, data_type::f32>;
418+
template struct jit_uni_dw_conv_bwd_data_kernel<sve_256, data_type::f32>;
421419

422420
template <cpu_isa_t isa, data_type_t kernel_dt>
423421
struct jit_uni_dw_conv_bwd_weights_kernel {
@@ -589,6 +587,7 @@ void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::balance(
589587
}
590588

591589
template struct jit_uni_dw_conv_bwd_weights_kernel<sve_512, data_type::f32>;
590+
template struct jit_uni_dw_conv_bwd_weights_kernel<sve_256, data_type::f32>;
592591
} // namespace aarch64
593592
} // namespace cpu
594593
} // namespace impl

‎src/cpu/aarch64/jit_uni_dw_convolution.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021 Intel Corporation
3-
* Copyright 2021 FUJITSU LIMITED
3+
* Copyright 2021-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -160,6 +160,7 @@ void jit_uni_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward(
160160
}
161161

162162
template struct jit_uni_dw_convolution_fwd_t<sve_512, data_type::f32>;
163+
template struct jit_uni_dw_convolution_fwd_t<sve_256, data_type::f32>;
163164

164165
template <cpu_isa_t isa, data_type_t diff_dst_type, data_type_t diff_src_type>
165166
void jit_uni_dw_convolution_bwd_data_t<isa, diff_dst_type,
@@ -269,6 +270,7 @@ void jit_uni_dw_convolution_bwd_data_t<isa, diff_dst_type,
269270
}
270271

271272
template struct jit_uni_dw_convolution_bwd_data_t<sve_512, data_type::f32>;
273+
template struct jit_uni_dw_convolution_bwd_data_t<sve_256, data_type::f32>;
272274

273275
template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type>
274276
jit_uni_dw_convolution_bwd_weights_t<isa, src_type, diff_weights_type>::
@@ -404,7 +406,7 @@ void jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
404406
* this should be explored in the future if further optimizations are required.
405407
*/
406408
template <>
407-
void jit_uni_dw_convolution_bwd_weights_t<sve_512,
409+
void jit_uni_dw_convolution_bwd_weights_t<sve_256,
408410
data_type::bf16>::execute_reduction(const exec_ctx_t &ctx) const {
409411

410412
auto diff_wei_reduction_buf
@@ -527,7 +529,7 @@ void jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
527529
}
528530

529531
template struct jit_uni_dw_convolution_bwd_weights_t<sve_512, data_type::f32>;
530-
532+
template struct jit_uni_dw_convolution_bwd_weights_t<sve_256, data_type::f32>;
531533
} // namespace aarch64
532534
} // namespace cpu
533535
} // namespace impl

‎src/cpu/aarch64/jit_uni_dw_convolution.hpp

+187-161
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021 Intel Corporation
3-
* Copyright 2021 FUJITSU LIMITED
3+
* Copyright 2021-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -95,180 +95,206 @@ struct jit_uni_dw_convolution_fwd_t : public primitive_t {
9595
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
9696

9797
std::unique_ptr<jit_uni_dw_conv_fwd_kernel<isa, src_type>> kernel_;
98-
};
99-
100-
using jit_sve_512_dw_convolution_fwd_t
101-
= jit_uni_dw_convolution_fwd_t<sve_512, data_type::f32>;
102-
103-
template <cpu_isa_t isa, data_type_t diff_dst_type,
104-
data_type_t diff_src_type = diff_dst_type>
105-
struct jit_uni_dw_convolution_bwd_data_t : public primitive_t {
106-
struct pd_t : public cpu_convolution_bwd_data_pd_t {
107-
pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
108-
const convolution_fwd_pd_t *hint_fwd_pd)
109-
: cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
110-
111-
DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""),
112-
jit_uni_dw_convolution_bwd_data_t);
113-
114-
status_t init(engine_t *engine) {
115-
bool ok = true && desc()->prop_kind == prop_kind::backward_data
116-
&& set_default_alg_kind(alg_kind::convolution_direct)
117-
&& expect_data_types(diff_src_type, diff_dst_type,
118-
data_type::undef, diff_dst_type, data_type::f32)
119-
&& attr()->has_default_values() && !has_zero_dim_memory()
120-
&& set_default_formats();
121-
122-
if (!ok) return status::unimplemented;
123-
124-
status_t status = jit_uni_dw_conv_bwd_data_kernel<isa,
125-
diff_dst_type>::init_conf(jcp_, *desc(), *diff_src_md(),
126-
*weights_md(), *diff_dst_md());
127-
if (status != status::success) return status;
128-
129-
auto scratchpad = scratchpad_registry().registrar();
130-
jit_uni_dw_conv_bwd_data_kernel<isa,
131-
diff_dst_type>::init_scratchpad(scratchpad, jcp_);
98+
using jit_sve_512_dw_convolution_fwd_t
99+
= jit_uni_dw_convolution_fwd_t<sve_512, data_type::f32>;
100+
using jit_sve_256_dw_convolution_fwd_t
101+
= jit_uni_dw_convolution_fwd_t<sve_256, data_type::f32>;
102+
103+
template <cpu_isa_t isa, data_type_t diff_dst_type,
104+
data_type_t diff_src_type = diff_dst_type>
105+
struct jit_uni_dw_convolution_bwd_data_t : public primitive_t {
106+
struct pd_t : public cpu_convolution_bwd_data_pd_t {
107+
pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
108+
const convolution_fwd_pd_t *hint_fwd_pd)
109+
: cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
110+
, jcp_() {}
111+
112+
DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""),
113+
jit_uni_dw_convolution_bwd_data_t);
114+
115+
status_t init(engine_t *engine) {
116+
bool ok = true && desc()->prop_kind == prop_kind::backward_data
117+
&& set_default_alg_kind(alg_kind::convolution_direct)
118+
&& expect_data_types(diff_src_type, diff_dst_type,
119+
data_type::undef, diff_dst_type, data_type::f32)
120+
&& attr()->has_default_values()
121+
&& !has_zero_dim_memory() && set_default_formats();
122+
123+
if (!ok) return status::unimplemented;
124+
125+
status_t status = jit_uni_dw_conv_bwd_data_kernel<isa,
126+
diff_dst_type>::init_conf(jcp_, *desc(), *diff_src_md(),
127+
*weights_md(), *diff_dst_md());
128+
if (status != status::success) return status;
129+
130+
auto scratchpad = scratchpad_registry().registrar();
131+
jit_uni_dw_conv_bwd_data_kernel<isa,
132+
diff_dst_type>::init_scratchpad(scratchpad, jcp_);
133+
134+
return status::success;
135+
}
136+
137+
jit_conv_conf_t jcp_;
138+
139+
protected:
140+
bool set_default_formats() {
141+
142+
using namespace format_tag;
143+
format_tag_t dat_tag, wei_tag;
144+
switch (isa) {
145+
case sve_512:
146+
dat_tag = nChw16c;
147+
wei_tag = Goihw16g;
148+
break;
149+
case sve_256:
150+
dat_tag = nChw8c;
151+
wei_tag = Goihw8g;
152+
break;
153+
default: return false;
154+
}
155+
return set_default_formats_common(dat_tag, wei_tag, dat_tag);
156+
}
157+
};
158+
159+
jit_uni_dw_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
160+
161+
typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
162+
typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
163+
typedef typename prec_traits<diff_dst_type>::type wei_data_t;
164+
165+
status_t init(engine_t *engine) override {
166+
CHECK(safe_ptr_assign(kernel_,
167+
new jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>(
168+
pd()->jcp_)));
169+
return kernel_->create_kernel();
170+
}
132171

172+
status_t execute(const exec_ctx_t &ctx) const override {
173+
execute_backward_data(ctx);
133174
return status::success;
134175
}
135176

136-
jit_conv_conf_t jcp_;
137-
138-
protected:
139-
bool set_default_formats() {
140-
using namespace format_tag;
177+
private:
178+
void execute_backward_data(const exec_ctx_t &ctx) const;
179+
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
141180

142-
auto dat_tag = nChw16c;
143-
auto wei_tag = Goihw16g;
144-
145-
return set_default_formats_common(dat_tag, wei_tag, dat_tag);
146-
}
181+
std::unique_ptr<jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>>
182+
kernel_;
147183
};
148184

149-
jit_uni_dw_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
150-
151-
typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
152-
typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
153-
typedef typename prec_traits<diff_dst_type>::type wei_data_t;
154-
155-
status_t init(engine_t *engine) override {
156-
CHECK(safe_ptr_assign(kernel_,
157-
new jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>(
158-
pd()->jcp_)));
159-
return kernel_->create_kernel();
160-
}
161-
162-
status_t execute(const exec_ctx_t &ctx) const override {
163-
execute_backward_data(ctx);
164-
return status::success;
165-
}
166-
167-
private:
168-
void execute_backward_data(const exec_ctx_t &ctx) const;
169-
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
170-
171-
std::unique_ptr<jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>>
172-
kernel_;
173-
};
174-
175-
using jit_sve_512_dw_convolution_bwd_data_t
176-
= jit_uni_dw_convolution_bwd_data_t<sve_512, data_type::f32>;
177-
178-
template <cpu_isa_t isa, data_type_t src_type,
179-
data_type_t diff_weights_type = src_type>
180-
struct jit_uni_dw_convolution_bwd_weights_t : public primitive_t {
181-
struct pd_t : public cpu_convolution_bwd_weights_pd_t {
182-
pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
183-
const convolution_fwd_pd_t *hint_fwd_pd)
184-
: cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
185-
, jcp_() {}
186-
using jit_uni_dw_convolution_bwd_weights
187-
= jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
188-
diff_weights_type>;
189-
DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""),
190-
jit_uni_dw_convolution_bwd_weights);
191-
192-
status_t init(engine_t *engine) {
193-
bool ok = true && desc()->prop_kind == prop_kind::backward_weights
194-
&& set_default_alg_kind(alg_kind::convolution_direct)
195-
&& expect_data_types(src_type, diff_weights_type,
196-
data_type::undef, src_type, data_type::f32)
197-
&& IMPLICATION(this->with_bias(),
198-
utils::one_of(
199-
this->desc()->diff_bias_desc.data_type,
200-
data_type::f32, data_type::bf16))
201-
&& attr()->has_default_values() && !has_zero_dim_memory()
202-
&& set_default_formats();
203-
if (!ok) return status::unimplemented;
204-
205-
const int max_threads
206-
= dnnl_in_parallel() ? 1 : dnnl_get_max_threads();
207-
208-
status_t status = jit_uni_dw_conv_bwd_weights_kernel<isa,
209-
src_type>::init_conf(jcp_, *desc(), *src_md(),
210-
*diff_weights_md(), *diff_dst_md(), max_threads);
211-
if (status != status::success) return status;
212-
213-
auto scratchpad = scratchpad_registry().registrar();
214-
jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>::init_scratchpad(
215-
scratchpad, jcp_);
216-
185+
using jit_sve_512_dw_convolution_bwd_data_t
186+
= jit_uni_dw_convolution_bwd_data_t<sve_512, data_type::f32>;
187+
using jit_sve_256_dw_convolution_bwd_data_t
188+
= jit_uni_dw_convolution_bwd_data_t<sve_256, data_type::f32>;
189+
190+
template <cpu_isa_t isa, data_type_t src_type,
191+
data_type_t diff_weights_type = src_type>
192+
struct jit_uni_dw_convolution_bwd_weights_t : public primitive_t {
193+
struct pd_t : public cpu_convolution_bwd_weights_pd_t {
194+
pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
195+
const convolution_fwd_pd_t *hint_fwd_pd)
196+
: cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
197+
, jcp_() {}
198+
using jit_uni_dw_convolution_bwd_weights
199+
= jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
200+
diff_weights_type>;
201+
DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""),
202+
jit_uni_dw_convolution_bwd_weights);
203+
204+
status_t init(engine_t *engine) {
205+
bool ok = true
206+
&& desc()->prop_kind == prop_kind::backward_weights
207+
&& set_default_alg_kind(alg_kind::convolution_direct)
208+
&& expect_data_types(src_type, diff_weights_type,
209+
data_type::undef, src_type, data_type::f32)
210+
&& IMPLICATION(this->with_bias(),
211+
utils::one_of(
212+
this->desc()->diff_bias_desc.data_type,
213+
data_type::f32, data_type::bf16))
214+
&& attr()->has_default_values()
215+
&& !has_zero_dim_memory() && set_default_formats();
216+
if (!ok) return status::unimplemented;
217+
218+
const int max_threads
219+
= dnnl_in_parallel() ? 1 : dnnl_get_max_threads();
220+
221+
status_t status = jit_uni_dw_conv_bwd_weights_kernel<isa,
222+
src_type>::init_conf(jcp_, *desc(), *src_md(),
223+
*diff_weights_md(), *diff_dst_md(), max_threads);
224+
if (status != status::success) return status;
225+
226+
auto scratchpad = scratchpad_registry().registrar();
227+
jit_uni_dw_conv_bwd_weights_kernel<isa,
228+
src_type>::init_scratchpad(scratchpad, jcp_);
229+
230+
return status::success;
231+
}
232+
233+
jit_conv_conf_t jcp_;
234+
235+
protected:
236+
bool set_default_formats() {
237+
using namespace format_tag;
238+
format_tag_t dat_tag, wei_tag;
239+
switch (isa) {
240+
case sve_512:
241+
dat_tag = nChw16c;
242+
wei_tag = Goihw16g;
243+
break;
244+
case sve_256:
245+
dat_tag = nChw8c;
246+
wei_tag = Goihw8g;
247+
break;
248+
default: return false;
249+
}
250+
251+
return set_default_formats_common(dat_tag, wei_tag, dat_tag);
252+
}
253+
};
254+
255+
jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd);
256+
257+
typedef typename prec_traits<data_type::bf16>::type bf16_data_t;
258+
typedef typename prec_traits<data_type::f32>::type f32_data_t;
259+
typedef typename prec_traits<src_type>::type src_data_t;
260+
typedef typename prec_traits<src_type>::type diff_dst_data_t;
261+
typedef typename prec_traits<diff_weights_type>::type
262+
diff_weights_data_t;
263+
264+
status_t init(engine_t *engine) override {
265+
CHECK(safe_ptr_assign(kernel_,
266+
new jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>(
267+
pd()->jcp_)));
268+
CHECK(kernel_->create_kernel());
269+
270+
if (pd()->jcp_.nthr_mb > 1) {
271+
CHECK(safe_ptr_assign(acc_ker_,
272+
new cpu_accumulator_1d_t<data_type::f32, isa>()));
273+
CHECK(acc_ker_->create_kernel());
274+
}
217275
return status::success;
218276
}
219277

220-
jit_conv_conf_t jcp_;
221-
222-
protected:
223-
bool set_default_formats() {
224-
using namespace format_tag;
225-
226-
auto dat_tag = isa == sve_512 ? nChw16c : nChw8c;
227-
auto wei_tag = isa == sve_512 ? Goihw16g : Goihw8g;
228-
229-
return set_default_formats_common(dat_tag, wei_tag, dat_tag);
230-
}
231-
};
232-
233-
jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd);
234-
235-
typedef typename prec_traits<data_type::bf16>::type bf16_data_t;
236-
typedef typename prec_traits<data_type::f32>::type f32_data_t;
237-
typedef typename prec_traits<src_type>::type src_data_t;
238-
typedef typename prec_traits<src_type>::type diff_dst_data_t;
239-
typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
240-
241-
status_t init(engine_t *engine) override {
242-
CHECK(safe_ptr_assign(kernel_,
243-
new jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>(
244-
pd()->jcp_)));
245-
CHECK(kernel_->create_kernel());
246-
247-
if (pd()->jcp_.nthr_mb > 1) {
248-
CHECK(safe_ptr_assign(
249-
acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
250-
CHECK(acc_ker_->create_kernel());
278+
status_t execute(const exec_ctx_t &ctx) const override {
279+
execute_backward_weights(ctx);
280+
execute_reduction(ctx);
281+
return status::success;
251282
}
252-
return status::success;
253-
}
254283

255-
status_t execute(const exec_ctx_t &ctx) const override {
256-
execute_backward_weights(ctx);
257-
execute_reduction(ctx);
258-
return status::success;
259-
}
284+
private:
285+
void execute_backward_weights(const exec_ctx_t &ctx) const;
286+
void execute_reduction(const exec_ctx_t &ctx) const;
287+
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
260288

261-
private:
262-
void execute_backward_weights(const exec_ctx_t &ctx) const;
263-
void execute_reduction(const exec_ctx_t &ctx) const;
264-
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
265-
266-
std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
267-
std::unique_ptr<jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>> kernel_;
268-
};
289+
std::unique_ptr<cpu_accumulator_1d_t<data_type::f32, isa>> acc_ker_;
290+
std::unique_ptr<jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>>
291+
kernel_;
292+
};
269293

270-
using jit_sve_512_dw_convolution_bwd_weights_t
271-
= jit_uni_dw_convolution_bwd_weights_t<sve_512, data_type::f32>;
294+
using jit_sve_512_dw_convolution_bwd_weights_t
295+
= jit_uni_dw_convolution_bwd_weights_t<sve_512, data_type::f32>;
296+
using jit_sve_256_dw_convolution_bwd_weights_t
297+
= jit_uni_dw_convolution_bwd_weights_t<sve_256, data_type::f32>;
272298
} // namespace aarch64
273299
} // namespace cpu
274300
} // namespace impl

‎src/cpu/cpu_convolution_list.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
138138
CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t)
139139
CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t<sve_512>)
140140
CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t<sve_512>)
141-
CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t)
141+
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t<sve_512,data_type::f32>)
142142
CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t)
143143
CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t<f32,f32,f32,sve_512>)
144+
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t<sve_256,data_type::f32>)
144145
CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t)
145146
CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
146147
CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<f32>)
@@ -250,9 +251,10 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
250251
CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_data_t)
251252
CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_data_t)
252253
CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_data_t)
253-
CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_bwd_data_t)
254+
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t<sve_512,data_type::f32>)
254255
CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_data_f32_t)
255256
CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t<f32,f32,f32,sve_512>)
257+
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t<sve_256,data_type::f32>)
256258
CPU_INSTANCE(gemm_convolution_bwd_data_t)
257259
CPU_INSTANCE(ref_convolution_bwd_data_t)
258260
nullptr,
@@ -339,9 +341,10 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
339341
CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_weights_t)
340342
CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_weights_t)
341343
CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_weights_t)
342-
CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_bwd_weights_t)
344+
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t<sve_512,data_type::f32>)
343345
CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_weights_t)
344346
CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t<f32,f32,f32,sve_512>)
347+
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t<sve_256,data_type::f32>)
345348
CPU_INSTANCE(gemm_convolution_bwd_weights_t)
346349
CPU_INSTANCE(ref_convolution_bwd_weights_t)
347350
nullptr,

0 commit comments

Comments
 (0)
Please sign in to comment.