Skip to content

Commit 8493d2f

Browse files
committed
cpu: aarch64: Enable stateless ACL LayerNorm
1 parent 70a6801 commit 8493d2f

File tree

3 files changed

+182
-222
lines changed

3 files changed

+182
-222
lines changed

src/cpu/aarch64/acl_layer_normalization.cpp

+167-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2023 Arm Ltd. and affiliates
2+
* Copyright 2023, 2025 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,29 +21,178 @@ namespace impl {
2121
namespace cpu {
2222
namespace aarch64 {
2323

24-
status_t acl_layer_normalization_fwd_t::execute_forward(
25-
const exec_ctx_t &ctx) const {
24+
acl_layer_normalization_fwd_t::acl_layer_normalization_fwd_t(const pd_t *apd)
25+
: primitive_t(apd)
26+
, acl_obj_(std::make_unique<
27+
arm_compute::experimental::op::CpuMeanStdDevNormalization>()) {}
28+
29+
status_t acl_layer_normalization_fwd_t::pd_t::init(engine_t *engine) {
30+
31+
// dir and flags
32+
ACL_CHECK_SUPPORT(!is_fwd(), "ACL lnorm supports forward propagation only");
33+
ACL_CHECK_SUPPORT(is_training(), "ACL supports inference only for lnorm");
34+
ACL_CHECK_SUPPORT(
35+
use_global_stats(), "ACL does not support global stats with lnorm");
36+
ACL_CHECK_SUPPORT(use_scale() || use_shift(),
37+
"ACL does not support lnorm scale and shift");
38+
39+
// attr-scales
40+
ACL_CHECK_SUPPORT(!attr()->has_default_values(),
41+
"ACL does not support scales attribute");
42+
43+
// tag and stat_tag
44+
ACL_CHECK_SUPPORT(src_md()->ndims < 2 || src_md()->ndims > 5,
45+
"src tensor must have between 2 and 5 (inclusive) "
46+
"dimensions");
47+
48+
// msdNorm only supports lnorm for src in a channels last format.
49+
// So if channels aren't last (ie. if they aren't dense),
50+
// then reorder into a channels last format
51+
std::string ref_implementation_guess = "simple:any";
52+
if (src_md()->format_desc.blocking.strides[ndims() - 1] != 1) {
53+
CHECK(memory_desc_init_by_tag(
54+
src_md_, get_channels_last_format(src_md_.ndims)));
55+
ref_implementation_guess = "ref:any";
56+
}
57+
if (dst_md_ != src_md_)
58+
// Make sure dst and src share a format
59+
CHECK(memory_desc_init_by_md_and_dt(
60+
dst_md_, src_md_, src_md()->data_type));
61+
if (!set_default_stat_md_format(src_md_)) return status::unimplemented;
62+
63+
const memory_desc_wrapper src_d(src_md_);
64+
const memory_desc_wrapper dst_d(dst_md_);
2665

27-
// Lock here is needed because resource_mapper does not support
28-
// concurrent access.
29-
std::lock_guard<std::mutex> _lock {this->mtx};
66+
ACL_CHECK_SUPPORT(src_d.has_zero_dim() || dst_d.has_zero_dim(),
67+
"data tensor(s) must not have a zero dimension");
68+
69+
// data type
70+
ACL_CHECK_SUPPORT(
71+
src_d.data_type() != data_type::f32, "ACL Lnorm only supports F32");
72+
ACL_CHECK_SUPPORT(dst_d.data_type() != src_d.data_type(),
73+
"src and dst must share data types");
74+
75+
// Problem shape
76+
int C = norm_axis(); // Channel dim size
77+
int X = src_d.nelems() / C; // Non-channel dims size
78+
79+
ACL_CHECK_SUPPORT(!use_acl_heuristic(X, C, dnnl_get_max_threads(),
80+
is_training(), ref_implementation_guess),
81+
"ACL is unoptimal in this case");
82+
83+
anp_data_info = arm_compute::TensorInfo(
84+
arm_compute::TensorShape(C, X), 1, arm_compute::DataType::F32);
85+
86+
ACL_CHECK_VALID(
87+
arm_compute::experimental::op::CpuMeanStdDevNormalization::validate(
88+
&anp_data_info, &anp_data_info,
89+
desc()->layer_norm_epsilon));
90+
91+
return status::success;
92+
}
3093

31-
// Retrieve primitive resource and configured Compute Library objects
32-
auto *acl_resource
33-
= ctx.get_resource_mapper()
34-
->get<acl_layer_normalization_resource_t>(this);
35-
acl_msdnorm_obj_t &acl_obj = acl_resource->get_acl_obj();
94+
format_tag_t acl_layer_normalization_fwd_t::pd_t::get_channels_last_format(
95+
size_t ndim) const {
96+
assert(ndim > 1 && ndim < 6);
97+
switch (ndim) {
98+
case 2: return format_tag::nc;
99+
case 3: return format_tag::tnc;
100+
case 4: return format_tag::ldnc;
101+
case 5: return format_tag::abcde;
102+
default: return format_tag::undef;
103+
}
104+
}
105+
106+
bool acl_layer_normalization_fwd_t::pd_t::use_acl_heuristic(int X, int C,
107+
int threads, bool ref_has_stats,
108+
const std::string &ref_implementation_guess) const {
109+
// Above a certain C, acl is always better, and below a certain C,
110+
// acl is always worse. for C in between these two, whether acl is
111+
// better can be approximated with the workload (X*C) per thread.
112+
// The values here were derived empirically and all depend on
113+
// threads, whether ref can use provided stats, and which reference
114+
// implementation acl is competing with.
115+
116+
int acl_competitive_C = C;
117+
int acl_better_C = C;
118+
int acl_better_XC_per_thread = X * C;
119+
120+
if (ref_implementation_guess == "simple:any") {
121+
acl_competitive_C = 64;
122+
if (ref_has_stats) {
123+
acl_better_C = 4096;
124+
acl_better_XC_per_thread = threads == 1 ? 4096 : 8192;
125+
} else {
126+
acl_better_C = threads <= 2 ? 1024 : 4096;
127+
acl_better_XC_per_thread = threads == 1 ? 1024 : 4096;
128+
}
129+
} else if (ref_implementation_guess == "ref:any") {
130+
acl_competitive_C = 0;
131+
if (ref_has_stats) {
132+
if (threads == 1) {
133+
acl_better_C = 64;
134+
} else if (threads == 2) {
135+
acl_better_C = 256;
136+
} else {
137+
acl_better_C = 1024;
138+
}
139+
140+
if (threads == 1) {
141+
acl_better_XC_per_thread = 256;
142+
} else if (threads <= 16) {
143+
acl_better_XC_per_thread = 512;
144+
} else {
145+
acl_better_XC_per_thread = 1024;
146+
}
147+
} else {
148+
if (threads == 1) {
149+
acl_better_C = 64;
150+
acl_better_XC_per_thread = 128;
151+
} else if (threads <= 32) {
152+
acl_better_C = 256;
153+
acl_better_XC_per_thread = 256;
154+
} else {
155+
acl_better_C = 1024;
156+
acl_better_XC_per_thread = 512;
157+
}
158+
}
159+
}
160+
161+
return C > acl_competitive_C
162+
&& (C > acl_better_C || X * C > acl_better_XC_per_thread * threads);
163+
}
164+
165+
const acl_layer_normalization_fwd_t::pd_t *
166+
acl_layer_normalization_fwd_t::pd() const {
167+
return (const pd_t *)primitive_t::pd().get();
168+
}
169+
170+
status_t acl_layer_normalization_fwd_t::init(engine_t *engine) {
171+
auto *anp_data_info
172+
= const_cast<arm_compute::TensorInfo *>(&pd()->anp_data_info);
173+
acl_obj_->configure(
174+
anp_data_info, anp_data_info, pd()->desc()->layer_norm_epsilon);
175+
return status::success;
176+
}
177+
178+
status_t acl_layer_normalization_fwd_t::execute_forward(
179+
const exec_ctx_t &ctx) const {
36180

37-
auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
38-
acl_obj.src_tensor.allocator()->import_memory(const_cast<float *>(src));
181+
const auto *src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
182+
auto *dst = CTX_OUT_MEM(float *, DNNL_ARG_DST);
39183

40-
auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST);
41-
acl_obj.dst_tensor.allocator()->import_memory(dst);
184+
arm_compute::Tensor src_tensor;
185+
arm_compute::Tensor dst_tensor;
42186

43-
acl_obj.msdNorm.run();
187+
src_tensor.allocator()->init(pd()->anp_data_info);
188+
src_tensor.allocator()->import_memory(const_cast<float *>(src));
189+
dst_tensor.allocator()->init(pd()->anp_data_info);
190+
dst_tensor.allocator()->import_memory(dst);
44191

45-
acl_obj.src_tensor.allocator()->free();
46-
acl_obj.dst_tensor.allocator()->free();
192+
arm_compute::ITensorPack act_pack;
193+
act_pack.add_tensor(arm_compute::TensorType::ACL_SRC, &src_tensor);
194+
act_pack.add_tensor(arm_compute::TensorType::ACL_DST, &dst_tensor);
195+
acl_obj_->run(act_pack);
47196

48197
return status::success;
49198
}

0 commit comments

Comments
 (0)