Skip to content

Commit fcaa982

Browse files
committed
generic: sycl: Introduce spec constants for sycl matmul
Address comments Aaand its back! Addressed comments
1 parent 63184fe commit fcaa982

File tree

7 files changed

+207
-42
lines changed

7 files changed

+207
-42
lines changed

src/gpu/generic/sycl/matmul_kernels.hpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define GPU_GENERIC_SYCL_MATMUL_KERNELS_HPP
1919

2020
#include "common/primitive_exec_types.hpp"
21+
#include "gpu/generic/sycl/specialization_constants.hpp"
2122
#include "gpu/generic/sycl/sycl_io_helper.hpp"
2223
#include "gpu/generic/sycl/sycl_math_utils.hpp"
2324
#include "gpu/generic/sycl/sycl_post_ops.hpp"
@@ -358,7 +359,7 @@ struct matmul_kernel_fwd_t {
358359
matmul_kernel_fwd_t(const sycl_matmul_conf_t &conf, ::sycl::handler &cgh,
359360
const exec_ctx_t &ctx)
360361
: conf_(conf)
361-
, data_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
362+
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
362363
, weights_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_WEIGHTS))
363364
, bias_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_BIAS))
364365
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
@@ -409,16 +410,23 @@ struct matmul_kernel_fwd_t {
409410
CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_ATTR_DROPOUT_PROBABILITY))
410411
, po_args_(cgh, ctx, conf_.post_ops) {}
411412

