Skip to content

Commit 9442aa4

Browse files
gyhintelTaoLv
authored andcommitted
graph: backend: dnnl: enable genindex ref implementation for gpu
1 parent 6817bce commit 9442aa4

File tree

5 files changed

+229
-37
lines changed

5 files changed

+229
-37
lines changed

src/gpu/intel/ocl/graph/gen_index.cl

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*******************************************************************************
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
__kernel void gen_index(__global int *dst, int axis) {
18+
long id = get_global_id(0);
19+
long result, offset = 0;
20+
long idx;
21+
22+
idx = id % D0;
23+
id = id / D0;
24+
offset += idx * S0;
25+
if (axis == 0) result = idx;
26+
27+
idx = id % D1;
28+
id = id / D1;
29+
offset += idx * S1;
30+
if (axis == 1) result = idx;
31+
32+
idx = id % D2;
33+
id = id / D2;
34+
offset += idx * S2;
35+
if (axis == 2) result = idx;
36+
37+
idx = id % D3;
38+
id = id / D3;
39+
offset += idx * S3;
40+
if (axis == 3) result = idx;
41+
42+
idx = id % D4;
43+
id = id / D4;
44+
offset += idx * S4;
45+
if (axis == 4) result = idx;
46+
47+
idx = id % D5;
48+
id = id / D5;
49+
offset += idx * S5;
50+
if (axis == 5) result = idx;
51+
52+
dst[offset] = result;
53+
}

src/graph/backend/dnnl/kernels/gen_index.cpp

+58-19
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
#include "graph/backend/dnnl/passes/utils.hpp"
2626

2727
#include "graph/backend/dnnl/op_executable.hpp"
28+
29+
#define VCHECK_GENINDEX(cond, status, msg, ...) \
30+
VCONDCHECK(graph, create, check, genindex_t, (cond), status, msg, \
31+
##__VA_ARGS__);
32+
2833
namespace dnnl {
2934
namespace impl {
3035
namespace graph {
@@ -41,6 +46,16 @@ status_t genindex_t::compile_impl(const dnnl_partition_impl_t *part,
4146
part->get_fpmath_mode(), part->get_use_blocked_layout(), true);
4247
BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));
4348

49+
#if DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE
50+
if (p_engine_.get_kind() == engine::kind::gpu) {
51+
int ndims = inputs[0].ndims;
52+
VCHECK_GENINDEX(ndims <= MAX_NDIMS, status::invalid_arguments,
53+
"only tensors of 6 or fewer dimensions are supported for "
54+
"genindex GPU, but got %dD",
55+
ndims);
56+
}
57+
#endif
58+
4459
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
4560
return this->memory_planner_.get_memory_info(val);
4661
});
@@ -84,7 +99,7 @@ status_t genindex_t::compile_impl(const dnnl_partition_impl_t *part,
8499

85100
void genindex_t::prepare_args_set(const execution_args_set_t *res,
86101
const std::vector<tensor_t> &inputs,
87-
const std::vector<tensor_t> &outputs, const scratchpad_t &scratchpad) {
102+
const std::vector<tensor_t> &outputs) {
88103
// update the data of partition in/outputs args
89104
for (const auto &mem_idx : res->get_mems_use_external_inputs()) {
90105
mem_idx.first.set_data_handle(inputs[mem_idx.second].get_data_handle());
@@ -93,13 +108,6 @@ void genindex_t::prepare_args_set(const execution_args_set_t *res,
93108
mem_idx.first.set_data_handle(
94109
outputs[mem_idx.second].get_data_handle());
95110
}
96-
97-
grantor_t var_grantor = memory_planner_.internal_temporary_grantor(
98-
scratchpad.get_buffer());
99-
100-
for (auto &mem_offkey : res->get_mems_use_internal_temporary()) {
101-
mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second));
102-
}
103111
}
104112

