1
1
/* ******************************************************************************
2
- * Copyright 2022 Arm Ltd. and affiliates
2
+ * Copyright 2022, 2024 Arm Ltd. and affiliates
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
14
14
* limitations under the License.
15
15
*******************************************************************************/
16
16
17
- #include " cpu/aarch64/acl_binary.hpp"
17
+ #include " acl_binary.hpp"
18
+
19
+ #include " arm_compute/core/ITensorPack.h"
20
+ #include " arm_compute/core/experimental/Types.h"
21
+ #include " arm_compute/runtime/Tensor.h"
22
+ #include " arm_compute/runtime/experimental/operators/CpuAdd.h"
23
+ #include " arm_compute/runtime/experimental/operators/CpuElementwise.h"
24
+ #include " arm_compute/runtime/experimental/operators/CpuMul.h"
25
+ #include " arm_compute/runtime/experimental/operators/CpuSub.h"
18
26
19
27
namespace dnnl {
20
28
namespace impl {
21
29
namespace cpu {
22
30
namespace aarch64 {
23
31
32
+ status_t acl_binary_t::pd_t::init (engine_t *engine) {
33
+ using namespace acl_utils ;
34
+
35
+ // Only support f16/f32/s32 for now
36
+ data_type_t ddt = dst_md (0 )->data_type ;
37
+ if (!utils::one_of (ddt, data_type::f16, data_type::f32, data_type::s32))
38
+ return status::unimplemented;
39
+
40
+ // Only support src and dst all matching for now
41
+ if (ddt != src_md (0 )->data_type || src_md (1 )->data_type != ddt)
42
+ return status::unimplemented;
43
+
44
+ // Sets the memory format of dst from any to src_md(0) blocking desc
45
+ CHECK (set_default_params ());
46
+
47
+ if (!attr ()->has_default_values ()) return status::unimplemented;
48
+
49
+ asp_.alg = desc ()->alg_kind ;
50
+
51
+ // All the algorithms we support
52
+ if (!utils::one_of (asp_.alg , alg_kind::binary_add, alg_kind::binary_sub,
53
+ alg_kind::binary_mul, alg_kind::binary_div,
54
+ alg_kind::binary_max, alg_kind::binary_min))
55
+ return status::unimplemented;
56
+
57
+ // s32 div in ACL does not round as oneDNN expects
58
+ if (ddt == data_type::s32 && asp_.alg == alg_kind::binary_div)
59
+ return status::unimplemented;
60
+
61
+ // ACL pointwise arithmetic operators assume that the innermost
62
+ // dimensions are dense for src0, src1 and dst. Reordering the
63
+ // logical dimensions by stride does this (if reordered_dims >= 1 )
64
+ // and also makes memory accesses contiguous in ACL (without any
65
+ // data reordering).
66
+ memory_desc_t src_d0_permed, src_d1_permed, dst_d_permed;
67
+ int reordered_dims = reorder_dimensions_by_stride (
68
+ {&src_d0_permed, &src_d1_permed, &dst_d_permed},
69
+ {src_md (0 ), src_md (1 ), dst_md ()});
70
+ if (reordered_dims < 1 ) return status::unimplemented;
71
+
72
+ // Create ACL tensor infos with permuted descs
73
+ CHECK (tensor_info (asp_.src0_info , src_d0_permed));
74
+ CHECK (tensor_info (asp_.src1_info , src_d1_permed));
75
+ CHECK (tensor_info (asp_.dst_info , dst_d_permed));
76
+
77
+ // In this case ACL tries to treat src0 and src1 as a 1D array, but
78
+ // fails because the strides aren't equal. TODO: remove when fixed
79
+ // in ACL
80
+ if (asp_.alg == alg_kind::binary_add
81
+ && asp_.src0_info .tensor_shape () == asp_.src1_info .tensor_shape ()
82
+ && asp_.src0_info .strides_in_bytes ()
83
+ != asp_.src1_info .strides_in_bytes ()) {
84
+ return status::unimplemented;
85
+ }
86
+
87
+ // This forces ACL not to parallelise with small workloads, this is
88
+ // a temporary fix and should be removed in future versions (TODO)
89
+ memory_desc_wrapper dst_d (dst_md ());
90
+ if (dst_d.nelems () < 40000 ) {
91
+ size_t acl_y_axis_i = 1 ;
92
+ CHECK (insert_singleton_dimension (asp_.src0_info , acl_y_axis_i));
93
+ CHECK (insert_singleton_dimension (asp_.src1_info , acl_y_axis_i));
94
+ CHECK (insert_singleton_dimension (asp_.dst_info , acl_y_axis_i));
95
+ }
96
+
97
+ // Call operator specific validate function to check support
98
+ ACL_CHECK_VALID (validate (asp_));
99
+
100
+ return status::success;
101
+ }
102
+
103
+ arm_compute::Status acl_binary_t::pd_t::validate (const acl_binary_conf_t &asp) {
104
+ switch (asp.alg ) {
105
+ case alg_kind::binary_add:
106
+ return arm_compute::experimental::op::CpuAdd::validate (
107
+ &asp.src0_info , &asp.src1_info , &asp.dst_info ,
108
+ arm_compute::ConvertPolicy::SATURATE);
109
+ case alg_kind::binary_sub:
110
+ return arm_compute::experimental::op::CpuSub::validate (
111
+ &asp.src0_info , &asp.src1_info , &asp.dst_info ,
112
+ arm_compute::ConvertPolicy::SATURATE);
113
+ case alg_kind::binary_div:
114
+ return arm_compute::experimental::op::CpuElementwiseDivision::
115
+ validate (&asp.src0_info , &asp.src1_info , &asp.dst_info );
116
+ case alg_kind::binary_mul:
117
+ return arm_compute::experimental::op::CpuMul::validate (
118
+ &asp.src0_info , &asp.src1_info , &asp.dst_info , 1 .0f ,
119
+ arm_compute::ConvertPolicy::SATURATE,
120
+ arm_compute::RoundingPolicy::TO_ZERO);
121
+ case alg_kind::binary_min:
122
+ return arm_compute::experimental::op::CpuElementwiseMin::validate (
123
+ &asp.src0_info , &asp.src1_info , &asp.dst_info );
124
+ case alg_kind::binary_max:
125
+ return arm_compute::experimental::op::CpuElementwiseMax::validate (
126
+ &asp.src0_info , &asp.src1_info , &asp.dst_info );
127
+ default :
128
+ return arm_compute::Status (arm_compute::ErrorCode::RUNTIME_ERROR,
129
+ " unsupported alg_kind" );
130
+ }
131
+ }
132
+
133
+ status_t acl_binary_t::init (engine_t *engine) {
134
+ auto asp = pd ()->asp_ ;
135
+
136
+ switch (asp.alg ) {
137
+ case alg_kind::binary_add: {
138
+ auto add_op
139
+ = std::make_unique<arm_compute::experimental::op::CpuAdd>();
140
+ add_op->configure (&asp.src0_info , &asp.src1_info , &asp.dst_info ,
141
+ arm_compute::ConvertPolicy::SATURATE);
142
+ binary_op_ = std::move (add_op);
143
+ break ;
144
+ }
145
+ case alg_kind::binary_sub: {
146
+ auto sub_op
147
+ = std::make_unique<arm_compute::experimental::op::CpuSub>();
148
+ sub_op->configure (&asp.src0_info , &asp.src1_info , &asp.dst_info ,
149
+ arm_compute::ConvertPolicy::SATURATE);
150
+ binary_op_ = std::move (sub_op);
151
+ break ;
152
+ }
153
+ case alg_kind::binary_div: {
154
+ auto div_op = std::make_unique<
155
+ arm_compute::experimental::op::CpuElementwiseDivision>();
156
+ div_op->configure (&asp.src0_info , &asp.src1_info , &asp.dst_info );
157
+ binary_op_ = std::move (div_op);
158
+ break ;
159
+ }
160
+ case alg_kind::binary_mul: {
161
+ auto mul_op
162
+ = std::make_unique<arm_compute::experimental::op::CpuMul>();
163
+ mul_op->configure (&asp.src0_info , &asp.src1_info , &asp.dst_info ,
164
+ 1 .0f , arm_compute::ConvertPolicy::SATURATE,
165
+ arm_compute::RoundingPolicy::TO_ZERO);
166
+ binary_op_ = std::move (mul_op);
167
+ break ;
168
+ }
169
+ case alg_kind::binary_min: {
170
+ auto min_op = std::make_unique<
171
+ arm_compute::experimental::op::CpuElementwiseMin>();
172
+ min_op->configure (&asp.src0_info , &asp.src1_info , &asp.dst_info );
173
+ binary_op_ = std::move (min_op);
174
+ break ;
175
+ }
176
+ case alg_kind::binary_max: {
177
+ auto max_op = std::make_unique<
178
+ arm_compute::experimental::op::CpuElementwiseMax>();
179
+ max_op->configure (&asp.src0_info , &asp.src1_info , &asp.dst_info );
180
+ binary_op_ = std::move (max_op);
181
+ break ;
182
+ }
183
+ default : return status::runtime_error;
184
+ }
185
+
186
+ return status::success;
187
+ }
188
+
24
189
status_t acl_binary_t::execute_forward (const exec_ctx_t &ctx, const void *src0,
25
190
const void *src1, void *dst) const {
26
191
27
- // Lock here is needed because resource_mapper does not support
28
- // concurrent multithreaded access.
29
- std::lock_guard<std::mutex> _lock {this ->mtx };
192
+ auto asp = pd ()->asp_ ;
30
193
31
- // Retrieve primitive resource and configured Compute Library objects
32
- acl_binary_obj_t &acl_obj = ctx.get_resource_mapper ()
33
- ->get <acl_binary_resource_t >(this )
34
- ->get_acl_obj ();
194
+ arm_compute::Tensor src0_tensor;
195
+ arm_compute::Tensor src1_tensor;
196
+ arm_compute::Tensor dst_tensor;
35
197
36
- acl_obj.src0_tensor .allocator ()->import_memory (const_cast <void *>(src0));
37
- acl_obj.src1_tensor .allocator ()->import_memory (const_cast <void *>(src1));
38
- acl_obj.dst_tensor .allocator ()->import_memory (dst);
198
+ src0_tensor.allocator ()->init (asp.src0_info );
199
+ src0_tensor.allocator ()->import_memory (const_cast <void *>(src0));
200
+ src1_tensor.allocator ()->init (asp.src1_info );
201
+ src1_tensor.allocator ()->import_memory (const_cast <void *>(src1));
202
+ dst_tensor.allocator ()->init (asp.dst_info );
203
+ dst_tensor.allocator ()->import_memory (dst);
39
204
40
- acl_obj.binary_op ->run ();
205
+ arm_compute::ITensorPack run_pack {
206
+ {arm_compute::TensorType::ACL_SRC_0, &src0_tensor},
207
+ {arm_compute::TensorType::ACL_SRC_1, &src1_tensor},
208
+ {arm_compute::TensorType::ACL_DST, &dst_tensor}};
41
209
42
- acl_obj.src0_tensor .allocator ()->free ();
43
- acl_obj.src1_tensor .allocator ()->free ();
44
- acl_obj.dst_tensor .allocator ()->free ();
210
+ binary_op_->run (run_pack);
45
211
46
212
return status::success;
47
213
}
@@ -55,6 +221,14 @@ status_t acl_binary_t::execute_forward(const exec_ctx_t &ctx) const {
55
221
return execute_forward (ctx, src0, src1, dst);
56
222
}
57
223
224
+ status_t acl_binary_t::execute (const exec_ctx_t &ctx) const {
225
+ return execute_forward (ctx);
226
+ }
227
+
228
+ const acl_binary_t ::pd_t *acl_binary_t ::pd() const {
229
+ return static_cast <const pd_t *>(primitive_t::pd ().get ());
230
+ }
231
+
58
232
} // namespace aarch64
59
233
} // namespace cpu
60
234
} // namespace impl
0 commit comments