Skip to content

Commit 4b84d79

Browse files
AD2605Rbiessy
authored andcommitted
generic:sycl: Inner Product FWD
Co-authored-by: Atharva Dubey <atharvadubey26@gmail.com>
1 parent 7ba14f6 commit 4b84d79

File tree

5 files changed

+274
-7
lines changed

5 files changed

+274
-7
lines changed

src/gpu/generic/sycl/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ The implementation supports both forward and backward directions.
9494
* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`, `N`
9595
* Supported data types: `f32`, `bf16`, `f16`, `s32`, `s8`, `u8`
9696

97+
## Inner Product
98+
99+
The implementation supports the forward direction only.
100+
101+
* Supported formats: All plain formats are supported.
102+
* Supported data types: All possible data combinations listed in the oneDNN specification are supported.
103+
* Supported post-ops: All the post operations as mentioned in the specification are supported.
104+
97105
## Layer Normalization
98106

99107
The implementation supports both forward and backward directions.
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*******************************************************************************
2+
* Copyright 2024 Intel Corporation
3+
* Copyright 2024 Codeplay Software Limited
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*******************************************************************************/
17+
18+
#include "gpu/generic/sycl/ref_inner_product.hpp"
19+
#include "common/primitive_desc_iterator.hpp"
20+
21+
namespace dnnl::impl::gpu::generic::sycl {
22+
23+
status_t ref_inner_product_fwd_t::pd_t::init_matmul(impl::engine_t *engine) {
24+
matmul_desc_t matmul_desc;
25+
CHECK(matmul_desc_init(&matmul_desc, &src_md_reshaped, &weights_md_reshaped,
26+
&bias_md_reshaped, arg_md(DNNL_ARG_DST)));
27+
primitive_attr_t matmul_attr(*attr());
28+
29+
primitive_desc_iterator_t it(engine,
30+
reinterpret_cast<op_desc_t *>(&matmul_desc), &matmul_attr, nullptr);
31+
if (!it.is_initialized()) return status::out_of_memory;
32+
while (++it != it.end()) {
33+
matmul_pd = *it;
34+
if (matmul_pd) { break; }
35+
}
36+
if (!matmul_pd) { return status::invalid_arguments; }
37+
return status::success;
38+
}
39+
40+
status_t ref_inner_product_fwd_t::init(impl::engine_t *engine) {
41+
std::pair<std::shared_ptr<impl::primitive_t>, cache_state_t> p;
42+
CHECK(pd()->matmul_pd->create_primitive_nested(p, engine));
43+
matmul_primitive = p.first;
44+
return status::success;
45+
}
46+
47+
status_t ref_inner_product_fwd_t::execute(const exec_ctx_t &ctx) const {
48+
nested_scratchpad_t nested_scratchpad(
49+
ctx, memory_tracking::names::key_nested, matmul_primitive);
50+
exec_ctx_t copied_ctx(ctx);
51+
copied_ctx.set_scratchpad_grantor(nested_scratchpad.grantor());
52+
return matmul_primitive->execute(copied_ctx);
53+
}
54+
55+
} // namespace dnnl::impl::gpu::generic::sycl
+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*******************************************************************************
2+
* Copyright 2023-2024 Intel Corporation
3+
* Copyright 2024-2025 Codeplay Software Limited
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*******************************************************************************/
17+
18+
#ifndef GPU_GENERIC_SYCL_REF_INNER_PRODUCT_HPP
19+
#define GPU_GENERIC_SYCL_REF_INNER_PRODUCT_HPP
20+
21+
#include "gpu/generic/sycl/ref_matmul.hpp"
22+
#include "gpu/generic/sycl/sycl_gpu_primitive.hpp"
23+
#include "gpu/generic/sycl/sycl_post_ops.hpp"
24+
#include "gpu/generic/sycl/sycl_primitive_conf.hpp"
25+
#include "gpu/generic/sycl/sycl_utils.hpp"
26+
#include "gpu/gpu_inner_product_pd.hpp"
27+
#include "gpu/gpu_primitive.hpp"
28+
29+
namespace dnnl::impl::gpu::generic::sycl {
30+
struct ref_inner_product_fwd_t : public gpu::generic::sycl::primitive_t {
31+
using gpu::generic::sycl::primitive_t::primitive_t;
32+
33+
struct pd_t : public gpu_inner_product_fwd_pd_t {
34+
using gpu_inner_product_fwd_pd_t::gpu_inner_product_fwd_pd_t;
35+
using sm = primitive_attr_t::skip_mask_t;
36+
37+
DECLARE_COMMON_PD_T("dpcpp:ref:any", ref_inner_product_fwd_t);
38+
39+
status_t init(impl::engine_t *engine) {
40+
auto src_dt = arg_md(DNNL_ARG_SRC)->data_type;
41+
auto weights_dt = arg_md(DNNL_ARG_WEIGHTS)->data_type;
42+
auto dst_dt = arg_md(DNNL_ARG_DST)->data_type;
43+
auto bias_dt = with_bias() ? arg_md(DNNL_ARG_BIAS)->data_type
44+
: data_type::undef;
45+
46+
const bool ok = (set_default_params() == status::success)
47+
&& is_fwd()
48+
&& check_if_dtypes_valid(
49+
src_dt, dst_dt, bias_dt, weights_dt)
50+
&& sycl_post_ops_t::post_ops_ok(attr())
51+
&& (attr_.set_default_formats(dst_md()) == status::success)
52+
// Blocked memory formats are not supported
53+
&& memory_desc_wrapper(src_md()).is_plain()
54+
&& memory_desc_wrapper(dst_md()).is_plain()
55+
&& memory_desc_wrapper(weights_md()).is_plain();
56+
57+
if (!ok) { return status::unimplemented; }
58+
CHECK(create_ip_mds());
59+
CHECK(init_matmul(engine));
60+
61+
// book scratchpad for the matmul
62+
auto scratchpad = scratchpad_registry().registrar();
63+
scratchpad.book(memory_tracking::names::key_nested,
64+
matmul_pd->scratchpad_registry());
65+
return status::success;
66+
}
67+
68+
std::shared_ptr<primitive_desc_t> matmul_pd;
69+
70+
private:
71+
bool check_if_dtypes_valid(const data_type_t &src_dt,
72+
const data_type_t &dst_dt, const data_type_t &bias_dt,
73+
const data_type_t &weight_dt) const {
74+
using namespace data_type;
75+
return (utils::one_of(src_dt, f32) && utils::one_of(weight_dt, f32)
76+
&& utils::one_of(dst_dt, f32)
77+
&& utils::one_of(bias_dt, f32, undef))
78+
|| (utils::one_of(src_dt, f16)
79+
&& utils::one_of(weight_dt, f16)
80+
&& utils::one_of(dst_dt, f16, f32, s8, u8)
81+
&& utils::one_of(bias_dt, f16, f32, undef))
82+
|| (utils::one_of(src_dt, u8, s8)
83+
&& utils::one_of(weight_dt, s8)
84+
&& utils::one_of(dst_dt, u8, s8, s32, bf16, f32)
85+
&& utils::one_of(
86+
bias_dt, u8, s8, s32, bf16, f32, undef))
87+
|| (utils::one_of(src_dt, bf16)
88+
&& utils::one_of(weight_dt, bf16)
89+
&& utils::one_of(dst_dt, f32, bf16)
90+
&& utils::one_of(bias_dt, f32, bf16, undef));
91+
}
92+
93+
std::vector<int> get_dim_order(int ndims, const dims_t strides) {
94+
std::vector<int> order(ndims);
95+
for (int i = 0; i < ndims; ++i) {
96+
order[i] = i;
97+
}
98+
99+
std::sort(
100+
order.begin(), order.end(), [&strides](size_t i, size_t j) {
101+
return strides[i] < strides[j];
102+
});
103+
104+
return order;
105+
}
106+
107+
status_t create_ip_mds() {
108+
auto accumulate_dimensions = [](const dims_t dimensions, int start,
109+
int end) -> int64_t {
110+
int64_t accum = 1;
111+
for (int i = start; i < end; i++) {
112+
accum *= dimensions[i];
113+
}
114+
return accum;
115+
};
116+
117+
const auto src_md_ = arg_md(DNNL_ARG_SRC);
118+
const auto weights_md_ = arg_md(DNNL_ARG_WEIGHTS);
119+
const auto bias_md_ = arg_md(DNNL_ARG_BIAS);
120+
auto src_wrap = memory_desc_wrapper(src_md_);
121+
auto w_wrap = memory_desc_wrapper(weights_md_);
122+
123+
// src and weights dims need to be in the same order
124+
if (get_dim_order(src_wrap.ndims(), src_wrap.strides())
125+
!= get_dim_order(w_wrap.ndims(), w_wrap.strides())) {
126+
return status::unimplemented;
127+
}
128+
129+
// Reshape input into the form of Batch x (\prod_{dim_{n-1}}^dim_0)
130+
if (src_md_->ndims == 2) {
131+
src_md_reshaped = *src_md_;
132+
} else {
133+
int64_t src_flattened_dimension = accumulate_dimensions(
134+
src_md_->dims, 1, src_md_->ndims);
135+
dims_t src_reshaped_dims {
136+
src_md_->dims[0], src_flattened_dimension};
137+
CHECK(memory_desc_init_by_tag(src_md_reshaped, 2,
138+
src_reshaped_dims, src_md_->data_type, format_tag::ab));
139+
}
140+
141+
// Reshape weights as (OC x (\prod_{dim_{n-1}}^dim_0))^T
142+
int weights_flattened_dimensions = accumulate_dimensions(
143+
weights_md_->dims, 1, weights_md_->ndims);
144+
dims_t weights_reshaped_dims {
145+
weights_flattened_dimensions, weights_md_->dims[0]};
146+
CHECK(memory_desc_init_by_tag(weights_md_reshaped, 2,
147+
weights_reshaped_dims, weights_md_->data_type,
148+
format_tag::ba));
149+
if (with_bias()) {
150+
dims_t bias_reshaped_dims {1, bias_md_->dims[0]};
151+
CHECK(memory_desc_init_by_tag(bias_md_reshaped, 2,
152+
bias_reshaped_dims, bias_md_->data_type,
153+
format_tag::ab));
154+
}
155+
return status::success;
156+
}
157+
158+
status_t init_matmul(impl::engine_t *engine);
159+
// Memory descriptors to contain reshaped tensors from nD to 2D for IP
160+
memory_desc_t src_md_reshaped;
161+
memory_desc_t weights_md_reshaped;
162+
memory_desc_t bias_md_reshaped;
163+
};
164+
165+
status_t init(impl::engine_t *engine) override;
166+
status_t execute(const exec_ctx_t &ctx) const override;
167+
168+
private:
169+
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
170+
kernel_t kernel_;
171+
std::shared_ptr<impl::primitive_t> matmul_primitive;
172+
};
173+
} // namespace dnnl::impl::gpu::generic::sycl
174+
175+
#endif

src/gpu/gpu_inner_product_list.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
#include "gpu/amd/miopen_gemm_inner_product.hpp"
3333
#endif
3434

35+
#ifdef GENERIC_SYCL_KERNELS_ENABLED
36+
#include "gpu/generic/sycl/ref_inner_product.hpp"
37+
#endif
38+
3539
namespace dnnl {
3640
namespace impl {
3741
namespace gpu {
@@ -49,6 +53,7 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>>
4953
GPU_INSTANCE_NVIDIA(nvidia::cudnn_gemm_inner_product_fwd_t)
5054
GPU_INSTANCE_NVIDIA(nvidia::cudnn_conv_inner_product_fwd_t)
5155
GPU_INSTANCE_AMD(amd::miopen_gemm_inner_product_fwd_t)
56+
GPU_INSTANCE_GENERIC_SYCL(generic::sycl::ref_inner_product_fwd_t)
5257
nullptr,
5358
}},
5459
{{backward}, REG_BWD_PK({

tests/gtests/test_inner_product_forward.cpp

+31-7
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,18 @@ class inner_product_test_t
8888
protected:
8989
void SetUp() override {
9090
auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
91-
SKIP_IF_CUDA(!cuda_check_format_tags(p.src_format, p.weights_format,
92-
p.bias_format, p.dst_format),
91+
SKIP_IF_CUDA(!cuda_generic_check_format_tags(p.src_format,
92+
p.weights_format, p.bias_format, p.dst_format),
93+
"Unsupported format tag");
94+
SKIP_IF_GENERIC(!cuda_generic_check_format_tags(p.src_format,
95+
p.weights_format, p.bias_format, p.dst_format),
9396
"Unsupported format tag");
9497
SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions");
95-
SKIP_IF_GENERIC(true, "Primitive not implemented");
9698
catch_expected_failures(
9799
[&]() { Test(); }, p.expect_to_fail, p.expected_status);
98100
}
99101

100-
bool cuda_check_format_tags(memory::format_tag src_format,
102+
bool cuda_generic_check_format_tags(memory::format_tag src_format,
101103
memory::format_tag wei_format, memory::format_tag bia_format,
102104
memory::format_tag dst_format) {
103105
bool src_ok = src_format == memory::format_tag::ncdhw
@@ -130,6 +132,20 @@ class inner_product_test_t
130132
return src_ok && wei_ok && bia_ok && dst_ok;
131133
}
132134

135+
std::vector<int> get_dim_order(const memory::dims &strides) {
136+
size_t ndims = strides.size();
137+
std::vector<int> order(ndims);
138+
for (size_t i = 0; i < ndims; ++i) {
139+
order[i] = i;
140+
}
141+
142+
std::sort(order.begin(), order.end(), [&strides](size_t i, size_t j) {
143+
return strides[i] < strides[j];
144+
});
145+
146+
return order;
147+
}
148+
133149
void Test() {
134150
auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
135151
test_inner_product_descr_t ipd = p.test_ipd;
@@ -169,18 +185,26 @@ class inner_product_test_t
169185
: create_md({}, data_type, p.bias_format);
170186
auto ip_dst_desc = create_md({ipd.mb, ipd.oc}, data_type, p.dst_format);
171187

188+
SKIP_IF_GENERIC(get_dim_order(ip_src_desc.get_strides())
189+
!= get_dim_order(ip_weights_desc.get_strides()),
190+
"Unsupported case for generic");
191+
172192
auto ip_primitive_desc = with_bias
173193
? pd_t(eng, p.aprop_kind, ip_src_desc, ip_weights_desc,
174194
ip_bias_desc, ip_dst_desc)
175195
: pd_t(eng, p.aprop_kind, ip_src_desc, ip_weights_desc,
176196
ip_dst_desc);
177197

178198
auto aa = allows_attr_t {false};
179-
aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
180199
aa.po_eltwise = true;
181-
aa.po_prelu = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
182200
aa.po_sum = true;
183-
201+
#ifdef DNNL_SYCL_GENERIC
202+
aa.po_binary = true;
203+
aa.po_prelu = true;
204+
#else
205+
aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
206+
aa.po_prelu = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
207+
#endif
184208
test_fwd_pd_constructors<pd_t>(ip_primitive_desc, aa, p.aprop_kind,
185209
ip_src_desc, ip_weights_desc, ip_bias_desc, ip_dst_desc);
186210

0 commit comments

Comments
 (0)