105113
status_t genindex_t::execute_impl(const stream_t *g_stream,
@@ -111,14 +119,7 @@ status_t genindex_t::execute_impl(const stream_t *g_stream,
111119
thread_local_cache_t<execution_args_set_t> res_cache;
112120
execution_args_set_t *res = res_cache.get_or_add(
113121
reinterpret_cast<size_t>(this), resource_ctor_);
114-
115-
temporary_scratchpad_t scratchpad(
116-
memory_planner_.total_internal_temporary_size(), p_engine_,
117-
*g_alloc_);
118-
assertm(scratchpad.size()
119-
>= memory_planner_.total_internal_temporary_size(),
120-
"no enough scratchpad memory");
121-
prepare_args_set(res, inputs, outputs, scratchpad);
122+
prepare_args_set(res, inputs, outputs);
122123

123124
constant_cache_t::cached_t c_buffer;
124125

@@ -135,7 +136,26 @@ status_t genindex_t::sycl_execute_impl(const stream_t *g_stream,
135136
const std::vector<tensor_t> &outputs,
136137
const std::vector<::sycl::event> &sycl_deps,
137138
::sycl::event *sycl_event) {
138-
if (p_engine_.get_kind() == engine::kind::gpu) return status::unimplemented;
139+
if (p_engine_.get_kind() == engine::kind::gpu) {
140+
auto deps = sycl_deps;
141+
::sycl::event returned_event;
142+
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
143+
144+
thread_local_cache_t<execution_args_set_t> res_cache;
145+
execution_args_set_t *res = res_cache.get_or_add(
146+
reinterpret_cast<size_t>(this), resource_ctor_);
147+
prepare_args_set(res, inputs, outputs);
148+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
149+
if (subgraph_->is_constant_[i]) continue;
150+
returned_event = subgraph_->execs_[i]->execute_sycl(
151+
p_stream, res->get_exec_args()[i], deps);
152+
deps = {returned_event};
153+
}
154+
155+
if (sycl_event) *sycl_event = returned_event;
156+
157+
return status::success;
158+
}
139159
return execute_impl(g_stream, inputs, outputs);
140160
}
141161
#endif
@@ -144,8 +164,27 @@ status_t genindex_t::ocl_execute_impl(const stream_t *g_stream,
144164
const std::vector<tensor_t> &inputs,
145165
const std::vector<tensor_t> &outputs,
146166
const std::vector<cl_event> &ocl_deps, cl_event *ocl_event) {
147-
// TODO: add support
148-
return status::unimplemented;
167+
auto deps = ocl_deps;
168+
cl_event returned_event {};
169+
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
170+
171+
// each thread's own local resource
172+
thread_local_cache_t<execution_args_set_t> res_cache;
173+
execution_args_set_t *res = res_cache.get_or_add(
174+
reinterpret_cast<size_t>(this), resource_ctor_);
175+
176+
prepare_args_set(res, inputs, outputs);
177+
178+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
179+
if (subgraph_->is_constant_[i]) continue;
180+
returned_event = subgraph_->execs_[i]->execute_ocl(
181+
p_stream, res->get_exec_args()[i], deps);
182+
deps = {returned_event};
183+
}
184+
185+
if (ocl_event) *ocl_event = returned_event;
186+
187+
return status::success;
149188
}
150189
#endif
151190
} // namespace dnnl_impl

