15
15
*******************************************************************************/
16
16
17
17
#include " gpu/generic/sycl/ref_matmul.hpp"
18
+ #include " common/c_types_map.hpp"
18
19
#include " gpu/generic/sycl/matmul_kernels.hpp"
20
+ #include " gpu/generic/sycl/specialization_constants.hpp"
21
+ #include " xpu/sycl/types.hpp"
22
+
23
+ #define VCHECK_MATMUL (cond, msg, ...) \
24
+ VCONDCHECK (primitive, create, check, matmul, (cond), \
25
+ status::unimplemented, msg, ##__VA_ARGS__);
19
26
20
27
namespace dnnl {
21
28
namespace impl {
22
29
namespace gpu {
23
30
namespace generic {
24
31
namespace sycl {
25
32
26
- void ref_matmul_t::pd_t::init_conf () {
33
+ status_t ref_matmul_t::pd_t::init_conf () {
27
34
conf_ = sycl_matmul_conf_t ();
28
35
29
36
conf_.do_scale_data
@@ -49,19 +56,61 @@ void ref_matmul_t::pd_t::init_conf() {
49
56
memory_desc_wrapper weights_d = weights_md ();
50
57
memory_desc_wrapper dst_d = dst_md ();
51
58
memory_desc_wrapper bias_d = weights_md (1 );
52
- for ( const auto &mdw : {src_d, weights_d, dst_d, bias_d}) {
53
- if (mdw .has_runtime_dims ()) {
54
- any_runtime_params_ = true ;
55
- return ;
56
- }
57
- }
58
- init_rt_conf (conf_, src_d, weights_d, dst_d, bias_d);
59
+ VCHECK_MATMUL (! utils::one_of ( true , src_d. has_runtime_dims (),
60
+ weights_d .has_runtime_dims (),
61
+ dst_d. has_runtime_dims (), bias_d. has_runtime_dims ()),
62
+ VERBOSE_RUNTIMEDIM_UNSUPPORTED) ;
63
+
64
+ return init_rt_conf (conf_, data_md_t , dst_md_t , weights_md_t , src_d,
65
+ weights_d, dst_d, bias_d);
59
66
}
60
67
61
- void ref_matmul_t::pd_t::init_rt_conf (sycl_matmul_conf_t &conf,
68
+ status_t ref_matmul_t::pd_t::init_rt_conf (sycl_matmul_conf_t &conf,
69
+ xpu::sycl::md_t_spec_const &data_md_t_,
70
+ xpu::sycl::md_t_spec_const &dst_md_t_,
71
+ xpu::sycl::md_t_spec_const &weights_md_t_,
62
72
const memory_desc_wrapper src_d, const memory_desc_wrapper weights_d,
63
73
const memory_desc_wrapper dst_d,
64
74
const memory_desc_wrapper bias_d) const {
75
+
76
+ // Lambda because this function will not be used anywhere else
77
+ auto init_md_t_sc_from_md = [=](xpu::sycl::md_t_spec_const &md_t_sc,
78
+ const memory_desc_t *md) -> status_t {
79
+ constexpr int max_dims = 6 ;
80
+ using dim32_t = int32_t ;
81
+
82
+ memory_desc_wrapper mdw (md);
83
+
84
+ VCHECK_MATMUL (mdw.format_kind () == format_kind::blocked,
85
+ VERBOSE_UNSUPPORTED_FORMAT_KIND);
86
+ VCHECK_MATMUL (
87
+ mdw.ndims () <= max_dims, VERBOSE_BAD_NDIMS, mdw, mdw.ndims ());
88
+
89
+ const auto &blk = mdw.blocking_desc ();
90
+
91
+ md_t_sc.data_type_ = mdw.data_type ();
92
+ #define CHECK_AND_ASSIGN (lhs, rhs ) \
93
+ VCHECK_MATMUL ((rhs) <= INT32_MAX, VERBOSE_BAD_PARAM, rhs); \
94
+ (lhs) = static_cast <dim32_t >(rhs)
95
+
96
+ CHECK_AND_ASSIGN (md_t_sc.ndims_ , mdw.ndims ());
97
+ CHECK_AND_ASSIGN (md_t_sc.offset0_ , mdw.offset0 ());
98
+ CHECK_AND_ASSIGN (md_t_sc.inner_nblks_ , blk.inner_nblks );
99
+
100
+ for (int d = 0 ; d < mdw.ndims (); d++) {
101
+ CHECK_AND_ASSIGN (md_t_sc.dims_ [d], mdw.dims ()[d]);
102
+ CHECK_AND_ASSIGN (md_t_sc.padded_dims_ [d], mdw.padded_dims ()[d]);
103
+ CHECK_AND_ASSIGN (
104
+ md_t_sc.padded_offsets_ [d], mdw.padded_offsets ()[d]);
105
+ CHECK_AND_ASSIGN (md_t_sc.strides_ [d], blk.strides [d]);
106
+ CHECK_AND_ASSIGN (md_t_sc.inner_blks_ [d], blk.inner_blks [d]);
107
+ CHECK_AND_ASSIGN (md_t_sc.inner_idxs_ [d], blk.inner_idxs [d]);
108
+ }
109
+ #undef CHECK_AND_ASSIGN
110
+
111
+ return status::success;
112
+ };
113
+
65
114
int matmul_dim_1 = ndims () - 2 ;
66
115
int matmul_dim_2 = ndims () - 1 ;
67
116
@@ -73,7 +122,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
73
122
data_md_copy.dims [matmul_dim_2]);
74
123
conf.transpose_data = true ;
75
124
}
76
- conf. data_md = xpu::sycl::md_t ( &data_md_copy);
125
+ init_md_t_sc_from_md (data_md_t_, &data_md_copy);
77
126
78
127
memory_desc_t weights_md_copy = *weights_d.md_ ;
79
128
auto &weights_strides = weights_md_copy.format_desc .blocking .strides ;
@@ -83,7 +132,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
83
132
weights_md_copy.dims [matmul_dim_2]);
84
133
conf.transpose_weights = true ;
85
134
}
86
- conf. weights_md = xpu::sycl::md_t ( &weights_md_copy);
135
+ init_md_t_sc_from_md (weights_md_t_, &weights_md_copy);
87
136
88
137
memory_desc_t dst_md_copy = *dst_d.md_ ;
89
138
auto &dst_strides = dst_md_copy.format_desc .blocking .strides ;
@@ -93,7 +142,7 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
93
142
dst_md_copy.dims [matmul_dim_1], dst_md_copy.dims [matmul_dim_2]);
94
143
conf.transpose_dst = true ;
95
144
}
96
- conf. dst_md = xpu::sycl::md_t ( &dst_md_copy);
145
+ init_md_t_sc_from_md (dst_md_t_, &dst_md_copy);
97
146
98
147
if (with_bias ()) {
99
148
memory_desc_t bias_md_copy = *bias_d.md_ ;
@@ -109,8 +158,8 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
109
158
110
159
dims_t dst_blocks;
111
160
for (int i = 0 ; i < matmul_kernel_fwd_t ::max_supported_ndims; i++) {
112
- if (i < conf. dst_md . ndims () ) {
113
- dst_blocks[i] = conf. dst_md . dims () [i];
161
+ if (i < dst_md_t . ndims_ ) {
162
+ dst_blocks[i] = dst_md_t . dims_ [i];
114
163
} else {
115
164
dst_blocks[i] = 1 ;
116
165
}
@@ -133,34 +182,44 @@ void ref_matmul_t::pd_t::init_rt_conf(sycl_matmul_conf_t &conf,
133
182
= utils::get_dims_mask (dst_d.dims (), weights_d.dims (), ndims ())
134
183
| high_two_bits;
135
184
conf.bias_mask = utils::get_dims_mask (dst_d.dims (), bias_d.dims (), ndims ());
185
+
186
+ return status::success;
136
187
}
137
188
138
189
status_t ref_matmul_t::init (impl::engine_t *engine) {
139
190
const auto kid = ::sycl::get_kernel_id<matmul_kernel_fwd_t >();
140
- CHECK (create_kernel (engine, kid, &kernel_));
191
+ CHECK (create_matmul_kernel (engine, kid, &kernel_,
192
+ {pd ()->data_md_t , pd ()->dst_md_t , pd ()->weights_md_t }));
193
+ return status::success;
194
+ }
195
+
196
+ status_t ref_matmul_t::create_matmul_kernel (impl::engine_t *engine,
197
+ ::sycl::kernel_id kid, kernel_t *kernel,
198
+ xpu::sycl::md_t_spec_const_pod pod) {
199
+
200
+ auto ctx = utils::downcast<const xpu::sycl::engine_impl_t *>(engine->impl ())
201
+ ->context ();
202
+ auto input_bundle = ::sycl::get_kernel_bundle<::sycl::bundle_state::input>(
203
+ ctx, {kid});
204
+
205
+ input_bundle.template set_specialization_constant <
206
+ detail::matmul::md_t_spec_const_id>(pod);
207
+ try {
208
+ (*kernel) = kernel_t (::sycl::build (input_bundle));
209
+ } catch (const ::sycl::exception &e) { return status::runtime_error; }
141
210
return status::success;
142
211
}
143
212
144
213
status_t ref_matmul_t::execute (const exec_ctx_t &ctx) const {
145
214
if (memory_desc_wrapper (pd ()->dst_md ()).size () == 0 ) return status::success;
146
215
147
- sycl_matmul_conf_t conf = pd ()->conf_ ;
148
- if (pd ()->any_runtime_params_ ) {
149
- const auto src_d = ctx.memory_mdw (DNNL_ARG_SRC, pd ()->src_md ());
150
- const auto weights_d
151
- = ctx.memory_mdw (DNNL_ARG_WEIGHTS, pd ()->weights_md ());
152
- const auto dst_d = ctx.memory_mdw (DNNL_ARG_DST, pd ()->dst_md ());
153
- const auto bias_d = ctx.memory_mdw (DNNL_ARG_BIAS, pd ()->weights_md (1 ));
154
- pd ()->init_rt_conf (conf, src_d, weights_d, dst_d, bias_d);
155
- }
156
-
157
216
parallel_for (ctx, kernel_, [&](::sycl::handler &cgh) {
158
- matmul_kernel_fwd_t matmul_kernel (conf , cgh, ctx);
217
+ matmul_kernel_fwd_t matmul_kernel (pd ()-> conf_ , cgh, ctx);
159
218
160
219
const int block_size = 32 ;
161
220
const int wg_size = 32 ;
162
221
163
- const int t_work = conf .wk_size ;
222
+ const int t_work = pd ()-> conf_ .wk_size ;
164
223
const int wg_work = wg_size * block_size;
165
224
const int wg_cnt = utils::div_up (t_work, wg_work);
166
225
0 commit comments