15
15
*******************************************************************************/
16
16
17
17
#include " gpu/generic/sycl/ref_matmul.hpp"
18
+ #include " common/c_types_map.hpp"
18
19
#include " gpu/generic/sycl/matmul_kernels.hpp"
20
+ #include " gpu/generic/sycl/specialization_constants.hpp"
21
+ #include " xpu/sycl/types.hpp"
19
22
20
23
namespace dnnl {
21
24
namespace impl {
22
25
namespace gpu {
23
26
namespace generic {
24
27
namespace sycl {
25
28
26
- void ref_matmul_t::pd_t::init_conf () {
29
+ status_t ref_matmul_t::pd_t::init_conf () {
27
30
conf_ = sycl_matmul_conf_t ();
28
31
29
32
conf_.do_scale_data
@@ -52,16 +55,56 @@ void ref_matmul_t::pd_t::init_conf() {
52
55
for (const auto &mdw : {src_d, weights_d, dst_d, bias_d}) {
53
56
if (mdw.has_runtime_dims ()) {
54
57
any_runtime_params_ = true ;
55
- return ;
58
+ return status::unimplemented ;
56
59
}
57
60
}
58
- init_rt_conf (conf_, src_d, weights_d, dst_d, bias_d);
61
+ init_rt_conf (conf_, data_md_t , dst_md_t , weights_md_t , src_d,
62
+ weights_d, dst_d, bias_d);
63
+ return status::success;
59
64
}
60
65
61
66
void ref_matmul_t::pd_t::init_rt_conf (sycl_matmul_conf_t &conf,
67
+ xpu::sycl::md_t_spec_const &data_md_t_,
68
+ xpu::sycl::md_t_spec_const &dst_md_t_,
69
+ xpu::sycl::md_t_spec_const &weights_md_t_,
62
70
const memory_desc_wrapper src_d, const memory_desc_wrapper weights_d,
63
71
const memory_desc_wrapper dst_d,
64
72
const memory_desc_wrapper bias_d) const {
73
+
74
+ // Lambda because this function will not be used anywhere else
75
+ auto init_md_t_sc_from_md = [=](xpu::sycl::md_t_spec_const &md_t_sc,
76
+ const memory_desc_t *md) -> void {
77
+ constexpr int max_dims = 6 ;
78
+ using dim32_t = int32_t ;
79
+
80
+ memory_desc_wrapper mdw (md);
81
+
82
+ assert (mdw.format_kind () == format_kind::blocked);
83
+ assert (mdw.ndims () <= max_dims);
84
+
85
+ const auto &blk = mdw.blocking_desc ();
86
+
87
+ md_t_sc.data_type_ = mdw.data_type ();
88
+ #define CHECK_AND_ASSIGN (lhs, rhs ) \
89
+ assert ((rhs) <= INT32_MAX); \
90
+ (lhs) = static_cast <dim32_t >(rhs)
91
+
92
+ CHECK_AND_ASSIGN (md_t_sc.ndims_ , mdw.ndims ());
93
+ CHECK_AND_ASSIGN (md_t_sc.offset0_ , mdw.offset0 ());
94
+ CHECK_AND_ASSIGN (md_t_sc.inner_nblks_ , blk.inner_nblks );
95
+
96
+ for (int d = 0 ; d < mdw.ndims (); d++) {
97
+ CHECK_AND_ASSIGN (md_t_sc.dims_ [d], mdw.dims ()[d]);
98
+ CHECK_AND_ASSIGN (md_t_sc.padded_dims_ [d], mdw.padded_dims ()[d]);
99
+ CHECK_AND_ASSIGN (
100
+ md_t_sc.padded_offsets_ [d], mdw.padded_offsets ()[d]);
101
+ CHECK_AND_ASSIGN (md_t_sc.strides_ [d], blk.strides [d]);
102
+ CHECK_AND_ASSIGN (md_t_sc.inner_blks_ [d], blk.inner_blks [d]);
103
+ CHECK_AND_ASSIGN (md_t_sc.inner_idxs_ [d], blk.inner_idxs [d]);
104
+ }
105
+ #undef CHECK_AND_ASSIGN
106
+ };
107
+
65
108
int matmul_dim_1 = ndims () - 2 ;
66
109
int matmul_dim_2 = ndims () - 1 ;
67
110
@@ -73,7 +116,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
73
116
data_md_copy.dims [matmul_dim_2]);
74
117
conf.transpose_data = true ;
75
118
}
76
- conf. data_md = xpu::sycl::md_t ( &data_md_copy);
119
+ init_md_t_sc_from_md (data_md_t_, &data_md_copy);
77
120
78
121
memory_desc_t weights_md_copy = *weights_d.md_ ;
79
122
auto &weights_strides = weights_md_copy.format_desc .blocking .strides ;
@@ -83,7 +126,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
83
126
weights_md_copy.dims [matmul_dim_2]);
84
127
conf.transpose_weights = true ;
85
128
}
86
- conf. weights_md = xpu::sycl::md_t ( &weights_md_copy);
129
+ init_md_t_sc_from_md (weights_md_t_, &weights_md_copy);
87
130
88
131
memory_desc_t dst_md_copy = *dst_d.md_ ;
89
132
auto &dst_strides = dst_md_copy.format_desc .blocking .strides ;
@@ -93,7 +136,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
93
136
dst_md_copy.dims [matmul_dim_1], dst_md_copy.dims [matmul_dim_2]);
94
137
conf.transpose_dst = true ;
95
138
}
96
- conf. dst_md = xpu::sycl::md_t ( &dst_md_copy);
139
+ init_md_t_sc_from_md (dst_md_t_, &dst_md_copy);
97
140
98
141
if (with_bias ()) {
99
142
memory_desc_t bias_md_copy = *bias_d.md_ ;
@@ -109,8 +152,8 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
109
152
110
153
dims_t dst_blocks;
111
154
for (int i = 0 ; i < matmul_kernel_fwd_t ::max_supported_ndims; i++) {
112
- if (i < conf. dst_md . ndims () ) {
113
- dst_blocks[i] = conf. dst_md . dims () [i];
155
+ if (i < dst_md_t . ndims_ ) {
156
+ dst_blocks[i] = dst_md_t . dims_ [i];
114
157
} else {
115
158
dst_blocks[i] = 1 ;
116
159
}
@@ -137,30 +180,38 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
137
180
138
181
status_t ref_matmul_t::init (impl::engine_t *engine) {
139
182
const auto kid = ::sycl::get_kernel_id<matmul_kernel_fwd_t >();
140
- CHECK (create_kernel (engine, kid, &kernel_));
183
+ CHECK (create_matmul_kernel (engine, kid, &kernel_,
184
+ {pd ()->data_md_t , pd ()->dst_md_t , pd ()->weights_md_t }));
185
+ return status::success;
186
+ }
187
+
188
+ status_t ref_matmul_t::create_matmul_kernel (impl::engine_t *engine,
189
+ ::sycl::kernel_id kid, kernel_t *kernel,
190
+ xpu::sycl::md_t_spec_const_pod pod) {
191
+
192
+ auto ctx = utils::downcast<const xpu::sycl::engine_impl_t *>(engine->impl ())
193
+ ->context ();
194
+ auto input_bundle = ::sycl::get_kernel_bundle<::sycl::bundle_state::input>(
195
+ ctx, {kid});
196
+
197
+ input_bundle.template set_specialization_constant <
198
+ detail::matmul::md_t_spec_const_id>(pod);
199
+ try {
200
+ (*kernel) = kernel_t (::sycl::build (input_bundle));
201
+ } catch (const ::sycl::exception &e) { return status::runtime_error; }
141
202
return status::success;
142
203
}
143
204
144
205
status_t ref_matmul_t::execute (const exec_ctx_t &ctx) const {
145
206
if (memory_desc_wrapper (pd ()->dst_md ()).size () == 0 ) return status::success;
146
207
147
- sycl_matmul_conf_t conf = pd ()->conf_ ;
148
- if (pd ()->any_runtime_params_ ) {
149
- const auto src_d = ctx.memory_mdw (DNNL_ARG_SRC, pd ()->src_md ());
150
- const auto weights_d
151
- = ctx.memory_mdw (DNNL_ARG_WEIGHTS, pd ()->weights_md ());
152
- const auto dst_d = ctx.memory_mdw (DNNL_ARG_DST, pd ()->dst_md ());
153
- const auto bias_d = ctx.memory_mdw (DNNL_ARG_BIAS, pd ()->weights_md (1 ));
154
- pd ()->init_rt_conf (conf, src_d, weights_d, dst_d, bias_d);
155
- }
156
-
157
208
parallel_for (ctx, kernel_, [&](::sycl::handler &cgh) {
158
- matmul_kernel_fwd_t matmul_kernel (conf , cgh, ctx);
209
+ matmul_kernel_fwd_t matmul_kernel (pd ()-> conf_ , cgh, ctx);
159
210
160
211
const int block_size = 32 ;
161
212
const int wg_size = 32 ;
162
213
163
- const int t_work = conf .wk_size ;
214
+ const int t_work = pd ()-> conf_ .wk_size ;
164
215
const int wg_work = wg_size * block_size;
165
216
const int wg_cnt = utils::div_up (t_work, wg_work);
166
217
0 commit comments