diff --git a/src/gpu/generic/sycl/matmul_kernels.hpp b/src/gpu/generic/sycl/matmul_kernels.hpp index 1e415b48b15..3b98e24cee0 100644 --- a/src/gpu/generic/sycl/matmul_kernels.hpp +++ b/src/gpu/generic/sycl/matmul_kernels.hpp @@ -18,6 +18,7 @@ #define GPU_GENERIC_SYCL_MATMUL_KERNELS_HPP #include "common/primitive_exec_types.hpp" +#include "gpu/generic/sycl/specialization_constants.hpp" #include "gpu/generic/sycl/sycl_io_helper.hpp" #include "gpu/generic/sycl/sycl_math_utils.hpp" #include "gpu/generic/sycl/sycl_post_ops.hpp" @@ -358,7 +359,7 @@ struct matmul_kernel_fwd_t { matmul_kernel_fwd_t(const sycl_matmul_conf_t &conf, ::sycl::handler &cgh, const exec_ctx_t &ctx) : conf_(conf) - , data_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0)) + , src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0)) , weights_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_WEIGHTS)) , bias_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_BIAS)) , dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST)) @@ -409,16 +410,23 @@ struct matmul_kernel_fwd_t { CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_ATTR_DROPOUT_PROBABILITY)) , po_args_(cgh, ctx, conf_.post_ops) {} - void operator()(::sycl::nd_item<1> item) const { + void operator()(::sycl::nd_item<1> item, ::sycl::kernel_handler kh) const { using data_block_t = register_block; using weights_block_t = register_block; using dst_block_t = register_block; - memory_tensor_t data_mem(data_, conf_.data_md); - memory_tensor_t weights_mem(weights_, conf_.weights_md); + // Get the value of the spec constant; + const auto &md_t_spec_const_pod_val = kh.get_specialization_constant< + detail::matmul::md_t_spec_const_id>(); + const auto &src_md = md_t_spec_const_pod_val.data_md_t; + const auto &weights_md = md_t_spec_const_pod_val.weights_md_t; + const auto &dst_md = md_t_spec_const_pod_val.dst_md_t; + + memory_tensor_t data_mem(src_, src_md); + memory_tensor_t weights_mem(weights_, weights_md); memory_tensor_t bias_mem(bias_, conf_.bias_md); - memory_tensor_t dst_mem(dst_, conf_.dst_md); + memory_tensor_t dst_mem(dst_, dst_md); memory_plain_t data_scale_mem(data_scale_, data_scales_dt_); memory_plain_t weights_scale_mem(weights_scale_, weights_scales_dt_); memory_plain_t dst_scale_mem(dst_scale_, dst_scales_dt_); @@ -513,7 +521,11 @@ struct matmul_kernel_fwd_t { off_dst[matmul_dim_2] *= conf_.transpose_dst ? register_block_M : register_block_N; int m = off_dst[conf_.transpose_dst ? matmul_dim_2 : matmul_dim_1]; - int n = off_dst[conf_.transpose_dst ? matmul_dim_1 : matmul_dim_2]; + // TODO: the following code is changed due to a correctness bug + // specific for PVC, needs further investigation and a better fix + // or explanation. + int n = off_dst[matmul_dim_2]; + if (conf_.transpose_dst) { n = off_dst[matmul_dim_1]; } dims_t off_src, off_weights, off_bias; for (int i = max_supported_ndims - 1; i >= 0; i--) { @@ -650,7 +662,7 @@ struct matmul_kernel_fwd_t { private: sycl_matmul_conf_t conf_; - xpu::sycl::in_memory_arg_t data_; + xpu::sycl::in_memory_arg_t src_; xpu::sycl::in_memory_arg_t weights_; xpu::sycl::in_memory_arg_t bias_; xpu::sycl::inout_memory_arg_t dst_; diff --git a/src/gpu/generic/sycl/ref_matmul.cpp b/src/gpu/generic/sycl/ref_matmul.cpp index d79f80b3cf0..89bc2a78cd9 100644 --- a/src/gpu/generic/sycl/ref_matmul.cpp +++ b/src/gpu/generic/sycl/ref_matmul.cpp @@ -15,7 +15,14 @@ *******************************************************************************/ #include "gpu/generic/sycl/ref_matmul.hpp" +#include "common/c_types_map.hpp" #include "gpu/generic/sycl/matmul_kernels.hpp" +#include "gpu/generic/sycl/specialization_constants.hpp" +#include "xpu/sycl/types.hpp" + +#define VCHECK_MATMUL(cond, msg, ...) \ + VCONDCHECK(primitive, create, check, matmul, (cond), \ + status::unimplemented, msg, ##__VA_ARGS__); namespace dnnl { namespace impl { @@ -23,7 +30,7 @@ namespace gpu { namespace generic { namespace sycl { -void ref_matmul_t::pd_t::init_conf() { +status_t ref_matmul_t::pd_t::init_conf() { conf_ = sycl_matmul_conf_t(); conf_.do_scale_data @@ -49,19 +56,61 @@ void ref_matmul_t::pd_t::init_conf() { memory_desc_wrapper weights_d = weights_md(); memory_desc_wrapper dst_d = dst_md(); memory_desc_wrapper bias_d = weights_md(1); - for (const auto &mdw : {src_d, weights_d, dst_d, bias_d}) { - if (mdw.has_runtime_dims()) { - any_runtime_params_ = true; - return; - } - } - init_rt_conf(conf_, src_d, weights_d, dst_d, bias_d); + VCHECK_MATMUL(!utils::one_of(true, src_d.has_runtime_dims(), + weights_d.has_runtime_dims(), + dst_d.has_runtime_dims(), bias_d.has_runtime_dims()), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + return init_rt_conf(conf_, data_md_t, dst_md_t, weights_md_t, src_d, + weights_d, dst_d, bias_d); } -void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf, +status_t ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf, + xpu::sycl::md_t_spec_const &data_md_t_, + xpu::sycl::md_t_spec_const &dst_md_t_, + xpu::sycl::md_t_spec_const &weights_md_t_, const memory_desc_wrapper src_d, const memory_desc_wrapper weights_d, const memory_desc_wrapper dst_d, const memory_desc_wrapper bias_d) const { + + // Lambda because this function will not be used anywhere else + auto init_md_t_sc_from_md = [=](xpu::sycl::md_t_spec_const &md_t_sc, + const memory_desc_t *md) -> status_t { + constexpr int max_dims = 6; + using dim32_t = int32_t; + + memory_desc_wrapper mdw(md); + + VCHECK_MATMUL(mdw.format_kind() == format_kind::blocked, + VERBOSE_UNSUPPORTED_FORMAT_KIND); + VCHECK_MATMUL( + mdw.ndims() <= max_dims, VERBOSE_BAD_NDIMS, mdw, mdw.ndims()); + + const auto &blk = mdw.blocking_desc(); + + md_t_sc.data_type_ = mdw.data_type(); +#define CHECK_AND_ASSIGN(lhs, rhs) \ + VCHECK_MATMUL((rhs) <= INT32_MAX, VERBOSE_BAD_PARAM, rhs); \ + (lhs) = static_cast(rhs) + + CHECK_AND_ASSIGN(md_t_sc.ndims_, mdw.ndims()); + CHECK_AND_ASSIGN(md_t_sc.offset0_, mdw.offset0()); + CHECK_AND_ASSIGN(md_t_sc.inner_nblks_, blk.inner_nblks); + + for (int d = 0; d < mdw.ndims(); d++) { + CHECK_AND_ASSIGN(md_t_sc.dims_[d], mdw.dims()[d]); + CHECK_AND_ASSIGN(md_t_sc.padded_dims_[d], mdw.padded_dims()[d]); + CHECK_AND_ASSIGN( + md_t_sc.padded_offsets_[d], mdw.padded_offsets()[d]); + CHECK_AND_ASSIGN(md_t_sc.strides_[d], blk.strides[d]); + CHECK_AND_ASSIGN(md_t_sc.inner_blks_[d], blk.inner_blks[d]); + CHECK_AND_ASSIGN(md_t_sc.inner_idxs_[d], blk.inner_idxs[d]); + } +#undef CHECK_AND_ASSIGN + + return status::success; + }; + int matmul_dim_1 = ndims() - 2; int matmul_dim_2 = ndims() - 1; @@ -73,7 +122,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf, data_md_copy.dims[matmul_dim_2]); conf.transpose_data = true; } - conf.data_md = xpu::sycl::md_t(&data_md_copy); + init_md_t_sc_from_md(data_md_t_, &data_md_copy); memory_desc_t weights_md_copy = *weights_d.md_; auto &weights_strides = weights_md_copy.format_desc.blocking.strides; @@ -83,7 +132,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf, weights_md_copy.dims[matmul_dim_2]); conf.transpose_weights = true; } - conf.weights_md = xpu::sycl::md_t(&weights_md_copy); + init_md_t_sc_from_md(weights_md_t_, &weights_md_copy); memory_desc_t dst_md_copy = *dst_d.md_; auto &dst_strides = dst_md_copy.format_desc.blocking.strides; @@ -93,7 +142,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf, dst_md_copy.dims[matmul_dim_1], dst_md_copy.dims[matmul_dim_2]); conf.transpose_dst = true; } - conf.dst_md = xpu::sycl::md_t(&dst_md_copy); + init_md_t_sc_from_md(dst_md_t_, &dst_md_copy); if (with_bias()) { memory_desc_t bias_md_copy = *bias_d.md_; @@ -109,8 +158,8 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf, dims_t dst_blocks; for (int i = 0; i < matmul_kernel_fwd_t::max_supported_ndims; i++) { - if (i < conf.dst_md.ndims()) { - dst_blocks[i] = conf.dst_md.dims()[i]; + if (i < dst_md_t.ndims_) { + dst_blocks[i] = dst_md_t.dims_[i]; } else { dst_blocks[i] = 1; } @@ -133,34 +182,44 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf, = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims()) | high_two_bits; conf.bias_mask = utils::get_dims_mask(dst_d.dims(), bias_d.dims(), ndims()); + + return status::success; } status_t ref_matmul_t::init(impl::engine_t *engine) { const auto kid = ::sycl::get_kernel_id(); - CHECK(create_kernel(engine, kid, &kernel_)); + CHECK(create_matmul_kernel(engine, kid, &kernel_, + {pd()->data_md_t, pd()->dst_md_t, pd()->weights_md_t})); + return status::success; +} + +status_t ref_matmul_t::create_matmul_kernel(impl::engine_t *engine, + ::sycl::kernel_id kid, kernel_t *kernel, + xpu::sycl::md_t_spec_const_pod pod) { + + auto ctx = utils::downcast(engine->impl()) + ->context(); + auto input_bundle = ::sycl::get_kernel_bundle<::sycl::bundle_state::input>( + ctx, {kid}); + + input_bundle.template set_specialization_constant< + detail::matmul::md_t_spec_const_id>(pod); + try { + (*kernel) = kernel_t(::sycl::build(input_bundle)); + } catch (const ::sycl::exception &e) { return status::runtime_error; } return status::success; } status_t ref_matmul_t::execute(const exec_ctx_t &ctx) const { if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success; - sycl_matmul_conf_t conf = pd()->conf_; - if (pd()->any_runtime_params_) { - const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); - const auto weights_d - = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); - const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); - const auto bias_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1)); - pd()->init_rt_conf(conf, src_d, weights_d, dst_d, bias_d); - } - parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) { - matmul_kernel_fwd_t matmul_kernel(conf, cgh, ctx); + matmul_kernel_fwd_t matmul_kernel(pd()->conf_, cgh, ctx); const int block_size = 32; const int wg_size = 32; - const int t_work = conf.wk_size; + const int t_work = pd()->conf_.wk_size; const int wg_work = wg_size * block_size; const int wg_cnt = utils::div_up(t_work, wg_work); diff --git a/src/gpu/generic/sycl/ref_matmul.hpp b/src/gpu/generic/sycl/ref_matmul.hpp index 7b6f3bd6806..652eaaf1f35 100644 --- a/src/gpu/generic/sycl/ref_matmul.hpp +++ b/src/gpu/generic/sycl/ref_matmul.hpp @@ -17,6 +17,8 @@ #ifndef GPU_GENERIC_SYCL_REF_MATMUL_HPP #define GPU_GENERIC_SYCL_REF_MATMUL_HPP +#include "common/c_types_map.hpp" +#include "gpu/generic/sycl/specialization_constants.hpp" #include "gpu/generic/sycl/sycl_gpu_primitive.hpp" #include "gpu/generic/sycl/sycl_io_helper.hpp" #include "gpu/generic/sycl/sycl_post_ops.hpp" @@ -64,21 +66,28 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t { && md_dims_in_range(weights_md()); if (!ok) return status::unimplemented; - init_conf(); - return status::success; + return init_conf(); } sycl_matmul_conf_t conf_; + + xpu::sycl::md_t_spec_const data_md_t; + xpu::sycl::md_t_spec_const dst_md_t; + xpu::sycl::md_t_spec_const weights_md_t; + bool any_runtime_params_ = false; - void init_rt_conf(sycl_matmul_conf_t &conf, + status_t init_rt_conf(sycl_matmul_conf_t &conf, + xpu::sycl::md_t_spec_const &data_md_t_, + xpu::sycl::md_t_spec_const &dst_md_t_, + xpu::sycl::md_t_spec_const &weights_md_t_, const memory_desc_wrapper src_d, const memory_desc_wrapper weights_d, const memory_desc_wrapper dst_d, const memory_desc_wrapper bias_d) const; private: - void init_conf(); + status_t init_conf(); status_t set_default_params() { if (src_md_.format_kind == format_kind::any) { @@ -153,6 +162,8 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t { status_t execute(const exec_ctx_t &ctx) const override; private: + status_t create_matmul_kernel(impl::engine_t *engine, ::sycl::kernel_id kid, + kernel_t *kernel, xpu::sycl::md_t_spec_const_pod pod); const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } kernel_t kernel_; }; diff --git a/src/gpu/generic/sycl/specialization_constants.hpp b/src/gpu/generic/sycl/specialization_constants.hpp new file mode 100644 index 00000000000..aa95ee8967b --- /dev/null +++ b/src/gpu/generic/sycl/specialization_constants.hpp @@ -0,0 +1,34 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* Copyright 2024 Codeplay Software + +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_GENERIC_SYCL_SPECIALIZATION_CONSTANTS_HPP +#define GPU_GENERIC_SYCL_SPECIALIZATION_CONSTANTS_HPP + +#include + +#include "xpu/sycl/types.hpp" + +namespace dnnl::impl::gpu::generic::sycl { +namespace detail { +namespace matmul { +static constexpr ::sycl::specialization_id + md_t_spec_const_id; +} +} // namespace detail +} // namespace dnnl::impl::gpu::generic::sycl + +#endif diff --git a/src/gpu/generic/sycl/sycl_primitive_conf.hpp b/src/gpu/generic/sycl/sycl_primitive_conf.hpp index ec4e812cebb..cfcd23e9cdc 100644 --- a/src/gpu/generic/sycl/sycl_primitive_conf.hpp +++ b/src/gpu/generic/sycl/sycl_primitive_conf.hpp @@ -112,9 +112,6 @@ struct sycl_eltwise_conf_t { }; struct sycl_matmul_conf_t { - xpu::sycl::md_t data_md; - xpu::sycl::md_t dst_md; - xpu::sycl::md_t weights_md; xpu::sycl::md_t bias_md; alg_kind_t alg_kind; bool transpose_data; //TODO can we remove? diff --git a/src/xpu/sycl/types.hpp b/src/xpu/sycl/types.hpp index 78b518aee68..ad36e367969 100644 --- a/src/xpu/sycl/types.hpp +++ b/src/xpu/sycl/types.hpp @@ -93,6 +93,37 @@ using in_memory_arg_t = memory_arg_t<::sycl::access::mode::read>; using out_memory_arg_t = memory_arg_t<::sycl::access::mode::write>; using inout_memory_arg_t = memory_arg_t<::sycl::access::mode::read_write>; +//TODO: This is a work-around for reducing the size of kernel parameters being passed +// to the matmul kernel. This is to be removed when we shift to sycl-RTC +struct md_t_spec_const { + static constexpr int max_dims = 6; + + using dim32_t = int32_t; + using dims32_t = dim32_t[max_dims]; + + // ordering of elements is important during initialization. + // This struct cannot have a non trivial constructor, or any non trivial types. + data_type_t data_type_; + + dim32_t ndims_; + + dims32_t dims_; + dims32_t padded_dims_; + dims32_t padded_offsets_; + dim32_t offset0_; + + dims32_t strides_; + dim32_t inner_nblks_; + dims32_t inner_blks_; + dims32_t inner_idxs_; +}; + +struct md_t_spec_const_pod { + struct md_t_spec_const data_md_t; + struct md_t_spec_const dst_md_t; + struct md_t_spec_const weights_md_t; +}; + // TODO: this class mimics memory_desc_t and makes sure it can be passed // to SYCL kernels as a kernel argument. SYCL puts restrictions on kernel // arguments, e.g. those cannot contain unions. @@ -146,6 +177,21 @@ struct md_t { #undef CHECK_AND_ASSIGN } + md_t(const md_t_spec_const &other) + : data_type_(other.data_type_) + , ndims_(other.ndims_) + , offset0_(other.offset0_) + , inner_nblks_(other.inner_nblks_) { + for (dim32_t i = 0; i < ndims_; i++) { + dims_[i] = other.dims_[i]; + padded_dims_[i] = other.padded_dims_[i]; + padded_offsets_[i] = other.padded_offsets_[i]; + strides_[i] = other.strides_[i]; + inner_blks_[i] = other.inner_blks_[i]; + inner_idxs_[i] = other.inner_idxs_[i]; + } + } + template dim_t off(Args... args) const { dims_t pos = {args...}; diff --git a/tests/benchdnn/dnnl_common.hpp b/tests/benchdnn/dnnl_common.hpp index e2dbb0bc711..8d516fe93a0 100644 --- a/tests/benchdnn/dnnl_common.hpp +++ b/tests/benchdnn/dnnl_common.hpp @@ -310,7 +310,13 @@ int check_dnnl_status(dnnl_status_t status, const prb_t *prb, res_t *res) { case dnnl_unimplemented: { // Unconditionally set all Nvidia backend unimplemented cases as // not supported. - if (is_nvidia_gpu() || is_amd_gpu()) { + if (is_nvidia_gpu() + || is_amd_gpu() +#ifdef DNNL_SYCL_GENERIC + // skip unimplemented configs for sycl impl + || is_gpu() +#endif + ) { res->state = SKIPPED; res->reason = skip_reason::case_not_supported; return OK;