File tree 5 files changed +13
-16
lines changed
5 files changed +13
-16
lines changed Original file line number Diff line number Diff line change @@ -280,7 +280,8 @@ struct concat_pd_t : public primitive_desc_t {
280
280
return safe_ptr_assign (*concat_pd, _pd.release ()); \
281
281
} \
282
282
status_t create_primitive ( \
283
- std::pair<std::shared_ptr<impl::primitive_t >, cache_state_t > &primitive, \
283
+ std::pair<std::shared_ptr<impl::primitive_t >, cache_state_t > \
284
+ &primitive, \
284
285
dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \
285
286
const override { \
286
287
return primitive_t ::create_primitive_common<__VA_ARGS__, pd_t >( \
Original file line number Diff line number Diff line change @@ -487,7 +487,8 @@ struct primitive_desc_t : public c_compatible {
487
487
return new_pd.release (); \
488
488
} \
489
489
status_t create_primitive ( \
490
- std::pair<std::shared_ptr<impl::primitive_t >, cache_state_t > &primitive, \
490
+ std::pair<std::shared_ptr<impl::primitive_t >, cache_state_t > \
491
+ &primitive, \
491
492
dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \
492
493
const override { \
493
494
return primitive_t ::create_primitive_common<impl_type, pd_t >( \
Original file line number Diff line number Diff line change @@ -208,7 +208,8 @@ struct sum_pd_t : public primitive_desc_t {
208
208
return safe_ptr_assign (*sum_pd, _pd.release ()); \
209
209
} \
210
210
status_t create_primitive ( \
211
- std::pair<std::shared_ptr<impl::primitive_t >, cache_state_t > &primitive, \
211
+ std::pair<std::shared_ptr<impl::primitive_t >, cache_state_t > \
212
+ &primitive, \
212
213
dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \
213
214
const override { \
214
215
return primitive_t ::create_primitive_common<__VA_ARGS__, pd_t >( \
Original file line number Diff line number Diff line change @@ -73,7 +73,13 @@ struct primitive_t : public impl::primitive_t {
73
73
std::shared_ptr<impl::primitive_t > &primitive,
74
74
const std::shared_ptr<primitive_desc_t > &pd,
75
75
impl::engine_t *engine) {
76
- CHECK (pd->create_primitive (primitive, engine, cache_blob ()));
76
+ std::pair<std::shared_ptr<impl::primitive_t >, cache_state_t > p;
77
+ CHECK (pd->create_primitive_nested (p, engine, cache_blob ()));
78
+
79
+ if (p.second == cache_state_t ::kernel_hit) {
80
+ creation_cached_state_ = cache_state_t ::nested_primitive_hit;
81
+ }
82
+ primitive = p.first ;
77
83
register_compute_block (new compute_block_t (primitive.get ()));
78
84
return status::success;
79
85
}
Original file line number Diff line number Diff line change @@ -155,18 +155,6 @@ struct gpu_primitive_t : public gpu::primitive_t {
155
155
return status::success;
156
156
}
157
157
158
- status_t create_nested_primitive (std::shared_ptr<primitive_t > &primitive,
159
- const std::shared_ptr<primitive_desc_t > &pd, engine_t *engine) {
160
- std::pair<std::shared_ptr<primitive_t >, cache_state_t > p;
161
- CHECK (pd->create_primitive_nested (p, engine, cache_blob ()));
162
- if (p.second == cache_state_t ::kernel_hit) {
163
- creation_cached_state_ = cache_state_t ::nested_primitive_hit;
164
- }
165
- primitive = p.first ;
166
- register_primitive (primitive.get ());
167
- return status::success;
168
- }
169
-
170
158
// TODO: use inheritance for exec_ctx_t to get rid of such places...
171
159
static status_t parallel_for (const gemm_exec_ctx_t &ctx,
172
160
const compute::nd_range_t &range, const compute::kernel_t &kernel,
You can’t perform that action at this time.
0 commit comments