1
1
/* ******************************************************************************
2
- * Copyright 2023 Arm Ltd. and affiliates
2
+ * Copyright 2023, 2025 Arm Ltd. and affiliates
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
@@ -21,29 +21,170 @@ namespace impl {
21
21
namespace cpu {
22
22
namespace aarch64 {
23
23
24
- status_t acl_layer_normalization_fwd_t::execute_forward (
25
- const exec_ctx_t &ctx) const {
24
+ status_t acl_layer_normalization_fwd_t::pd_t::init (engine_t *engine) {
25
+
26
+ // dir and flags
27
+ ACL_CHECK_SUPPORT (!is_fwd (), " ACL lnorm supports forward propagation only" );
28
+ ACL_CHECK_SUPPORT (is_training (), " ACL supports inference only for lnorm" );
29
+ ACL_CHECK_SUPPORT (
30
+ use_global_stats (), " ACL does not support global stats with lnorm" );
31
+ ACL_CHECK_SUPPORT (use_scale () || use_shift (),
32
+ " ACL does not support lnorm scale and shift" );
33
+
34
+ // attr-scales
35
+ ACL_CHECK_SUPPORT (!attr ()->has_default_values (),
36
+ " ACL does not support scales attribute" );
37
+
38
+ // tag and stat_tag
39
+ ACL_CHECK_SUPPORT (src_md ()->ndims < 2 || src_md ()->ndims > 5 ,
40
+ " src tensor must have between 2 and 5 (inclusive) "
41
+ " dimensions" );
42
+
43
+ // msdNorm only supports lnorm for src in a channels last format.
44
+ // So if channels aren't last (ie. if they aren't dense),
45
+ // then reorder into a channels last format
46
+ std::string ref_implementation_guess = " simple:any" ;
47
+ if (src_md ()->format_desc .blocking .strides [ndims () - 1 ] != 1 ) {
48
+ CHECK (memory_desc_init_by_tag (
49
+ src_md_, get_channels_last_format (src_md_.ndims )));
50
+ ref_implementation_guess = " ref:any" ;
51
+ }
52
+ if (dst_md_ != src_md_)
53
+ // Make sure dst and src share a format
54
+ CHECK (memory_desc_init_by_md_and_dt (
55
+ dst_md_, src_md_, src_md ()->data_type ));
56
+ if (!set_default_stat_md_format (src_md_)) return status::unimplemented;
57
+
58
+ const memory_desc_wrapper src_d (src_md_);
59
+ const memory_desc_wrapper dst_d (dst_md_);
60
+
61
+ ACL_CHECK_SUPPORT (src_d.has_zero_dim () || dst_d.has_zero_dim (),
62
+ " data tensor(s) must not have a zero dimension" );
26
63
27
- // Lock here is needed because resource_mapper does not support
28
- // concurrent access.
29
- std::lock_guard<std::mutex> _lock {this ->mtx };
64
+ // data type
65
+ ACL_CHECK_SUPPORT (
66
+ src_d.data_type () != data_type::f32, " ACL Lnorm only supports F32" );
67
+ ACL_CHECK_SUPPORT (dst_d.data_type () != src_d.data_type (),
68
+ " src and dst must share data types" );
69
+
70
+ // Problem shape
71
+ int C = norm_axis (); // Channel dim size
72
+ int X = src_d.nelems () / C; // Non-channel dims size
73
+
74
+ ACL_CHECK_SUPPORT (!use_acl_heuristic (X, C, dnnl_get_max_threads (),
75
+ is_training (), ref_implementation_guess),
76
+ " ACL is unoptimal in this case" );
77
+
78
+ anp.data_info = arm_compute::TensorInfo (
79
+ arm_compute::TensorShape (C, X), 1 , arm_compute::DataType::F32);
80
+
81
+ ACL_CHECK_VALID (arm_compute::NEMeanStdDevNormalizationLayer::validate (
82
+ &anp.data_info , &anp.data_info , desc ()->layer_norm_epsilon ));
83
+
84
+ return status::success;
85
+ }
30
86
31
- // Retrieve primitive resource and configured Compute Library objects
32
- auto *acl_resource
33
- = ctx.get_resource_mapper ()
34
- ->get <acl_layer_normalization_resource_t >(this );
35
- acl_msdnorm_obj_t &acl_obj = acl_resource->get_acl_obj ();
87
+ format_tag_t acl_layer_normalization_fwd_t::pd_t::get_channels_last_format (
88
+ size_t ndim) {
89
+ assert (ndim > 1 && ndim < 6 );
90
+ switch (ndim) {
91
+ case 2 : return format_tag::nc;
92
+ case 3 : return format_tag::tnc;
93
+ case 4 : return format_tag::ldnc;
94
+ case 5 : return format_tag::abcde;
95
+ default : return format_tag::undef;
96
+ }
97
+ }
98
+
99
+ bool acl_layer_normalization_fwd_t::pd_t::use_acl_heuristic (int X, int C,
100
+ int threads, bool ref_has_stats, std::string ref_implementation_guess) {
101
+ // Above a certain C, acl is always better, and below a certain C,
102
+ // acl is always worse. for C in between these two, whether acl is
103
+ // better can be approximated with the workload (X*C) per thread.
104
+ // The values here were derived empirically and all depend on
105
+ // threads, whether ref can use provided stats, and which reference
106
+ // implementation acl is competing with.
107
+
108
+ int acl_competitive_C = C;
109
+ int acl_better_C = C;
110
+ int acl_better_XC_per_thread = X * C;
111
+
112
+ if (ref_implementation_guess == " simple:any" ) {
113
+ acl_competitive_C = 64 ;
114
+ if (ref_has_stats) {
115
+ acl_better_C = 4096 ;
116
+ acl_better_XC_per_thread = threads == 1 ? 4096 : 8192 ;
117
+ } else {
118
+ acl_better_C = threads <= 2 ? 1024 : 4096 ;
119
+ acl_better_XC_per_thread = threads == 1 ? 1024 : 4096 ;
120
+ }
121
+ } else if (ref_implementation_guess == " ref:any" ) {
122
+ acl_competitive_C = 0 ;
123
+ if (ref_has_stats) {
124
+ if (threads == 1 ) {
125
+ acl_better_C = 64 ;
126
+ } else if (threads == 2 ) {
127
+ acl_better_C = 256 ;
128
+ } else {
129
+ acl_better_C = 1024 ;
130
+ }
131
+
132
+ if (threads == 1 ) {
133
+ acl_better_XC_per_thread = 256 ;
134
+ } else if (threads <= 16 ) {
135
+ acl_better_XC_per_thread = 512 ;
136
+ } else {
137
+ acl_better_XC_per_thread = 1024 ;
138
+ }
139
+ } else {
140
+ if (threads == 1 ) {
141
+ acl_better_C = 64 ;
142
+ } else if (threads <= 32 ) {
143
+ acl_better_C = 256 ;
144
+ } else {
145
+ acl_better_C = 1024 ;
146
+ }
147
+
148
+ if (threads == 1 ) {
149
+ acl_better_XC_per_thread = 128 ;
150
+ } else if (threads <= 32 ) {
151
+ acl_better_XC_per_thread = 256 ;
152
+ } else {
153
+ acl_better_XC_per_thread = 512 ;
154
+ }
155
+ }
156
+ }
157
+
158
+ return C > acl_competitive_C
159
+ && (C > acl_better_C || X * C > acl_better_XC_per_thread * threads);
160
+ }
161
+
162
+ status_t acl_layer_normalization_fwd_t::init (engine_t *engine) {
163
+ auto aep = pd ()->anp ;
164
+ acl_obj_->configure (
165
+ &aep.data_info , &aep.data_info , pd ()->desc ()->layer_norm_epsilon );
166
+ return status::success;
167
+ }
168
+
169
+ status_t acl_layer_normalization_fwd_t::execute_forward (
170
+ const exec_ctx_t &ctx) const {
36
171
37
- auto src = CTX_IN_MEM (const float *, DNNL_ARG_SRC);
38
- acl_obj. src_tensor . allocator ()-> import_memory ( const_cast < float *>(src) );
172
+ const auto * src = CTX_IN_MEM (const float *, DNNL_ARG_SRC);
173
+ auto *dst = CTX_OUT_MEM ( float *, DNNL_ARG_DST );
39
174
40
- auto dst = CTX_OUT_MEM (float *, DNNL_ARG_DST);
41
- acl_obj.dst_tensor .allocator ()->import_memory (dst);
175
+ auto aep = pd ()->anp ;
176
+ arm_compute::Tensor src_tensor;
177
+ arm_compute::Tensor dst_tensor;
42
178
43
- acl_obj.msdNorm .run ();
179
+ src_tensor.allocator ()->init (aep.data_info );
180
+ src_tensor.allocator ()->import_memory (const_cast <float *>(src));
181
+ dst_tensor.allocator ()->init (aep.data_info );
182
+ dst_tensor.allocator ()->import_memory (dst);
44
183
45
- acl_obj.src_tensor .allocator ()->free ();
46
- acl_obj.dst_tensor .allocator ()->free ();
184
+ arm_compute::ITensorPack act_pack;
185
+ act_pack.add_tensor (arm_compute::TensorType::ACL_SRC, &src_tensor);
186
+ act_pack.add_tensor (arm_compute::TensorType::ACL_DST, &dst_tensor);
187
+ acl_obj_->run (act_pack);
47
188
48
189
return status::success;
49
190
}
0 commit comments