Skip to content

Commit 0a72ae8

Browse files
Introduces support for 3D and 4D inputs in ACL acl_lowp and acl_lowp_sq matmul. (#2846)
1 parent 8d80048 commit 0a72ae8

File tree

4 files changed

+143
-57
lines changed

4 files changed

+143
-57
lines changed

src/cpu/aarch64/matmul/acl_lowp_matmul.cpp

+61-25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2024 Arm Ltd. and affiliates
2+
* Copyright 2024-2025 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -89,6 +89,14 @@ status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) {
8989
const memory_desc_wrapper bia_d(bias_md_);
9090
const memory_desc_wrapper dst_d(dst_md_);
9191

92+
cpu::matmul::matmul_helper_t helper(src_d, wei_d, dst_d);
93+
const dim_t M = helper.M();
94+
const dim_t N = helper.N();
95+
const dim_t K = helper.K();
96+
const dim_t dst_batch = helper.batch();
97+
const dim_t src_batch = helper.src_batch();
98+
const dim_t wei_batch = helper.wei_batch();
99+
92100
using namespace data_type;
93101

94102
// Note that has_default_values checks the argument for default zero
@@ -107,39 +115,66 @@ status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) {
107115
VERBOSE_UNSUPPORTED_DT_CFG);
108116
almc_.dst_is_s8 = dst_d.data_type() == s8;
109117

110-
VDISPATCH_MATMUL(src_d.matches_tag(format_tag::ab)
111-
&& wei_d.matches_tag(format_tag::ab)
112-
&& dst_d.matches_tag(format_tag::ab),
113-
VERBOSE_UNSUPPORTED_TAG);
118+
// reject in case the op is running on a cpu that have i8mm instruction set.
119+
// this is a temporary fix until the issue is resolved.
120+
VDISPATCH_MATMUL(
121+
arm_compute::CPUInfo::get().has_i8mm() || dst_d.data_type() != s8,
122+
"Op not supported on CPUs without i8mm instructions when dest "
123+
"datatype is s8");
124+
125+
using namespace format_tag;
126+
auto src_tag = memory_desc_matches_one_of_tag(src_md_, abcd, abc, ab);
127+
auto wei_tag = memory_desc_matches_one_of_tag(weights_md_, abcd, abc, ab);
128+
auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, abcd, abc, ab);
129+
130+
ACL_CHECK_SUPPORT(
131+
utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag),
132+
"Format tag is undefined");
114133

115-
VDISPATCH_MATMUL_SC(
116-
memory_desc_init_by_tag(bias_md_, bias_md_.ndims, bias_md_.dims,
117-
bias_md_.data_type, format_tag::ab),
134+
VDISPATCH_MATMUL_SC(memory_desc_init_by_tag(bias_md_, bias_md_.ndims,
135+
bias_md_.dims, bias_md_.data_type, dst_tag),
118136
VERBOSE_UNSUPPORTED_BIAS_CFG);
119137

120138
// We set the QuantizationInfo to be dynamic because it is re-set in run()
121-
almc_.src_tensor_info
122-
= arm_compute::TensorInfo(arm_compute::TensorShape(K(), M()), 1,
123-
arm_compute::DataType::QASYMM8_SIGNED,
124-
arm_compute::QuantizationInfo(1.0, 0, true));
139+
almc_.src_tensor_info = arm_compute::TensorInfo(
140+
arm_compute::TensorShape(K, M, 1, src_batch), 1,
141+
arm_compute::DataType::QASYMM8_SIGNED,
142+
arm_compute::QuantizationInfo(1.0, 0, true));
125143
almc_.src_tensor_info.set_are_values_constant(false);
126144

127145
almc_.wei_tensor_info
128-
= arm_compute::TensorInfo(arm_compute::TensorShape(N(), K()), 1,
129-
arm_compute::DataType::QASYMM8_SIGNED,
146+
= arm_compute::TensorInfo(arm_compute::TensorShape(N, K, wei_batch),
147+
1, arm_compute::DataType::QASYMM8_SIGNED,
130148
arm_compute::QuantizationInfo(1.0, 0, true));
131149
almc_.wei_tensor_info.set_are_values_constant(false);
132150

133151
almc_.bia_tensor_info = arm_compute::TensorInfo(
134152
arm_compute::TensorShape(), 1, arm_compute::DataType::F32);
135153
almc_.with_bias = bia_d.format_kind() != format_kind::undef;
154+
136155
if (almc_.with_bias) {
137-
// This is not currently guarded in ACL
138-
VDISPATCH_MATMUL(bia_d.ndims() == 2 && bia_d.dims()[0] == 1
139-
&& bia_d.dims()[1] == N(),
140-
"Only 1xN bias is supported");
141-
almc_.bia_tensor_info.set_tensor_shape(
142-
arm_compute::TensorShape(bia_d.dims()[1], bia_d.dims()[0]));
156+
switch (bia_d.ndims()) {
157+
case 2:
158+
VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == N,
159+
"Only 1xN bias is supported for 2D input");
160+
almc_.bia_tensor_info.set_tensor_shape(
161+
arm_compute::TensorShape(bia_d.dims()[1], 1));
162+
break;
163+
case 3:
164+
VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1
165+
&& bia_d.dims()[2] == N,
166+
"Only 1x1xN bias is supported for 3D input");
167+
almc_.bia_tensor_info.set_tensor_shape(
168+
arm_compute::TensorShape(bia_d.dims()[2], 1, 1));
169+
break;
170+
case 4:
171+
VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1
172+
&& bia_d.dims()[2] == 1 && bia_d.dims()[3] == N,
173+
"Only 1x1x1xN bias is supported for 4D input");
174+
almc_.bia_tensor_info.set_tensor_shape(
175+
arm_compute::TensorShape(bia_d.dims()[3], 1, 1, 1));
176+
break;
177+
}
143178
}
144179

