Skip to content

Commit b80c233

Browse files
committed
gpu: update nested_hit to gpu refactor
1 parent c0644a3 commit b80c233

File tree

5 files changed

+13
-16
lines changed

5 files changed

+13
-16
lines changed

src/common/concat_pd.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ struct concat_pd_t : public primitive_desc_t {
280280
return safe_ptr_assign(*concat_pd, _pd.release()); \
281281
} \
282282
status_t create_primitive( \
283-
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> &primitive, \
283+
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> \
284+
&primitive, \
284285
dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \
285286
const override { \
286287
return primitive_t::create_primitive_common<__VA_ARGS__, pd_t>( \

src/common/primitive_desc.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,8 @@ struct primitive_desc_t : public c_compatible {
487487
return new_pd.release(); \
488488
} \
489489
status_t create_primitive( \
490-
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> &primitive, \
490+
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> \
491+
&primitive, \
491492
dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \
492493
const override { \
493494
return primitive_t::create_primitive_common<impl_type, pd_t>( \

src/common/sum_pd.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ struct sum_pd_t : public primitive_desc_t {
208208
return safe_ptr_assign(*sum_pd, _pd.release()); \
209209
} \
210210
status_t create_primitive( \
211-
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> &primitive, \
211+
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> \
212+
&primitive, \
212213
dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \
213214
const override { \
214215
return primitive_t::create_primitive_common<__VA_ARGS__, pd_t>( \

src/gpu/gpu_primitive.hpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ struct primitive_t : public impl::primitive_t {
7373
std::shared_ptr<impl::primitive_t> &primitive,
7474
const std::shared_ptr<primitive_desc_t> &pd,
7575
impl::engine_t *engine) {
76-
CHECK(pd->create_primitive(primitive, engine, cache_blob()));
76+
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> p;
77+
CHECK(pd->create_primitive_nested(p, engine, cache_blob()));
78+
79+
if (p.second == cache_state_t::kernel_hit) {
80+
creation_cached_state_ = cache_state_t::nested_primitive_hit;
81+
}
82+
primitive = p.first;
7783
register_compute_block(new compute_block_t(primitive.get()));
7884
return status::success;
7985
}

src/gpu/intel/gpu_primitive.hpp

-12
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,6 @@ struct gpu_primitive_t : public gpu::primitive_t {
155155
return status::success;
156156
}
157157

158-
status_t create_nested_primitive(std::shared_ptr<primitive_t> &primitive,
159-
const std::shared_ptr<primitive_desc_t> &pd, engine_t *engine) {
160-
std::pair<std::shared_ptr<primitive_t>, cache_state_t> p;
161-
CHECK(pd->create_primitive_nested(p, engine, cache_blob()));
162-
if (p.second == cache_state_t::kernel_hit) {
163-
creation_cached_state_ = cache_state_t::nested_primitive_hit;
164-
}
165-
primitive = p.first;
166-
register_primitive(primitive.get());
167-
return status::success;
168-
}
169-
170158
// TODO: use inheritance for exec_ctx_t to get rid of such places...
171159
static status_t parallel_for(const gemm_exec_ctx_t &ctx,
172160
const compute::nd_range_t &range, const compute::kernel_t &kernel,

0 commit comments

Comments
 (0)