Skip to content

Commit 9a1dc92

Browse files
Radu2kmgouicem
authored andcommitted
cpu: aarch64: Expand brgemm aarch64 unsupported cases handling mechanism (uxlfoundation#2099)
1 parent 4793296 commit 9a1dc92

7 files changed

+63
-61
lines changed

src/cpu/aarch64/acl_deconvolution.hpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ struct acl_deconvolution_fwd_t : public primitive_t {
193193
}
194194

195195
// Data layout
196-
const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC
197-
: arm_compute::DataLayout::NCHW;
196+
const arm_compute::DataLayout acl_layout = is_nspc
197+
? arm_compute::DataLayout::NHWC
198+
: arm_compute::DataLayout::NCHW;
198199

199200
acl_pd_conf.src_info = arm_compute::TensorInfo(is_nspc
200201
? arm_compute::TensorShape(ic, iw, ih, mb)
@@ -243,18 +244,15 @@ struct acl_deconvolution_fwd_t : public primitive_t {
243244
// padding is set for convolution. Otherwise, describe deconvolution as convolution of
244245
// upsampling input with stride = 1 and pad = 0.
245246
arm_compute::ConvolutionMethod conv_method;
246-
arm_compute::TensorInfo *conv_src_info;
247+
arm_compute::TensorInfo conv_src_info(
248+
acl_pd_conf.src_info.clone()->set_is_resizable(true));
247249
unsigned int pad_left = 0;
248250
unsigned int pad_right = 0;
249251
unsigned int pad_top = 0;
250252
unsigned int pad_bottom = 0;
251253
if (sh != 1 || sw != 1) {
252-
arm_compute::TensorInfo scale_out_info(
253-
acl_pd_conf.src_info.clone()
254-
->set_is_resizable(true)
255-
.reset_padding()
256-
.set_tensor_shape(scale_out_shape));
257-
conv_src_info = &scale_out_info;
254+
conv_src_info.reset_padding();
255+
conv_src_info.set_tensor_shape(scale_out_shape);
258256
} else {
259257
// compute correct padding here
260258
pad_left = pr > pl ? pr - pl : 0;
@@ -269,15 +267,13 @@ struct acl_deconvolution_fwd_t : public primitive_t {
269267
pad_right += deconv_pad_x / 2;
270268
pad_top += deconv_pad_y / 2;
271269
pad_bottom += deconv_pad_y / 2;
272-
273-
conv_src_info = &acl_pd_conf.src_info;
274270
}
275271
const arm_compute::PadStrideInfo conv_info(1, 1, pad_left,
276272
pad_right, pad_top, pad_bottom,
277273
arm_compute::DimensionRoundingType::CEIL);
278274
conv_method
279275
= arm_compute::NEConvolutionLayer::get_convolution_method(
280-
conv_src_info, &acl_pd_conf.wei_info,
276+
&conv_src_info, &acl_pd_conf.wei_info,
281277
&acl_pd_conf.dst_info, conv_info,
282278
arm_compute::WeightsInfo(),
283279
arm_compute::Size2D(1U, 1U),

src/cpu/aarch64/brgemm/brgemm.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*******************************************************************************
22
* Copyright 2020-2023 Intel Corporation
33
* Copyright 2023-2024 FUJITSU LIMITED
4+
* Copyright 2024 Arm Ltd. and affiliates
45
*
56
* Licensed under the Apache License, Version 2.0 (the "License");
67
* you may not use this file except in compliance with the License.
@@ -170,8 +171,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
170171
if (brg == nullptr) return status::invalid_arguments;
171172
if (transA || transB) return status::unimplemented;
172173

173-
brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
174-
beta, LDA, LDB, LDC, M, N, K, strides);
174+
CHECK(brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout,
175+
alpha, beta, LDA, LDB, LDC, M, N, K, strides));
175176

176177
if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
177178
bool ldx_check = (brg->is_row_major()) ? (LDA < K)
@@ -197,8 +198,8 @@ status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
197198
if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f)
198199
return status::unimplemented;
199200

200-
brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
201-
beta, LDA, LDC, M, N, strides);
201+
CHECK(brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout,
202+
alpha, beta, LDA, LDC, M, N, strides));
202203

203204
const bool ldx_check = (LDA < N || LDC < N);
204205
if (ldx_check) return status::invalid_arguments;

src/cpu/aarch64/brgemm/brgemm_utils.cpp

+29-26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*******************************************************************************
22
* Copyright 2022-2023 Intel Corporation
33
* Copyright 2023-2024 FUJITSU LIMITED
4+
* Copyright 2024 Arm Ltd. and affiliates
45
*
56
* Licensed under the Apache License, Version 2.0 (the "License");
67
* you may not use this file except in compliance with the License.
@@ -47,15 +48,18 @@ impl::data_type_t get_accum_datatype(brgemm_t *brg) {
4748
return brg->is_int8 ? data_type::s32 : data_type::f32;
4849
}
4950