src/graph/backend/dnnl/kernels/gen_index.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ struct genindex_t : public kernel_base_t {
5757
}
5858
void prepare_args_set(const execution_args_set_t *res,
5959
const std::vector<tensor_t> &inputs,
60-
const std::vector<tensor_t> &outputs,
61-
const scratchpad_t &scratchpad);
60+
const std::vector<tensor_t> &outputs);
6261
status_t compile_impl(const dnnl_partition_impl_t *part,
6362
const engine_t *g_engine,
6463
const std::vector<logical_tensor_t> &inputs,

src/graph/backend/dnnl/op_executable.hpp

+117-14
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@
4646
#include "graph/backend/dnnl/fusion_info.hpp"
4747
#include "graph/backend/dnnl/internal_attrs.hpp"
4848

49+
#if (DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE) \
50+
&& (DNNL_GPU_VENDOR == DNNL_VENDOR_INTEL)
51+
52+
#include "gpu/intel/compute/compute_engine.hpp"
53+
#include "gpu/intel/compute/compute_stream.hpp"
54+
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
55+
#include "gpu/intel/ocl/stream.hpp"
56+
#endif
57+
58+
#ifdef DNNL_WITH_SYCL
59+
#include "gpu/intel/sycl/stream.hpp"
60+
#endif
61+
62+
#endif
4963
namespace dnnl {
5064
namespace impl {
5165
namespace graph {
@@ -2467,19 +2481,17 @@ struct groupnorm_executable_t : public op_executable_t {
24672481
dnnl::group_normalization_forward prim_;
24682482
};
24692483

2484+
#if DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE
2485+
using namespace dnnl::impl::gpu::intel;
2486+
#define MAX_NDIMS 6
2487+
#endif
24702488
struct genindex_executable_t : public op_executable_t {
24712489
DECLARE_ARG_INDICES_GETTER;
24722490

24732491
genindex_executable_t(std::shared_ptr<op_t> &op,
24742492
const dnnl::engine &p_engine, fusion_info_mgr_t &mgr,
24752493
pd_cache_t &pd_cache) {
2476-
if (p_engine.get_kind() == engine::kind::gpu) {
2477-
assertm(false,
2478-
"genindex opexcutable is unimplemented "
2479-
"under SYCL and OCL "
2480-
"runtime!");
2481-
throw std::runtime_error("Unimplement");
2482-
}
2494+
24832495
using ltw = logical_tensor_wrapper_t;
24842496
const auto &input_lt = op->get_input_value(0)->get_logical_tensor();
24852497
nelems_ = ltw(input_lt).nelems();
@@ -2490,6 +2502,26 @@ struct genindex_executable_t : public op_executable_t {
24902502
output_dims_[i] = output_lt.dims[i];
24912503
output_strides_[i] = output_lt.layout.strides[i];
24922504
}
2505+
#if DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE \
2506+
&& DNNL_GPU_VENDOR == DNNL_VENDOR_INTEL
2507+
if (p_engine.get_kind() == engine::kind::gpu) {
2508+
compute::kernel_ctx_t kernel_ctx;
2509+
kernel_ctx.define_int("NDIMS", ndims_);
2510+
for (int d = 0; d < MAX_NDIMS; ++d) {
2511+
dim_t dim = (d < ndims_) ? output_dims_[d] : 1;
2512+
dim_t stride = (d < ndims_) ? output_strides_[d] : 0;
2513+
kernel_ctx.define_int(dnnl::impl::utils::format("D%d", d), dim);
2514+
kernel_ctx.define_int(
2515+
dnnl::impl::utils::format("S%d", d), stride);
2516+
}
2517+
auto *compute_engine
2518+
= dnnl::impl::utils::downcast<compute::compute_engine_t *>(
2519+
p_engine.get());
2520+
std::vector<compute::kernel_t> kernels(1);
2521+
compute_engine->create_kernels(&kernels, {"gen_index"}, kernel_ctx);
2522+
kernel_ = kernels[0];
2523+
}
2524+
#endif
24932525
}
24942526

24952527
void execute(const stream &stream,
@@ -2498,26 +2530,97 @@ struct genindex_executable_t : public op_executable_t {
24982530
#ifdef DNNL_WITH_SYCL
24992531
::sycl::event execute_sycl(const stream &stream,
25002532
const std::unordered_map<int, memory> &args,
2501-
const std::vector<::sycl::event> &deps) const override {
2502-
execute(stream, args);
2503-
return {};
2533+
const std::vector<::sycl::event> &deps = {}) const override {
2534+
#if (DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE) \
2535+
&& (DNNL_GPU_VENDOR == DNNL_VENDOR_INTEL)
2536+
auto compute_stream
2537+
= dnnl::impl::utils::downcast<compute::compute_stream_t *>(
2538+
stream.get());
2539+
compute::range_t gws = {static_cast<size_t>(nelems_)};
2540+
auto nd_range = compute::nd_range_t(gws);
2541+
compute::kernel_arg_list_t arg_list;
2542+
const auto &dst = *(args.at(DNNL_ARG_DST).get()->memory_storage());
2543+
arg_list.set(0, dst);
2544+
arg_list.set(1, axis_);
2545+
auto *sycl_stream
2546+
= dnnl::impl::utils::downcast<sycl::stream_t *>(compute_stream);
2547+
sycl_stream->before_exec_hook();
2548+
if (!deps.empty()) sycl_stream->sycl_ctx().set_deps(deps);
2549+
2550+
kernel_.parallel_for(*compute_stream, nd_range, arg_list,
2551+
sycl_stream->sycl_ctx().get_deps(),
2552+
sycl_stream->sycl_ctx().get_deps());
2553+
auto return_event = sycl_stream->get_output_event();
2554+
2555+
sycl_stream->after_exec_hook();
2556+
return return_event;
2557+
#else
2558+
assertm(false,
2559+
"genindex opexcutable is only implemented for intel vendor "
2560+
"under SYCL runtime ");
2561+
throw std::runtime_error("Unimplement");
2562+
#endif
25042563
}
25052564
#endif
25062565

25072566
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
25082567
cl_event execute_ocl(const stream &stream,
25092568
const std::unordered_map<int, memory> &args,
2510-
const std::vector<cl_event> &deps) const override {
2569+
const std::vector<cl_event> &deps = {}) const override {
2570+
#if DNNL_GPU_VENDOR == DNNL_VENDOR_INTEL
2571+
auto compute_stream
2572+
= dnnl::impl::utils::downcast<compute::compute_stream_t *>(
2573+
stream.get());
2574+
2575+
compute::range_t gws = {static_cast<size_t>(nelems_)};
2576+
2577+
auto nd_range = compute::nd_range_t(gws);
2578+
compute::kernel_arg_list_t arg_list;
2579+
const auto &dst = *(args.at(DNNL_ARG_DST).get()->memory_storage());
2580+
arg_list.set(0, dst);
2581+
arg_list.set(1, axis_);
2582+
auto *ocl_stream
2583+
= dnnl::impl::utils::downcast<gpu::intel::ocl::stream_t *>(
2584+
compute_stream);
2585+
2586+
ocl_stream->before_exec_hook();
2587+
2588+
if (!deps.empty()) {
2589+
std::vector<xpu::ocl::wrapper_t<cl_event>> events(deps.size());
2590+
for (size_t i = 0; i < deps.size(); i++)
2591+
events[i] = xpu::ocl::wrapper_t<cl_event>(deps[i], true);
2592+
ocl_stream->ocl_ctx().set_deps(events);
2593+
}
2594+
2595+
kernel_.parallel_for(*compute_stream, nd_range, arg_list,
2596+
compute_stream->ctx().get_deps(),
2597+
compute_stream->ctx().get_deps());
2598+
2599+
cl_event return_event = nullptr;
2600+
if ((ocl_stream->flags() & stream_flags::in_order) == 0) {
2601+
auto last = ocl_stream->get_output_event();
2602+
return_event = last.release();
2603+
}
2604+
2605+
ocl_stream->after_exec_hook();
2606+
return return_event;
2607+
#else
25112608
assertm(false,
2512-
"genindex op excutable is unimplemented "
2513-
"under OCL runtime!");
2514-
return {};
2609+
"genindex opexcutable is only implemented for intel vendor "
2610+
"under OCL runtime ");
2611+
throw std::runtime_error("Unimplement");
2612+
#endif
25152613
}
25162614
#endif
25172615

25182616
private:
25192617
int axis_, nelems_, ndims_;
25202618
dims_t output_dims_, output_strides_;
2619+
2620+
#if (DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE) \
2621+
&& (DNNL_GPU_VENDOR == DNNL_VENDOR_INTEL)
2622+
compute::kernel_t kernel_;
2623+
#endif
25212624
};
25222625

25232626
} // namespace dnnl_impl

src/graph/backend/dnnl/patterns/single_op_pattern.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -425,10 +425,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, reduce_pass)
425425
return std::make_shared<float_reduction>();
426426
});
427427

428-
// GenIndex currently is CPU only
429428
DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gen_index_pass)
430429
.set_priority(DEFAULT_P)
431-
.set_engine_kind(engine_kind::cpu)
432430
.set_kind(partition_kind_t::misc_post_ops)
433431
.set_attr<FCreatePattern>("FCreatePattern",
434432
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {

0 commit comments

Comments
 (0)