Skip to content

Commit 608d92f

Browse files
Sqvidvpirogov
authored andcommitted
cpu: aarch64 make binary ops use stateless ACL interface
Signed-off-by: Siddhartha Menon <siddhartha.menon@arm.com>
1 parent 912851f commit 608d92f

File tree

4 files changed

+205
-228
lines changed

4 files changed

+205
-228
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ On a CPU based on Arm AArch64 architecture, oneDNN CPU engine can be built with
173173
machine learning applications and provides AArch64 optimized implementations
174174
of core functions. This functionality currently requires that ACL is downloaded
175175
and built separately. See [Build from Source] section of the Developer Guide for
176-
details. oneDNN only supports Compute Library versions 24.08 or later.
176+
details. oneDNN only supports Compute Library versions 24.08.1 or later.
177177

178178
[Arm Compute Library (ACL)]: https://github.com/arm-software/ComputeLibrary
179179

cmake/ACL.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ endif()
3131

3232
find_package(ACL REQUIRED)
3333

34-
set(ACL_MINIMUM_VERSION "24.08")
34+
set(ACL_MINIMUM_VERSION "24.08.1")
3535

3636
if(ACL_FOUND)
3737
file(GLOB_RECURSE ACL_VERSION_FILE ${ACL_INCLUDE_DIR}/*/arm_compute_version.embed)

src/cpu/aarch64/acl_binary.cpp

+190-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022 Arm Ltd. and affiliates
2+
* Copyright 2022, 2024 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -14,34 +14,200 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17-
#include "cpu/aarch64/acl_binary.hpp"
17+
#include "acl_binary.hpp"
18+
19+
#include "arm_compute/core/ITensorPack.h"
20+
#include "arm_compute/core/experimental/Types.h"
21+
#include "arm_compute/runtime/Tensor.h"
22+
#include "arm_compute/runtime/experimental/operators/CpuAdd.h"
23+
#include "arm_compute/runtime/experimental/operators/CpuElementwise.h"
24+
#include "arm_compute/runtime/experimental/operators/CpuMul.h"
25+
#include "arm_compute/runtime/experimental/operators/CpuSub.h"
1826

1927
namespace dnnl {
2028
namespace impl {
2129
namespace cpu {
2230
namespace aarch64 {
2331

32+
status_t acl_binary_t::pd_t::init(engine_t *engine) {
33+
using namespace acl_utils;
34+
35+
// Only support f16/f32/s32 for now
36+
data_type_t ddt = dst_md(0)->data_type;
37+
if (!utils::one_of(ddt, data_type::f16, data_type::f32, data_type::s32))
38+
return status::unimplemented;
39+
40+
// Only support src and dst all matching for now
41+
if (ddt != src_md(0)->data_type || src_md(1)->data_type != ddt)
42+
return status::unimplemented;
43+
44+
// Sets the memory format of dst from any to src_md(0) blocking desc
45+
CHECK(set_default_params());
46+
47+
if (!attr()->has_default_values()) return status::unimplemented;
48+
49+
asp_.alg = desc()->alg_kind;
50+
51+
// All the algorithms we support
52+
if (!utils::one_of(asp_.alg, alg_kind::binary_add, alg_kind::binary_sub,
53+
alg_kind::binary_mul, alg_kind::binary_div,
54+
alg_kind::binary_max, alg_kind::binary_min))
55+
return status::unimplemented;
56+
57+
// s32 div in ACL does not round as oneDNN expects
58+
if (ddt == data_type::s32 && asp_.alg == alg_kind::binary_div)
59+
return status::unimplemented;
60+
61+
// ACL pointwise arithmetic operators assume that the innermost
62+
// dimensions are dense for src0, src1 and dst. Reordering the
63+
// logical dimensions by stride does this (if reordered_dims >= 1 )
64+
// and also makes memory accesses contiguous in ACL (without any
65+
// data reordering).
66+
memory_desc_t src_d0_permed, src_d1_permed, dst_d_permed;
67+
int reordered_dims = reorder_dimensions_by_stride(
68+
{&src_d0_permed, &src_d1_permed, &dst_d_permed},
69+
{src_md(0), src_md(1), dst_md()});
70+
if (reordered_dims < 1) return status::unimplemented;
71+
72+
// Create ACL tensor infos with permuted descs
73+
CHECK(tensor_info(asp_.src0_info, src_d0_permed));
74+
CHECK(tensor_info(asp_.src1_info, src_d1_permed));
75+
CHECK(tensor_info(asp_.dst_info, dst_d_permed));
76+
77+
// In this case ACL tries to treat src0 and src1 as a 1D array, but
78+
// fails because the strides aren't equal. TODO: remove when fixed
79+
// in ACL
80+
if (asp_.alg == alg_kind::binary_add
81+
&& asp_.src0_info.tensor_shape() == asp_.src1_info.tensor_shape()
82+
&& asp_.src0_info.strides_in_bytes()
83+
!= asp_.src1_info.strides_in_bytes()) {
84+
return status::unimplemented;
85+
}
86+
87+
// This forces ACL not to parallelise with small workloads, this is
88+
// a temporary fix and should be removed in future versions (TODO)
89+
memory_desc_wrapper dst_d(dst_md());
90+
if (dst_d.nelems() < 40000) {
91+
size_t acl_y_axis_i = 1;
92+
CHECK(insert_singleton_dimension(asp_.src0_info, acl_y_axis_i));
93+
CHECK(insert_singleton_dimension(asp_.src1_info, acl_y_axis_i));
94+
CHECK(insert_singleton_dimension(asp_.dst_info, acl_y_axis_i));
95+
}
96+
97+
// Call operator specific validate function to check support
98+
ACL_CHECK_VALID(validate(asp_));
99+
100+
return status::success;
101+
}
102+
103+
arm_compute::Status acl_binary_t::pd_t::validate(const acl_binary_conf_t &asp) {
104+
switch (asp.alg) {
105+
case alg_kind::binary_add:
106+
return arm_compute::experimental::op::CpuAdd::validate(
107+
&asp.src0_info, &asp.src1_info, &asp.dst_info,
108+
arm_compute::ConvertPolicy::SATURATE);
109+
case alg_kind::binary_sub:
110+
return arm_compute::experimental::op::CpuSub::validate(
111+
&asp.src0_info, &asp.src1_info, &asp.dst_info,
112+
arm_compute::ConvertPolicy::SATURATE);
113+
case alg_kind::binary_div:
114+
return arm_compute::experimental::op::CpuElementwiseDivision::
115+
validate(&asp.src0_info, &asp.src1_info, &asp.dst_info);
116+
case alg_kind::binary_mul:
117+
return arm_compute::experimental::op::CpuMul::validate(
118+
&asp.src0_info, &asp.src1_info, &asp.dst_info, 1.0f,
119+
arm_compute::ConvertPolicy::SATURATE,
120+
arm_compute::RoundingPolicy::TO_ZERO);
121+
case alg_kind::binary_min:
122+
return arm_compute::experimental::op::CpuElementwiseMin::validate(
123+
&asp.src0_info, &asp.src1_info, &asp.dst_info);
124+
case alg_kind::binary_max:
125+
return arm_compute::experimental::op::CpuElementwiseMax::validate(
126+
&asp.src0_info, &asp.src1_info, &asp.dst_info);
127+
default:
128+
return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
129+
"unsupported alg_kind");
130+
}
131+
}
132+
133+
status_t acl_binary_t::init(engine_t *engine) {
134+
auto asp = pd()->asp_;
135+
136+
switch (asp.alg) {
137+
case alg_kind::binary_add: {
138+
auto add_op
139+
= std::make_unique<arm_compute::experimental::op::CpuAdd>();
140+
add_op->configure(&asp.src0_info, &asp.src1_info, &asp.dst_info,
141+
arm_compute::ConvertPolicy::SATURATE);
142+
binary_op_ = std::move(add_op);
143+
break;
144+
}
145+
case alg_kind::binary_sub: {
146+
auto sub_op
147+
= std::make_unique<arm_compute::experimental::op::CpuSub>();
148+
sub_op->configure(&asp.src0_info, &asp.src1_info, &asp.dst_info,
149+
arm_compute::ConvertPolicy::SATURATE);
150+
binary_op_ = std::move(sub_op);
151+
break;
152+
}
153+
case alg_kind::binary_div: {
154+
auto div_op = std::make_unique<
155+
arm_compute::experimental::op::CpuElementwiseDivision>();
156+
div_op->configure(&asp.src0_info, &asp.src1_info, &asp.dst_info);
157+
binary_op_ = std::move(div_op);
158+
break;
159+
}
160+
case alg_kind::binary_mul: {
161+
auto mul_op
162+
= std::make_unique<arm_compute::experimental::op::CpuMul>();
163+
mul_op->configure(&asp.src0_info, &asp.src1_info, &asp.dst_info,
164+
1.0f, arm_compute::ConvertPolicy::SATURATE,
165+
arm_compute::RoundingPolicy::TO_ZERO);
166+
binary_op_ = std::move(mul_op);
167+
break;
168+
}
169+
case alg_kind::binary_min: {
170+
auto min_op = std::make_unique<
171+
arm_compute::experimental::op::CpuElementwiseMin>();
172+
min_op->configure(&asp.src0_info, &asp.src1_info, &asp.dst_info);
173+
binary_op_ = std::move(min_op);
174+
break;
175+
}
176+
case alg_kind::binary_max: {
177+
auto max_op = std::make_unique<
178+
arm_compute::experimental::op::CpuElementwiseMax>();
179+
max_op->configure(&asp.src0_info, &asp.src1_info, &asp.dst_info);
180+
binary_op_ = std::move(max_op);
181+
break;
182+
}
183+
default: return status::runtime_error;
184+
}
185+
186+
return status::success;
187+
}
188+
24189
status_t acl_binary_t::execute_forward(const exec_ctx_t &ctx, const void *src0,
25190
const void *src1, void *dst) const {
26191

27-
// Lock here is needed because resource_mapper does not support
28-
// concurrent multithreaded access.
29-
std::lock_guard<std::mutex> _lock {this->mtx};
192+
auto asp = pd()->asp_;
30193

31-
// Retrieve primitive resource and configured Compute Library objects
32-
acl_binary_obj_t &acl_obj = ctx.get_resource_mapper()
33-
->get<acl_binary_resource_t>(this)
34-
->get_acl_obj();
194+
arm_compute::Tensor src0_tensor;
195+
arm_compute::Tensor src1_tensor;
196+
arm_compute::Tensor dst_tensor;
35197

36-
acl_obj.src0_tensor.allocator()->import_memory(const_cast<void *>(src0));
37-
acl_obj.src1_tensor.allocator()->import_memory(const_cast<void *>(src1));
38-
acl_obj.dst_tensor.allocator()->import_memory(dst);
198+
src0_tensor.allocator()->init(asp.src0_info);
199+
src0_tensor.allocator()->import_memory(const_cast<void *>(src0));
200+
src1_tensor.allocator()->init(asp.src1_info);
201+
src1_tensor.allocator()->import_memory(const_cast<void *>(src1));
202+
dst_tensor.allocator()->init(asp.dst_info);
203+
dst_tensor.allocator()->import_memory(dst);
39204

40-
acl_obj.binary_op->run();
205+
arm_compute::ITensorPack run_pack {
206+
{arm_compute::TensorType::ACL_SRC_0, &src0_tensor},
207+
{arm_compute::TensorType::ACL_SRC_1, &src1_tensor},
208+
{arm_compute::TensorType::ACL_DST, &dst_tensor}};
41209

42-
acl_obj.src0_tensor.allocator()->free();
43-
acl_obj.src1_tensor.allocator()->free();
44-
acl_obj.dst_tensor.allocator()->free();
210+
binary_op_->run(run_pack);
45211

46212
return status::success;
47213
}
@@ -55,6 +221,14 @@ status_t acl_binary_t::execute_forward(const exec_ctx_t &ctx) const {
55221
return execute_forward(ctx, src0, src1, dst);
56222
}
57223

224+
status_t acl_binary_t::execute(const exec_ctx_t &ctx) const {
225+
return execute_forward(ctx);
226+
}
227+
228+
const acl_binary_t::pd_t *acl_binary_t::pd() const {
229+
return static_cast<const pd_t *>(primitive_t::pd().get());
230+
}
231+
58232
} // namespace aarch64
59233
} // namespace cpu
60234
} // namespace impl

0 commit comments

Comments
 (0)