Skip to content

Commit b03603e

Browse files
committed
cpu: aarch64: Enable stateless ACL LayerNorm
1 parent 5238fef commit b03603e

File tree

2 files changed

+177
-217
lines changed

2 files changed

+177
-217
lines changed

src/cpu/aarch64/acl_layer_normalization.cpp

+159-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2023 Arm Ltd. and affiliates
2+
* Copyright 2023, 2025 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,29 +21,170 @@ namespace impl {
2121
namespace cpu {
2222
namespace aarch64 {
2323

24-
status_t acl_layer_normalization_fwd_t::execute_forward(
25-
const exec_ctx_t &ctx) const {
24+
status_t acl_layer_normalization_fwd_t::pd_t::init(engine_t *engine) {
25+
26+
// dir and flags
27+
ACL_CHECK_SUPPORT(!is_fwd(), "ACL lnorm supports forward propagation only");
28+
ACL_CHECK_SUPPORT(is_training(), "ACL supports inference only for lnorm");
29+
ACL_CHECK_SUPPORT(
30+
use_global_stats(), "ACL does not support global stats with lnorm");
31+
ACL_CHECK_SUPPORT(use_scale() || use_shift(),
32+
"ACL does not support lnorm scale and shift");
33+
34+
// attr-scales
35+
ACL_CHECK_SUPPORT(!attr()->has_default_values(),
36+
"ACL does not support scales attribute");
37+
38+
// tag and stat_tag
39+
ACL_CHECK_SUPPORT(src_md()->ndims < 2 || src_md()->ndims > 5,
40+
"src tensor must have between 2 and 5 (inclusive) "
41+
"dimensions");
42+
43+
// msdNorm only supports lnorm for src in a channels last format.
44+
// So if channels aren't last (ie. if they aren't dense),
45+
// then reorder into a channels last format
46+
std::string ref_implementation_guess = "simple:any";
47+
if (src_md()->format_desc.blocking.strides[ndims() - 1] != 1) {
48+
CHECK(memory_desc_init_by_tag(
49+
src_md_, get_channels_last_format(src_md_.ndims)));
50+
ref_implementation_guess = "ref:any";
51+
}
52+
if (dst_md_ != src_md_)
53+
// Make sure dst and src share a format
54+
CHECK(memory_desc_init_by_md_and_dt(
55+
dst_md_, src_md_, src_md()->data_type));
56+
if (!set_default_stat_md_format(src_md_)) return status::unimplemented;
57+
58+
const memory_desc_wrapper src_d(src_md_);
59+
const memory_desc_wrapper dst_d(dst_md_);
60+
61+
ACL_CHECK_SUPPORT(src_d.has_zero_dim() || dst_d.has_zero_dim(),
62+
"data tensor(s) must not have a zero dimension");
2663

27-
// Lock here is needed because resource_mapper does not support
28-
// concurrent access.
29-
std::lock_guard<std::mutex> _lock {this->mtx};
64+
// data type
65+
ACL_CHECK_SUPPORT(
66+
src_d.data_type() != data_type::f32, "ACL Lnorm only supports F32");
67+
ACL_CHECK_SUPPORT(dst_d.data_type() != src_d.data_type(),
68+
"src and dst must share data types");
69+
70+
// Problem shape
71+
int C = norm_axis(); // Channel dim size
72+
int X = src_d.nelems() / C; // Non-channel dims size
73+
74+
ACL_CHECK_SUPPORT(!use_acl_heuristic(X, C, dnnl_get_max_threads(),
75+
is_training(), ref_implementation_guess),
76+
"ACL is unoptimal in this case");
77+
78+
anp.data_info = arm_compute::TensorInfo(
79+
arm_compute::TensorShape(C, X), 1, arm_compute::DataType::F32);
80+
81+
ACL_CHECK_VALID(arm_compute::NEMeanStdDevNormalizationLayer::validate(
82+
&anp.data_info, &anp.data_info, desc()->layer_norm_epsilon));
83+
84+
return status::success;
85+
}
3086

31-
// Retrieve primitive resource and configured Compute Library objects
32-
auto *acl_resource
33-
= ctx.get_resource_mapper()
34-
->get<acl_layer_normalization_resource_t>(this);
35-
acl_msdnorm_obj_t &acl_obj = acl_resource->get_acl_obj();
87+
format_tag_t acl_layer_normalization_fwd_t::pd_t::get_channels_last_format(
88+
size_t ndim) {
89+
assert(ndim > 1 && ndim < 6);
90+
switch (ndim) {
91+
case 2: return format_tag::nc;
92+
case 3: return format_tag::tnc;
93+
case 4: return format_tag::ldnc;
94+
case 5: return format_tag::abcde;
95+
default: return format_tag::undef;
96+
}
97+
}
98+
99+
bool acl_layer_normalization_fwd_t::pd_t::use_acl_heuristic(int X, int C,
100+
int threads, bool ref_has_stats, std::string ref_implementation_guess) {
101+
// Above a certain C, acl is always better, and below a certain C,
102+
// acl is always worse. for C in between these two, whether acl is
103+
// better can be approximated with the workload (X*C) per thread.
104+
// The values here were derived empirically and all depend on
105+
// threads, whether ref can use provided stats, and which reference
106+
// implementation acl is competing with.
107+
108+
int acl_competitive_C = C;
109+
int acl_better_C = C;
110+
int acl_better_XC_per_thread = X * C;
111+
112+
if (ref_implementation_guess == "simple:any") {
113+
acl_competitive_C = 64;
114+
if (ref_has_stats) {
115+
acl_better_C = 4096;
116+
acl_better_XC_per_thread = threads == 1 ? 4096 : 8192;
117+
} else {
118+
acl_better_C = threads <= 2 ? 1024 : 4096;
119+
acl_better_XC_per_thread = threads == 1 ? 1024 : 4096;
120+
}
121+
} else if (ref_implementation_guess == "ref:any") {
122+
acl_competitive_C = 0;
123+
if (ref_has_stats) {
124+
if (threads == 1) {
125+
acl_better_C = 64;
126+
} else if (threads == 2) {
127+
acl_better_C = 256;
128+
} else {
129+
acl_better_C = 1024;
130+
}
131+
132+
if (threads == 1) {
133+
acl_better_XC_per_thread = 256;
134+
} else if (threads <= 16) {
135+
acl_better_XC_per_thread = 512;
136+
} else {
137+
acl_better_XC_per_thread = 1024;
138+
}
139+
} else {
140+
if (threads == 1) {
141+
acl_better_C = 64;
142+
} else if (threads <= 32) {
143+
acl_better_C = 256;
144+
} else {
145+
acl_better_C = 1024;
146+
}
147+
148+
if (threads == 1) {
149+
acl_better_XC_per_thread = 128;
150+
} else if (threads <= 32) {
151+
acl_better_XC_per_thread = 256;
152+
} else {
153+
acl_better_XC_per_thread = 512;
154+
}
155+
}
156+
}
157+
158+
return C > acl_competitive_C
159+
&& (C > acl_better_C || X * C > acl_better_XC_per_thread * threads);
160+
}
161+
162+
status_t acl_layer_normalization_fwd_t::init(engine_t *engine) {
163+
auto aep = pd()->anp;
164+
acl_obj_->configure(
165+
&aep.data_info, &aep.data_info, pd()->desc()->layer_norm_epsilon);
166+
return status::success;
167+
}
168+
169+
status_t acl_layer_normalization_fwd_t::execute_forward(
170+
const exec_ctx_t &ctx) const {
36171

37-
auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
38-
acl_obj.src_tensor.allocator()->import_memory(const_cast<float *>(src));
172+
const auto *src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
173+
auto *dst = CTX_OUT_MEM(float *, DNNL_ARG_DST);
39174

40-
auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST);
41-
acl_obj.dst_tensor.allocator()->import_memory(dst);
175+
auto aep = pd()->anp;
176+
arm_compute::Tensor src_tensor;
177+
arm_compute::Tensor dst_tensor;
42178

43-
acl_obj.msdNorm.run();
179+
src_tensor.allocator()->init(aep.data_info);
180+
src_tensor.allocator()->import_memory(const_cast<float *>(src));
181+
dst_tensor.allocator()->init(aep.data_info);
182+
dst_tensor.allocator()->import_memory(dst);
44183

45-
acl_obj.src_tensor.allocator()->free();
46-
acl_obj.dst_tensor.allocator()->free();
184+
arm_compute::ITensorPack act_pack;
185+
act_pack.add_tensor(arm_compute::TensorType::ACL_SRC, &src_tensor);
186+
act_pack.add_tensor(arm_compute::TensorType::ACL_DST, &dst_tensor);
187+
acl_obj_->run(act_pack);
47188

48189
return status::success;
49190
}

0 commit comments

Comments
 (0)