145180
// We can fuse sum if it is the first post op
@@ -173,14 +208,15 @@ status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) {
173208
almc_.gemm_info.accumulate() ? 1 : 0));
174209

175210
almc_.dst_tensor_info = arm_compute::TensorInfo(
176-
arm_compute::TensorShape(N(), M()), arm_compute::Format::F32);
211+
arm_compute::TensorShape(N, M, 1, dst_batch),
212+
arm_compute::Format::F32);
177213

178214
almc_.dst_cast_tensor_info = almc_.dst_tensor_info;
179215

180-
almc_.dst_s8_tensor_info
181-
= arm_compute::TensorInfo(arm_compute::TensorShape(N(), M()), 1,
182-
arm_compute::DataType::QASYMM8_SIGNED,
183-
arm_compute::QuantizationInfo(1.0, 0, true));
216+
almc_.dst_s8_tensor_info = arm_compute::TensorInfo(
217+
arm_compute::TensorShape(N, M, 1, dst_batch), 1,
218+
arm_compute::DataType::QASYMM8_SIGNED,
219+
arm_compute::QuantizationInfo(1.0, 0, true));
184220

185221
ACL_CHECK_VALID(arm_compute::NEGEMMLowpMatrixMultiplyCore::validate(
186222
&almc_.src_tensor_info, &almc_.wei_tensor_info,

src/cpu/aarch64/matmul/acl_lowp_matmul.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
#include "cpu/matmul/cpu_matmul_pd.hpp"
2222
#include "cpu/matmul/matmul_utils.hpp"
2323

24+
#include "arm_compute/core/CPP/CPPTypes.h"
2425
#include "arm_compute/runtime/NEON/functions/NEDequantizationLayer.h"
2526
#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
2627
#include "arm_compute/runtime/NEON/functions/NEQuantizationLayer.h"
28+
2729
#include "cpu/aarch64/acl_post_ops.hpp"
2830
#include "cpu/aarch64/acl_utils.hpp"
2931

src/cpu/aarch64/matmul/acl_lowp_matmul_sq.cpp

+78-32
Original file line numberDiff line numberDiff line change
@@ -74,52 +74,95 @@ status_t acl_lowp_matmul_sq_t::pd_t::init(engine_t *engine) {
7474
const memory_desc_wrapper wei_d(weights_md_);
7575
const memory_desc_wrapper bia_d(bias_md_);
7676
const memory_desc_wrapper dst_d(dst_md_);
77+
78+
cpu::matmul::matmul_helper_t helper(src_d, wei_d, dst_d);
79+
const dim_t M = helper.M();
80+
const dim_t N = helper.N();
81+
const dim_t K = helper.K();
82+
const dim_t dst_batch = helper.batch();
83+
const dim_t src_batch = helper.src_batch();
84+
const dim_t wei_batch = helper.wei_batch();
85+
7786
using namespace data_type;
7887
VDISPATCH_MATMUL(utils::one_of(src_d.data_type(), s8, u8)
79-
&& wei_d.data_type() == s8
80-
&& src_d.data_type() == s8
81-
? dst_d.data_type() == s8
82-
: dst_d.data_type() == u8,
88+
&& wei_d.data_type() == s8
89+
&& (src_d.data_type() == s8 ? dst_d.data_type() == s8
90+
: dst_d.data_type() == u8),
8391
VERBOSE_UNSUPPORTED_DT_CFG);
8492
VDISPATCH_MATMUL(utils::one_of(bia_d.data_type(), f32, undef),
8593
VERBOSE_UNSUPPORTED_DT_CFG);
86-
// reject in case the op is running in a Neoverse-N1.
94+
95+
// reject in case the op is running on a cpu that have i8mm instruction set.
96+
// this is a temporary fix until the issue is resolved.
8797
VDISPATCH_MATMUL(arm_compute::CPUInfo::get().has_i8mm(),
88-
"Neoverse-N1 not supported");
89-
VDISPATCH_MATMUL(src_d.matches_tag(format_tag::ab)
90-
&& wei_d.matches_tag(format_tag::ab)
91-
&& dst_d.matches_tag(format_tag::ab),
92-
VERBOSE_UNSUPPORTED_TAG);
93-
VDISPATCH_MATMUL_SC(
94-
memory_desc_init_by_tag(bias_md_, bias_md_.ndims, bias_md_.dims,
95-
bias_md_.data_type, format_tag::ab),
98+
"Op not supported on CPUs without i8mm instructions");
99+
100+
// ACL batch dimension only support s32 for 3D and 4D
101+
VDISPATCH_MATMUL(
102+
wei_batch == 1, "Batch dimension must be 1 for the weights");
103+
104+
using namespace format_tag;
105+
auto src_tag = memory_desc_matches_one_of_tag(src_md_, abcd, abc, ab);
106+
auto wei_tag = memory_desc_matches_one_of_tag(weights_md_, abcd, abc, ab);
107+
auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, abcd, abc, ab);
108+
109+
ACL_CHECK_SUPPORT(
110+
utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag),
111+
"Format tag is undefined");
112+
113+
VDISPATCH_MATMUL_SC(memory_desc_init_by_tag(bias_md_, bias_md_.ndims,
114+
bias_md_.dims, bias_md_.data_type, dst_tag),
96115
VERBOSE_UNSUPPORTED_BIAS_CFG);
97-
// We set the QuantizationInfo to be dynamic because it is re-set in run()
98-
almc_.src_tensor_info
99-
= arm_compute::TensorInfo(arm_compute::TensorShape(K(), M()), 1,
100-
acl_utils::get_acl_data_t(src_d.data_type(), true),
101-
arm_compute::QuantizationInfo(1.0, 0, true));
116+
117+
almc_.bia_tensor_info = arm_compute::TensorInfo(
118+
arm_compute::TensorShape(), 1, arm_compute::DataType::S32);
119+
almc_.with_bias = bia_d.format_kind() != format_kind::undef;
120+
121+
almc_.src_tensor_info = arm_compute::TensorInfo(
122+
arm_compute::TensorShape(K, M, 1, src_batch), 1,
123+
acl_utils::get_acl_data_t(src_d.data_type(), true),
124+
arm_compute::QuantizationInfo(1.0, 0, true));
102125
almc_.src_tensor_info.set_are_values_constant(false);
103-
almc_.wei_tensor_info
104-
= arm_compute::TensorInfo(arm_compute::TensorShape(N(), K()), 1,
105-
acl_utils::get_acl_data_t(wei_d.data_type(), true),
106-
arm_compute::QuantizationInfo(1.0, 0, true));
126+
127+
almc_.wei_tensor_info = arm_compute::TensorInfo(
128+
arm_compute::TensorShape(N, K, 1, wei_batch), 1,
129+
acl_utils::get_acl_data_t(wei_d.data_type(), true),
130+
arm_compute::QuantizationInfo(1.0, 0, true));
107131
almc_.wei_tensor_info.set_are_values_constant(false);
108-
almc_.dst_tensor_info
109-
= arm_compute::TensorInfo(arm_compute::TensorShape(N(), M()), 1,
110-
acl_utils::get_acl_data_t(dst_d.data_type(), true),
111-
arm_compute::QuantizationInfo(1.0, 0, true));
132+
almc_.dst_tensor_info = arm_compute::TensorInfo(
133+
arm_compute::TensorShape(N, M, 1, dst_batch), 1,
134+
acl_utils::get_acl_data_t(dst_d.data_type(), true),
135+
arm_compute::QuantizationInfo(1.0, 0, true));
136+
112137
almc_.bia_tensor_info = arm_compute::TensorInfo(
113138
arm_compute::TensorShape(), 1, arm_compute::DataType::S32);
114139
almc_.with_bias = bia_d.format_kind() != format_kind::undef;
140+
115141
if (almc_.with_bias) {
116-
// This is not currently guarded in ACL
117-
VDISPATCH_MATMUL(bia_d.ndims() == 2 && bia_d.dims()[0] == 1
118-
&& bia_d.dims()[1] == N(),
119-
"Only 1xN bias is supported");
120-
almc_.bia_tensor_info.set_tensor_shape(
121-
arm_compute::TensorShape(bia_d.dims()[1], bia_d.dims()[0]));
142+
switch (bia_d.ndims()) {
143+
case 2:
144+
VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == N,
145+
"Only 1xN bias is supported for 2D input");
146+
almc_.bia_tensor_info.set_tensor_shape(arm_compute::TensorShape(
147+
bia_d.dims()[1], bia_d.dims()[0]));
148+
break;
149+
case 3:
150+
VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1
151+
&& bia_d.dims()[2] == N,
152+
"Only 1x1xN bias is supported for 3D input");
153+
almc_.bia_tensor_info.set_tensor_shape(
154+
arm_compute::TensorShape(bia_d.dims()[2], 1, 1));
155+
break;
156+
case 4:
157+
VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1
158+
&& bia_d.dims()[2] == 1 && bia_d.dims()[3] == N,
159+
"Only 1x1x1xN bias is supported for 4D input");
160+
almc_.bia_tensor_info.set_tensor_shape(
161+
arm_compute::TensorShape(bia_d.dims()[3], 1, 1, 1));
162+
break;
163+
}
122164
}
165+
123166
arm_compute::GEMMLowpOutputStageInfo info;
124167
info.type = arm_compute::GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
125168
info.gemmlowp_multiplier = 1073741824;
@@ -132,15 +175,18 @@ status_t acl_lowp_matmul_sq_t::pd_t::init(engine_t *engine) {
132175
auto scratchpad = scratchpad_registry().registrar();
133176
const dnnl::impl::memory_desc_t dst_md_ {desc_.dst_desc};
134177
arm_compute::ActivationLayerInfo act_info;
178+
135179
CHECK(init_scratchpad(engine, scratchpad, acl_post_ops, attr_.post_ops_,
136180
act_info, dst_md_));
137181
almc_.gemm_info.set_activation_info(act_info);
182+
138183
ACL_CHECK_VALID(arm_compute::NEGEMMLowpMatrixMultiplyCore::validate(
139184
&almc_.src_tensor_info, &almc_.wei_tensor_info,
140185
almc_.with_bias ? &almc_.bia_tensor_info : nullptr,
141186
&almc_.dst_tensor_info, almc_.gemm_info));
142187
return status::success;
143188
}
189+
144190
status_t acl_lowp_matmul_sq_t::pd_t::init_scratchpad(engine_t *engine,
145191
memory_tracking::registrar_t &scratchpad, acl_post_ops_t &post_ops,
146192
dnnl::impl::post_ops_t &attr_post_ops,

src/cpu/aarch64/matmul/acl_lowp_matmul_sq.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
#include "cpu/aarch64/acl_post_ops.hpp"
2727

28+
#include "arm_compute/core/CPP/CPPTypes.h"
29+
2830
namespace dnnl {
2931
namespace impl {
3032
namespace cpu {

0 commit comments

Comments
 (0)