50-
void init_kernel_datatype(
51+
status_t init_kernel_datatype(
5152
brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) {
52-
assert(dt_a != data_type::undef && dt_b != data_type::undef);
53+
if (dt_a != data_type::undef && dt_b != data_type::undef)
54+
return status::unimplemented;
5355
brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8)
5456
&& utils::one_of(dt_b, data_type::u8, data_type::s8);
5557
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
5658
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
5759
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
58-
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
60+
if (brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16)
61+
return status::unimplemented;
62+
return status::success;
5963
}
6064

6165
void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
@@ -88,27 +92,22 @@ void maybe_try_bf32(brgemm_t *brg) {
8892
//
8993
}
9094

91-
void set_isa_impl(brgemm_t *brg) {
95+
status_t set_isa_impl(brgemm_t *brg) {
9296
auto is_isa_ok = [&](cpu_isa_t isa) {
9397
return mayiuse(isa) &&
9498
// maybe IMPLICATION(brg->isa_user != isa_undef,
9599
// is_superset(brg->isa_user, isa)), but the API is not clear.
96100
one_of(brg->isa_user, isa_undef, isa);
97101
};
98102

99-
if (brg->is_bf32) {
100-
assert(!"unsupported case");
101-
} else if (brg->is_f32) {
102-
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512,
103-
is_isa_ok(sve_256), sve_256);
104-
} else if (brg->is_bf16) {
105-
assert(!"unsupported case");
106-
} else if (brg->is_f16) {
107-
assert(!"unsupported case");
108-
} else if (brg->is_int8) {
103+
if (brg->is_bf32 || brg->is_bf16 || brg->is_f16) {
104+
return status::unimplemented;
105+
} else if (brg->is_f32 || brg->is_int8) {
109106
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512,
110107
is_isa_ok(sve_256), sve_256);
108+
return status::success;
111109
}
110+
return status::success;
112111
}
113112

114113
void set_brg_vmm(brgemm_t *brg) {
@@ -187,7 +186,7 @@ inline size_t data_type_vnni_granularity(data_type_t data_type) {
187186
}
188187
status_t brgemm_blocking(brgemm_t *brg) {
189188

190-
set_isa_impl(brg);
189+
CHECK(set_isa_impl(brg));
191190
if (brg->isa_impl == isa_undef) return status::unimplemented;
192191
assert(!brg->is_dgmm); // should not be called from brdgmm
193192
set_brg_vmm(brg);
@@ -296,18 +295,19 @@ status_t brdgmm_blocking(brgemm_t *brg) {
296295
return status::success;
297296
}
298297

299-
void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
300-
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
301-
float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
302-
dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) {
298+
status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
299+
brgemm_batch_kind_t type, impl::data_type_t dt_a,
300+
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
301+
dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K,
302+
const brgemm_strides_t *strides, bool is_bf32) {
303303

304304
init_common_conf(brg, type, alpha, beta, strides);
305305

306306
brg->layout = layout;
307307

308308
brg->dt_a = brg->is_row_major() ? dt_a : dt_b;
309309
brg->dt_b = brg->is_row_major() ? dt_b : dt_a;
310-
init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
310+
CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b));
311311

312312
brg->dt_c = get_accum_datatype(brg);
313313
brg->dt_d = brg->dt_c;
@@ -319,7 +319,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
319319
brg->typesize_D = types::data_type_size(brg->dt_d);
320320

321321
brg->isa_user = isa;
322-
set_isa_impl(brg);
322+
CHECK(set_isa_impl(brg));
323323
brg->is_bf32 = false;
324324

325325
brg->has_int8_vnni = true;
@@ -352,11 +352,13 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
352352
brg->rd_step = has_no_vnni_compute_instruction
353353
? 1
354354
: data_type_vnni_granularity(brg->dt_b);
355+
return status::success;
355356
}
356357

357-
void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
358-
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
359-
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
358+
status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa,
359+
brgemm_batch_kind_t type, impl::data_type_t dt_a,
360+
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
361+
dim_t LDA, dim_t LDC, dim_t M, dim_t N,
360362
const brgemm_strides_t *strides) {
361363

362364
init_common_conf(brg, type, alpha, beta, strides);
@@ -365,7 +367,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
365367

366368
brg->dt_a = dt_a;
367369
brg->dt_b = dt_b;
368-
init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
370+
CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b));
369371

370372
brg->dt_c = get_accum_datatype(brg);
371373
brg->dt_d = brg->dt_c;
@@ -394,6 +396,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
394396

395397
brg->bcast_dim = M;
396398
brg->load_dim = N;
399+
return status::success;
397400
}
398401

399402
} // namespace brgemm_utils
@@ -402,4 +405,4 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
402405
} // namespace impl
403406
} // namespace dnnl
404407

405-
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
408+
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s

