@@ -89,37 +89,44 @@ template <data_type_t src_t, data_type_t wei_t, data_type_t dst_t,
89
89
data_type_t bia_t >
90
90
status_t acl_gemm_convolution_fwd_t <src_t , wei_t , dst_t , bia_t >::init(
91
91
engine_t *engine) {
92
+ // commented due to hot fix solution for stateless API which should be replaced soon.
93
+ // auto acp_ = pd()->acp_;
94
+ // acl_obj_->conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info,
95
+ // acp_.with_bias ? &acp_.bia_tensor_info : nullptr,
96
+ // &acp_.dst_tensor_info, acp_.padstride_info, acp_.weights_info,
97
+ // acp_.dilation_info, acp_.act_info, acp_.fast_math);
98
+ // acl_obj_->aux_mem_req = acl_obj_->conv.workspace();
99
+ return status::success;
100
+ }
101
+
102
+ template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
103
+ data_type_t bia_type>
104
+ std::unique_ptr<acl_obj_t <typename acl_gemm_convolution_fwd_t <src_type,
105
+ wei_type, dst_type, bia_type>::Op>>
106
+ acl_gemm_convolution_fwd_t <src_type, wei_type, dst_type,
107
+ bia_type>::reinitialize_acl_obj() const {
92
108
auto acp_ = pd ()->acp_ ;
93
- acl_obj_->conv .configure (&acp_.src_tensor_info , &acp_.wei_tensor_info ,
109
+ std::unique_ptr<acl_obj_t <Op>> acl_obj = std::make_unique<acl_obj_t <Op>>();
110
+ acl_obj->conv .configure (&acp_.src_tensor_info , &acp_.wei_tensor_info ,
94
111
acp_.with_bias ? &acp_.bia_tensor_info : nullptr ,
95
112
&acp_.dst_tensor_info , acp_.padstride_info , acp_.weights_info ,
96
113
acp_.dilation_info , acp_.act_info , acp_.fast_math );
97
- acl_obj_->aux_mem_req = acl_obj_->conv .workspace ();
98
- return status::success;
99
- }
100
-
101
- template <data_type_t src_t , data_type_t wei_t , data_type_t dst_t ,
102
- data_type_t bia_t >
103
- void acl_gemm_convolution_fwd_t <src_t , wei_t , dst_t ,
104
- bia_t >::reinitialize_acl_obj() const {
105
- auto acp = pd ()->acp_ ;
106
- std::lock_guard<std::mutex> _lock {this ->mtx };
107
- acl_obj_ = std::make_unique<acl_obj_t <Op>>();
108
- acl_obj_->conv .configure (&acp.src_tensor_info , &acp.wei_tensor_info ,
109
- acp.with_bias ? &acp.bia_tensor_info : nullptr ,
110
- &acp.dst_tensor_info , acp.padstride_info , acp.weights_info ,
111
- acp.dilation_info , acp.act_info , acp.fast_math );
112
- acl_obj_->aux_mem_req = acl_obj_->conv .workspace ();
114
+ acl_obj->aux_mem_req = acl_obj->conv .workspace ();
115
+ return acl_obj;
113
116
}
114
117
115
118
template <data_type_t src_t , data_type_t wei_t , data_type_t dst_t ,
116
119
data_type_t bia_t >
117
120
status_t
118
121
acl_gemm_convolution_fwd_t <src_t , wei_t , dst_t , bia_t >::execute_forward(
119
122
const exec_ctx_t &ctx) const {
120
- reinitialize_acl_obj ();
123
+ // Temporary hotfix: We're using a local acl_obj instance in this method
124
+ // instead of the class member acl_obj_. This hotfix is to bypass persistent aux mem requirements but is not the ideal solution.
125
+ // It should be refactored or removed in the future when a more permanent fix is implemented.
126
+ auto acl_obj = reinitialize_acl_obj ();
127
+
121
128
return execute_forward_conv_acl<acl_obj_t <Op>, pd_t , src_data_t , wei_data_t ,
122
- dst_data_t , bia_data_t >(ctx, acl_obj_ .get (), pd (), gemm_conv_keys);
129
+ dst_data_t , bia_data_t >(ctx, acl_obj .get (), pd (), gemm_conv_keys);
123
130
}
124
131
125
132
using namespace data_type ;
0 commit comments