Skip to content

Commit c5c10ad

Browse files
taoye9avmanerikar
authored andcommitted
cpu: aarch64: hot fix for aux tensor management of stateless gemm and winograd conv
Signed-off-by: Ye Tao <ye.tao@arm.com>
1 parent 6ff1ab1 commit c5c10ad

4 files changed

+34
-4
lines changed

src/cpu/aarch64/acl_gemm_convolution.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,26 @@ status_t acl_gemm_convolution_fwd_t<src_t, wei_t, dst_t, bia_t>::init(
9898
return status::success;
9999
}
100100

101+
template <data_type_t src_t, data_type_t wei_t, data_type_t dst_t,
102+
data_type_t bia_t>
103+
void acl_gemm_convolution_fwd_t<src_t, wei_t, dst_t,
104+
bia_t>::reinitialize_acl_obj() const {
105+
auto acp = pd()->acp_;
106+
std::lock_guard<std::mutex> _lock {this->mtx};
107+
acl_obj_ = std::make_unique<acl_obj_t<Op>>();
108+
acl_obj_->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info,
109+
acp.with_bias ? &acp.bia_tensor_info : nullptr,
110+
&acp.dst_tensor_info, acp.padstride_info, acp.weights_info,
111+
acp.dilation_info, acp.act_info, acp.fast_math);
112+
acl_obj_->aux_mem_req = acl_obj_->conv.workspace();
113+
}
114+
101115
template <data_type_t src_t, data_type_t wei_t, data_type_t dst_t,
102116
data_type_t bia_t>
103117
status_t
104118
acl_gemm_convolution_fwd_t<src_t, wei_t, dst_t, bia_t>::execute_forward(
105119
const exec_ctx_t &ctx) const {
120+
reinitialize_acl_obj();
106121
return execute_forward_conv_acl<acl_obj_t<Op>, pd_t, src_data_t, wei_data_t,
107122
dst_data_t, bia_data_t>(ctx, acl_obj_.get(), pd(), gemm_conv_keys);
108123
}

src/cpu/aarch64/acl_gemm_convolution.hpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,14 @@ struct acl_gemm_convolution_fwd_t : public primitive_t {
6565

6666
private:
6767
status_t execute_forward(const exec_ctx_t &ctx) const;
68+
69+
// hot fix solution for stateless API which should be replaced soon.
70+
mutable std::mutex mtx;
71+
void reinitialize_acl_obj() const;
72+
6873
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
69-
std::unique_ptr<acl_obj_t<Op>> acl_obj_;
74+
mutable std::unique_ptr<acl_obj_t<Op>> acl_obj_;
75+
7076
}; // acl_gemm_convolution_fwd_t
7177

7278
} // namespace aarch64

src/cpu/aarch64/acl_winograd_convolution.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,21 @@ status_t acl_wino_convolution_fwd_t::pd_t::init_conf() {
129129
return status::success;
130130
}
131131

132-
status_t acl_wino_convolution_fwd_t::execute_forward(
133-
const exec_ctx_t &ctx) const {
134-
acl_obj_ = std::make_unique<acl_obj_t<Op>>();
132+
void acl_wino_convolution_fwd_t::reinitialize_acl_obj() const {
135133
auto acp = pd()->acp_;
134+
std::lock_guard<std::mutex> _lock {this->mtx};
135+
acl_obj_ = std::make_unique<acl_obj_t<Op>>();
136136
acl_obj_->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info,
137137
acp.with_bias ? &acp.bia_tensor_info : nullptr,
138138
&acp.dst_tensor_info, acp.padstride_info, acp.act_info,
139139
true); // to support 5x5, 7x7 filter shapes in addition to 3x3
140140

141141
acl_obj_->aux_mem_req = acl_obj_->conv.workspace();
142+
}
142143

144+
status_t acl_wino_convolution_fwd_t::execute_forward(
145+
const exec_ctx_t &ctx) const {
146+
reinitialize_acl_obj();
143147
return execute_forward_conv_acl<acl_obj_t<Op>, pd_t, data_t>(
144148
ctx, acl_obj_.get(), pd(), wino_conv_keys);
145149
}

src/cpu/aarch64/acl_winograd_convolution.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ struct acl_wino_convolution_fwd_t : public primitive_t {
5858

5959
private:
6060
status_t execute_forward(const exec_ctx_t &ctx) const;
61+
62+
// hot fix solution for stateless API which should be replaced soon.
63+
mutable std::mutex mtx;
64+
void reinitialize_acl_obj() const;
65+
6166
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
6267
mutable std::unique_ptr<acl_obj_t<Op>> acl_obj_;
6368
}; // acl_wino_convolution_fwd_t

0 commit comments

Comments
 (0)