412-
void operator()(::sycl::nd_item<1> item) const {
413+
void operator()(::sycl::nd_item<1> item, ::sycl::kernel_handler kh) const {
413414
using data_block_t = register_block<register_block_M, register_block_K>;
414415
using weights_block_t
415416
= register_block<register_block_K, register_block_N>;
416417
using dst_block_t = register_block<register_block_M, register_block_N>;
417418

418-
memory_tensor_t data_mem(data_, conf_.data_md);
419-
memory_tensor_t weights_mem(weights_, conf_.weights_md);
419+
// Get the value of the spec constant;
420+
const auto &md_t_spec_const_pod_val = kh.get_specialization_constant<
421+
detail::matmul::md_t_spec_const_id>();
422+
const auto &src_md = md_t_spec_const_pod_val.data_md_t;
423+
const auto &weights_md = md_t_spec_const_pod_val.weights_md_t;
424+
const auto &dst_md = md_t_spec_const_pod_val.dst_md_t;
425+
426+
memory_tensor_t data_mem(src_, src_md);
427+
memory_tensor_t weights_mem(weights_, weights_md);
420428
memory_tensor_t bias_mem(bias_, conf_.bias_md);
421-
memory_tensor_t dst_mem(dst_, conf_.dst_md);
429+
memory_tensor_t dst_mem(dst_, dst_md);
422430
memory_plain_t data_scale_mem(data_scale_, data_scales_dt_);
423431
memory_plain_t weights_scale_mem(weights_scale_, weights_scales_dt_);
424432
memory_plain_t dst_scale_mem(dst_scale_, dst_scales_dt_);
@@ -513,7 +521,11 @@ struct matmul_kernel_fwd_t {
513521
off_dst[matmul_dim_2] *= conf_.transpose_dst ? register_block_M
514522
: register_block_N;
515523
int m = off_dst[conf_.transpose_dst ? matmul_dim_2 : matmul_dim_1];
516-
int n = off_dst[conf_.transpose_dst ? matmul_dim_1 : matmul_dim_2];
524+
// TODO: the following code is changed due to a correctness bug
525+
// specific for PVC, needs further investigation and a better fix
526+
// or explanation.
527+
int n = off_dst[matmul_dim_2];
528+
if (conf_.transpose_dst) { n = off_dst[matmul_dim_1]; }
517529

518530
dims_t off_src, off_weights, off_bias;
519531
for (int i = max_supported_ndims - 1; i >= 0; i--) {
@@ -650,7 +662,7 @@ struct matmul_kernel_fwd_t {
650662
private:
651663
sycl_matmul_conf_t conf_;
652664

653-
xpu::sycl::in_memory_arg_t data_;
665+
xpu::sycl::in_memory_arg_t src_;
654666
xpu::sycl::in_memory_arg_t weights_;
655667
xpu::sycl::in_memory_arg_t bias_;
656668
xpu::sycl::inout_memory_arg_t dst_;

src/gpu/generic/sycl/ref_matmul.cpp

+86-27
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,22 @@
1515
*******************************************************************************/
1616

1717
#include "gpu/generic/sycl/ref_matmul.hpp"
18+
#include "common/c_types_map.hpp"
1819
#include "gpu/generic/sycl/matmul_kernels.hpp"
20+
#include "gpu/generic/sycl/specialization_constants.hpp"
21+
#include "xpu/sycl/types.hpp"
22+
23+
#define VCHECK_MATMUL(cond, msg, ...) \
24+
VCONDCHECK(primitive, create, check, matmul, (cond), \
25+
status::unimplemented, msg, ##__VA_ARGS__);
1926

2027
namespace dnnl {
2128
namespace impl {
2229
namespace gpu {
2330
namespace generic {
2431
namespace sycl {
2532

26-
void ref_matmul_t::pd_t::init_conf() {
33+
status_t ref_matmul_t::pd_t::init_conf() {
2734
conf_ = sycl_matmul_conf_t();
2835

2936
conf_.do_scale_data
@@ -49,19 +56,61 @@ void ref_matmul_t::pd_t::init_conf() {
4956
memory_desc_wrapper weights_d = weights_md();
5057
memory_desc_wrapper dst_d = dst_md();
5158
memory_desc_wrapper bias_d = weights_md(1);
52-
for (const auto &mdw : {src_d, weights_d, dst_d, bias_d}) {
53-
if (mdw.has_runtime_dims()) {
54-
any_runtime_params_ = true;
55-
return;
56-
}
57-
}
58-
init_rt_conf(conf_, src_d, weights_d, dst_d, bias_d);
59+
VCHECK_MATMUL(!utils::one_of(true, src_d.has_runtime_dims(),
60+
weights_d.has_runtime_dims(),
61+
dst_d.has_runtime_dims(), bias_d.has_runtime_dims()),
62+
VERBOSE_RUNTIMEDIM_UNSUPPORTED);
63+
64+
return init_rt_conf(conf_, data_md_t, dst_md_t, weights_md_t, src_d,
65+
weights_d, dst_d, bias_d);
5966
}
6067

61-
void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
68+
status_t ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
69+
xpu::sycl::md_t_spec_const &data_md_t_,
70+
xpu::sycl::md_t_spec_const &dst_md_t_,
71+
xpu::sycl::md_t_spec_const &weights_md_t_,
6272
const memory_desc_wrapper src_d, const memory_desc_wrapper weights_d,
6373
const memory_desc_wrapper dst_d,
6474
const memory_desc_wrapper bias_d) const {
75+
76+
// Lambda because this function will not be used anywhere else
77+
auto init_md_t_sc_from_md = [=](xpu::sycl::md_t_spec_const &md_t_sc,
78+
const memory_desc_t *md) -> status_t {
79+
constexpr int max_dims = 6;
80+
using dim32_t = int32_t;
81+
82+
memory_desc_wrapper mdw(md);
83+
84+
VCHECK_MATMUL(mdw.format_kind() == format_kind::blocked,
85+
VERBOSE_UNSUPPORTED_FORMAT_KIND);
86+
VCHECK_MATMUL(
87+
mdw.ndims() <= max_dims, VERBOSE_BAD_NDIMS, mdw, mdw.ndims());
88+
89+
const auto &blk = mdw.blocking_desc();
90+
91+
md_t_sc.data_type_ = mdw.data_type();
92+
#define CHECK_AND_ASSIGN(lhs, rhs) \
93+
VCHECK_MATMUL((rhs) <= INT32_MAX, VERBOSE_BAD_PARAM, rhs); \
94+
(lhs) = static_cast<dim32_t>(rhs)
95+
96+
CHECK_AND_ASSIGN(md_t_sc.ndims_, mdw.ndims());
97+
CHECK_AND_ASSIGN(md_t_sc.offset0_, mdw.offset0());
98+
CHECK_AND_ASSIGN(md_t_sc.inner_nblks_, blk.inner_nblks);
99+
100+
for (int d = 0; d < mdw.ndims(); d++) {
101+
CHECK_AND_ASSIGN(md_t_sc.dims_[d], mdw.dims()[d]);
102+
CHECK_AND_ASSIGN(md_t_sc.padded_dims_[d], mdw.padded_dims()[d]);
103+
CHECK_AND_ASSIGN(
104+
md_t_sc.padded_offsets_[d], mdw.padded_offsets()[d]);
105+
CHECK_AND_ASSIGN(md_t_sc.strides_[d], blk.strides[d]);
106+
CHECK_AND_ASSIGN(md_t_sc.inner_blks_[d], blk.inner_blks[d]);
107+
CHECK_AND_ASSIGN(md_t_sc.inner_idxs_[d], blk.inner_idxs[d]);
108+
}
109+
#undef CHECK_AND_ASSIGN
110+
111+
return status::success;
112+
};
113+
65114
int matmul_dim_1 = ndims() - 2;
66115
int matmul_dim_2 = ndims() - 1;
67116

@@ -73,7 +122,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
73122
data_md_copy.dims[matmul_dim_2]);
74123
conf.transpose_data = true;
75124
}
76-
conf.data_md = xpu::sycl::md_t(&data_md_copy);
125+
init_md_t_sc_from_md(data_md_t_, &data_md_copy);
77126

78127
memory_desc_t weights_md_copy = *weights_d.md_;
79128
auto &weights_strides = weights_md_copy.format_desc.blocking.strides;
@@ -83,7 +132,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
83132
weights_md_copy.dims[matmul_dim_2]);
84133
conf.transpose_weights = true;
85134
}
86-
conf.weights_md = xpu::sycl::md_t(&weights_md_copy);
135+
init_md_t_sc_from_md(weights_md_t_, &weights_md_copy);
87136

88137
memory_desc_t dst_md_copy = *dst_d.md_;
89138
auto &dst_strides = dst_md_copy.format_desc.blocking.strides;
@@ -93,7 +142,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
93142
dst_md_copy.dims[matmul_dim_1], dst_md_copy.dims[matmul_dim_2]);
94143
conf.transpose_dst = true;
95144
}
96-
conf.dst_md = xpu::sycl::md_t(&dst_md_copy);
145+
init_md_t_sc_from_md(dst_md_t_, &dst_md_copy);
97146

98147
if (with_bias()) {
99148
memory_desc_t bias_md_copy = *bias_d.md_;
@@ -109,8 +158,8 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
109158

110159
dims_t dst_blocks;
111160
for (int i = 0; i < matmul_kernel_fwd_t::max_supported_ndims; i++) {
112-
if (i < conf.dst_md.ndims()) {
113-
dst_blocks[i] = conf.dst_md.dims()[i];
161+
if (i < dst_md_t.ndims_) {
162+
dst_blocks[i] = dst_md_t.dims_[i];
114163
} else {
115164
dst_blocks[i] = 1;
116165
}
@@ -133,34 +182,44 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
133182
= utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims())
134183
| high_two_bits;
135184
conf.bias_mask = utils::get_dims_mask(dst_d.dims(), bias_d.dims(), ndims());
185+
186+
return status::success;
136187
}
137188

138189
status_t ref_matmul_t::init(impl::engine_t *engine) {
139190
const auto kid = ::sycl::get_kernel_id<matmul_kernel_fwd_t>();
140-
CHECK(create_kernel(engine, kid, &kernel_));
191+
CHECK(create_matmul_kernel(engine, kid, &kernel_,
192+
{pd()->data_md_t, pd()->dst_md_t, pd()->weights_md_t}));
193+
return status::success;
194+
}
195+
196+
status_t ref_matmul_t::create_matmul_kernel(impl::engine_t *engine,
197+
::sycl::kernel_id kid, kernel_t *kernel,
198+
xpu::sycl::md_t_spec_const_pod pod) {
199+
200+
auto ctx = utils::downcast<const xpu::sycl::engine_impl_t *>(engine->impl())
201+
->context();
202+
auto input_bundle = ::sycl::get_kernel_bundle<::sycl::bundle_state::input>(
203+
ctx, {kid});
204+
205+
input_bundle.template set_specialization_constant<
206+
detail::matmul::md_t_spec_const_id>(pod);
207+
try {
208+
(*kernel) = kernel_t(::sycl::build(input_bundle));
209+
} catch (const ::sycl::exception &e) { return status::runtime_error; }
141210
return status::success;
142211
}
143212

144213
status_t ref_matmul_t::execute(const exec_ctx_t &ctx) const {
145214
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
146215

147-
sycl_matmul_conf_t conf = pd()->conf_;
148-
if (pd()->any_runtime_params_) {
149-
const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
150-
const auto weights_d
151-
= ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
152-
const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
153-
const auto bias_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1));
154-
pd()->init_rt_conf(conf, src_d, weights_d, dst_d, bias_d);
155-
}
156-
157216
parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
158-
matmul_kernel_fwd_t matmul_kernel(conf, cgh, ctx);
217+
matmul_kernel_fwd_t matmul_kernel(pd()->conf_, cgh, ctx);
159218

160219
const int block_size = 32;
161220
const int wg_size = 32;
162221

163-
const int t_work = conf.wk_size;
222+
const int t_work = pd()->conf_.wk_size;
164223
const int wg_work = wg_size * block_size;
165224
const int wg_cnt = utils::div_up(t_work, wg_work);
166225

src/gpu/generic/sycl/ref_matmul.hpp

+15-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#ifndef GPU_GENERIC_SYCL_REF_MATMUL_HPP
1818
#define GPU_GENERIC_SYCL_REF_MATMUL_HPP
1919

20+
#include "common/c_types_map.hpp"
21+
#include "gpu/generic/sycl/specialization_constants.hpp"
2022
#include "gpu/generic/sycl/sycl_gpu_primitive.hpp"
2123
#include "gpu/generic/sycl/sycl_io_helper.hpp"
2224
#include "gpu/generic/sycl/sycl_post_ops.hpp"
@@ -64,21 +66,28 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
6466
&& md_dims_in_range(weights_md());
6567
if (!ok) return status::unimplemented;
6668

67-
init_conf();
68-
return status::success;
69+
return init_conf();
6970
}
7071

7172
sycl_matmul_conf_t conf_;
73+
74+
xpu::sycl::md_t_spec_const data_md_t;
75+
xpu::sycl::md_t_spec_const dst_md_t;
76+
xpu::sycl::md_t_spec_const weights_md_t;
77+
7278
bool any_runtime_params_ = false;
7379

74-
void init_rt_conf(sycl_matmul_conf_t &conf,
80+
status_t init_rt_conf(sycl_matmul_conf_t &conf,
81+
xpu::sycl::md_t_spec_const &data_md_t_,
82+
xpu::sycl::md_t_spec_const &dst_md_t_,
83+
xpu::sycl::md_t_spec_const &weights_md_t_,
7584
const memory_desc_wrapper src_d,
7685
const memory_desc_wrapper weights_d,
7786
const memory_desc_wrapper dst_d,
7887
const memory_desc_wrapper bias_d) const;
7988

8089
private:
81-
void init_conf();
90+
status_t init_conf();
8291

8392
status_t set_default_params() {
8493
if (src_md_.format_kind == format_kind::any) {
@@ -153,6 +162,8 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
153162
status_t execute(const exec_ctx_t &ctx) const override;
154163

155164
private:
165+
status_t create_matmul_kernel(impl::engine_t *engine, ::sycl::kernel_id kid,
166+
kernel_t *kernel, xpu::sycl::md_t_spec_const_pod pod);
156167
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
157168
kernel_t kernel_;
158169
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*******************************************************************************
2+
* Copyright 2024 Intel Corporation
3+
* Copyright 2024 Codeplay Software
4+
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*******************************************************************************/
17+
18+
#ifndef GPU_GENERIC_SYCL_SPECIALIZATION_CONSTANTS_HPP
19+
#define GPU_GENERIC_SYCL_SPECIALIZATION_CONSTANTS_HPP
20+
21+
#include <sycl/sycl.hpp>
22+
23+
#include "xpu/sycl/types.hpp"
24+
25+
namespace dnnl::impl::gpu::generic::sycl {
26+
namespace detail {
27+
namespace matmul {
28+
static constexpr ::sycl::specialization_id<xpu::sycl::md_t_spec_const_pod>
29+
md_t_spec_const_id;
30+
}
31+
} // namespace detail
32+
} // namespace dnnl::impl::gpu::generic::sycl
33+
34+
#endif

src/gpu/generic/sycl/sycl_primitive_conf.hpp

-3
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,6 @@ struct sycl_eltwise_conf_t {
112112
};
113113

114114
struct sycl_matmul_conf_t {
115-
xpu::sycl::md_t data_md;
116-
xpu::sycl::md_t dst_md;
117-
xpu::sycl::md_t weights_md;
118115
xpu::sycl::md_t bias_md;
119116
alg_kind_t alg_kind;
120117
bool transpose_data; //TODO can we remove?

0 commit comments

Comments
 (0)