Skip to content

Commit 513f882

Browse files
taoye9mgouicem
authored andcommitted
cpu: aarch64: hot fix for aux tensor management of stateless gemm-conv and winograd conv without lock.
Signed-off-by: Ye Tao <ye.tao@arm.com> Change-Id: Ifb30292a8bfc5219c44515eb4d29b277a0f0b24a Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/oncpuml/oneDNN/+/680780 Tested-by: svc_mongoosetron <svc_mongoosetron@arm.com> Reviewed-by: Hamza Butt <hamza.butt@arm.com> IP-Review: Hamza Butt <hamza.butt@arm.com>
1 parent b9c3c4f commit 513f882

4 files changed

+74
-18
lines changed

src/cpu/aarch64/acl_gemm_convolution.cpp

+26-4
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,44 @@ template <data_type_t src_t, data_type_t wei_t, data_type_t dst_t,
8989
data_type_t bia_t>
9090
status_t acl_gemm_convolution_fwd_t<src_t, wei_t, dst_t, bia_t>::init(
9191
engine_t *engine) {
92+
// commented due to hot fix solution for stateless API which should be replaced soon.
93+
// auto acp_ = pd()->acp_;
94+
// acl_obj_->conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info,
95+
// acp_.with_bias ? &acp_.bia_tensor_info : nullptr,
96+
// &acp_.dst_tensor_info, acp_.padstride_info, acp_.weights_info,
97+
// acp_.dilation_info, acp_.act_info, acp_.fast_math);
98+
// acl_obj_->aux_mem_req = acl_obj_->conv.workspace();
99+
return status::success;
100+
}
101+
102+
template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
103+
data_type_t bia_type>
104+
std::unique_ptr<acl_obj_t<typename acl_gemm_convolution_fwd_t<src_type,
105+
wei_type, dst_type, bia_type>::Op>>
106+
acl_gemm_convolution_fwd_t<src_type, wei_type, dst_type,
107+
bia_type>::reinitialize_acl_obj() const {
92108
auto acp_ = pd()->acp_;
93-
acl_obj_->conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info,
109+
std::unique_ptr<acl_obj_t<Op>> acl_obj = std::make_unique<acl_obj_t<Op>>();
110+
acl_obj->conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info,
94111
acp_.with_bias ? &acp_.bia_tensor_info : nullptr,
95112
&acp_.dst_tensor_info, acp_.padstride_info, acp_.weights_info,
96113
acp_.dilation_info, acp_.act_info, acp_.fast_math);
97-
acl_obj_->aux_mem_req = acl_obj_->conv.workspace();
98-
return status::success;
114+
acl_obj->aux_mem_req = acl_obj->conv.workspace();
115+
return acl_obj;
99116
}
100117

101118
template <data_type_t src_t, data_type_t wei_t, data_type_t dst_t,
102119
data_type_t bia_t>
103120
status_t
104121
acl_gemm_convolution_fwd_t<src_t, wei_t, dst_t, bia_t>::execute_forward(
105122
const exec_ctx_t &ctx) const {
123+
// Temporary hotfix: We're using a local acl_obj instance in this method
124+
// instead of the class member acl_obj_. This hotfix is to bypass persistent aux mem requirements but is not the ideal solution.
125+
// It should be refactored or removed in the future when a more permanent fix is implemented.
126+
auto acl_obj = reinitialize_acl_obj();
127+
106128
return execute_forward_conv_acl<acl_obj_t<Op>, pd_t, src_data_t, wei_data_t,
107-
dst_data_t, bia_data_t>(ctx, acl_obj_.get(), pd(), gemm_conv_keys);
129+
dst_data_t, bia_data_t>(ctx, acl_obj.get(), pd(), gemm_conv_keys);
108130
}
109131

110132
using namespace data_type;

src/cpu/aarch64/acl_gemm_convolution.hpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ struct acl_gemm_convolution_fwd_t : public primitive_t {
4949
acl_post_ops_t post_ops;
5050
};
5151

52-
acl_gemm_convolution_fwd_t(const pd_t *apd)
53-
: primitive_t(apd), acl_obj_(std::make_unique<acl_obj_t<Op>>()) {}
52+
// hot fix solution for stateless API which should be replaced soon.
53+
// acl_gemm_convolution_fwd_t(const pd_t *apd)
54+
// : primitive_t(apd), acl_obj_(std::make_unique<acl_obj_t<Op>>()) {}
55+
acl_gemm_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
5456

5557
status_t init(engine_t *engine) override;
5658