src/cpu/aarch64/brgemm/brgemm_utils.hpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*******************************************************************************
22
* Copyright 2022 Intel Corporation
33
* Copyright 2024 FUJITSU LIMITED
4+
* Copyright 2024 Arm Ltd. and affiliates
45
*
56
* Licensed under the Apache License, Version 2.0 (the "License");
67
* you may not use this file except in compliance with the License.
@@ -44,20 +45,21 @@ status_t brdgmm_blocking(brgemm_t *brg);
4445
* having to depend on BRGeMM's API. An additional feature is that this
4546
* function can be modified depending on needs without requiring changes
4647
* at the API level. */
47-
void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
48-
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
49-
float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
50-
dim_t N, dim_t K, const brgemm_strides_t *strides = nullptr,
51-
bool is_bf32 = false);
48+
status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
49+
brgemm_batch_kind_t type, impl::data_type_t dt_a,
50+
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
51+
dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K,
52+
const brgemm_strides_t *strides = nullptr, bool is_bf32 = false);
5253

5354
/* The purpose of this function is to enable initialization of brgemm values
5455
* and then call additional functions like blocking heuristics without
5556
* having to depend on BRDGeMM's API. An additional feature is that this
5657
* function can be modified depending on needs without requiring changes
5758
* at the API level. */
58-
void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
59-
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
60-
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
59+
status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa,
60+
brgemm_batch_kind_t type, impl::data_type_t dt_a,
61+
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
62+
dim_t LDA, dim_t LDC, dim_t M, dim_t N,
6163
const brgemm_strides_t *strides = nullptr);
6264

6365
} // namespace brgemm_utils

src/cpu/aarch64/jit_brgemm_conv_utils.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
33
* Copyright 2024 FUJITSU LIMITED
4+
* Copyright 2024 Arm Ltd. and affiliates
45
*
56
* Licensed under the Apache License, Version 2.0 (the "License");
67
* you may not use this file except in compliance with the License.
@@ -725,9 +726,9 @@ status_t brg_blocking_t::estimate_brgemm_ur() {
725726
const float alpha = 1.0;
726727
const float beta = 0.0;
727728
brgemm_t brg;
728-
brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
729+
CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
729730
brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr,
730-
is_bf32);
731+
is_bf32));
731732
CHECK(brgemm_utils::brgemm_blocking(&brg));
732733
ur = brg.bd_block;
733734
ur_block = brg.bd_block;
@@ -771,9 +772,9 @@ status_t brg_blocking_t::get_brgemm_ur(
771772
* rnd_up(oc, oc_block) * wei_dsz;
772773
const auto strides_ptr
773774
= (brg_type == brgemm_strd) ? &brg_strides : nullptr;
774-
brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt,
775-
wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB, LDC,
776-
vM, vN, vK, strides_ptr, is_bf32);
775+
CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brg_type,
776+
src_dt, wei_dt, brgemm_row_major, alpha, vbeta, LDA,
777+
LDB, LDC, vM, vN, vK, strides_ptr, is_bf32));
777778
CHECK(brgemm_utils::brgemm_blocking(&brg));
778779

779780
brgemm_attr_t brgattr;

src/cpu/aarch64/matmul/brgemm_matmul.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
33
* Copyright 2024 FUJITSU LIMITED
4+
* Copyright 2024 Arm Ltd. and affiliates
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
67
* You may obtain a copy of the License at
@@ -642,7 +643,6 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
642643
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
643644
ctx.current_K_start = k;
644645
ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
645-
assert(isa == sve_512);
646646
(*copy_B_kernel_)(&ctx);
647647
}
648648

@@ -654,7 +654,6 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
654654
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
655655
ctx.current_K_start = k;
656656
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
657-
assert(isa == sve_512);
658657
(*copy_B_kernel_)(&ctx);
659658
}
660659
}

src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3+
* Copyright 2024 Arm Ltd. and affiliates
34
*
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
@@ -129,7 +130,7 @@ bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr,
129130
}
130131

131132
status_t check_isa_with_datatype(
132-
const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) {
133+
const brgemm_matmul_conf_utils_t &bm_conf_utils) {
133134
if (bm_conf_utils.is_f32() && !bm_conf_utils.is_int8()
134135
&& !bm_conf_utils.is_bf16() && !bm_conf_utils.is_f16()
135136
&& !bm_conf_utils.is_int8())
@@ -732,8 +733,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
732733
dst_d.format_kind() == format_kind::any,
733734
bias_md.format_kind == format_kind::any);
734735

735-
VCHECK_BG(check_isa_with_datatype(isa, bm_conf_utils),
736-
VERBOSE_ISA_DT_MISMATCH);
736+
VCHECK_BG(check_isa_with_datatype(bm_conf_utils), VERBOSE_ISA_DT_MISMATCH);
737737

738738
bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt);
739739
bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt);
@@ -1107,4 +1107,4 @@ void init_scratchpad(memory_tracking::registrar_t &scratchpad,
11071107
} // namespace aarch64
11081108
} // namespace cpu
11091109
} // namespace impl
1110-
} // namespace dnnl
1110+
} // namespace dnnl

0 commit comments

Comments
 (0)