@@ -65,8 +67,15 @@ struct acl_gemm_convolution_fwd_t : public primitive_t {
6567

6668
private:
6769
status_t execute_forward(const exec_ctx_t &ctx) const;
70+
71+
// hot fix solution for stateless API which should be replaced soon.
72+
std::unique_ptr<acl_obj_t<Op>> reinitialize_acl_obj() const;
73+
6874
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
69-
std::unique_ptr<acl_obj_t<Op>> acl_obj_;
75+
76+
// commented due to hot fix solution for stateless API which should be replaced soon.
77+
// std::unique_ptr<acl_obj_t<Op>> acl_obj_;
78+
7079
}; // acl_gemm_convolution_fwd_t
7180

7281
} // namespace aarch64

src/cpu/aarch64/acl_winograd_convolution.cpp

+26-8
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ status_t acl_wino_convolution_fwd_t::pd_t::init(engine_t *engine) {
7979
}
8080

8181
status_t acl_wino_convolution_fwd_t::init(engine_t *engine) {
82-
auto acp = pd()->acp_;
83-
acl_obj_->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info,
84-
acp.with_bias ? &acp.bia_tensor_info : nullptr,
85-
&acp.dst_tensor_info, acp.padstride_info, acp.act_info,
86-
true); // to support 5x5, 7x7 filter shapes in addition to 3x3
87-
88-
acl_obj_->aux_mem_req = acl_obj_->conv.workspace();
82+
// commented due to hot fix solution for stateless API which should be replaced soon.
83+
// auto acp = pd()->acp_;
84+
// acl_obj_->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info,
85+
// acp.with_bias ? &acp.bia_tensor_info : nullptr,
86+
// &acp.dst_tensor_info, acp.padstride_info, acp.act_info,
87+
// true); // to support 5x5, 7x7 filter shapes in addition to 3x3
88+
89+
// acl_obj_->aux_mem_req = acl_obj_->conv.workspace();
8990
return status::success;
9091
}
9192

@@ -129,10 +130,27 @@ status_t acl_wino_convolution_fwd_t::pd_t::init_conf() {
129130
return status::success;
130131
}
131132

133+
std::unique_ptr<acl_obj_t<acl_wino_convolution_fwd_t::Op>>
134+
acl_wino_convolution_fwd_t::reinitialize_acl_obj() const {
135+
auto acp = pd()->acp_;
136+
std::unique_ptr<acl_obj_t<Op>> acl_obj = std::make_unique<acl_obj_t<Op>>();
137+
acl_obj->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info,
138+
acp.with_bias ? &acp.bia_tensor_info : nullptr,
139+
&acp.dst_tensor_info, acp.padstride_info, acp.act_info,
140+
true); // to support 5x5, 7x7 filter shapes in addition to 3x3
141+
142+
acl_obj->aux_mem_req = acl_obj->conv.workspace();
143+
return acl_obj;
144+
}
145+
132146
status_t acl_wino_convolution_fwd_t::execute_forward(
133147
const exec_ctx_t &ctx) const {
148+
// Temporary hotfix: We're using a local acl_obj instance in this method
149+
// instead of the class member acl_obj_. This hotfix is to bypass persistent aux mem requirements but is not the ideal solution.
150+
// It should be refactored or removed in the future when a more permanent fix is implemented.
151+
const auto acl_obj = reinitialize_acl_obj();
134152
return execute_forward_conv_acl<acl_obj_t<Op>, pd_t, data_t>(
135-
ctx, acl_obj_.get(), pd(), wino_conv_keys);
153+
ctx, acl_obj.get(), pd(), wino_conv_keys);
136154
}
137155
} // namespace aarch64
138156
} // namespace cpu

src/cpu/aarch64/acl_winograd_convolution.hpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ struct acl_wino_convolution_fwd_t : public primitive_t {
4747
status_t init_conf();
4848
};
4949

50-
acl_wino_convolution_fwd_t(const pd_t *apd)
51-
: primitive_t(apd), acl_obj_(std::make_unique<acl_obj_t<Op>>()) {}
50+
// hot fix solution for stateless API which should be replaced soon.
51+
// acl_wino_convolution_fwd_t(const pd_t *apd)
52+
// : primitive_t(apd), acl_obj_(std::make_unique<acl_obj_t<Op>>()) {}
53+
acl_wino_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
5254

5355
status_t init(engine_t *engine) override;
5456

@@ -58,8 +60,13 @@ struct acl_wino_convolution_fwd_t : public primitive_t {
5860

5961
private:
6062
status_t execute_forward(const exec_ctx_t &ctx) const;
63+
64+
// hot fix solution for stateless API which should be replaced soon.
65+
std::unique_ptr<acl_obj_t<Op>> reinitialize_acl_obj() const;
66+
6167
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
62-
std::unique_ptr<acl_obj_t<Op>> acl_obj_;
68+
// commented due to hot fix solution for stateless API which should be replaced soon.
69+
// std::unique_ptr<acl_obj_t<Op>> acl_obj_;
6370
}; // acl_wino_convolution_fwd_t
6471

6572
} // namespace aarch64

0 commit comments

Comments
 (0)