Skip to content

Commit 15ca799

Browse files
committed
generic: sycl: fix accessor types
1 parent 23914f0 commit 15ca799

11 files changed

+35
-25
lines changed

src/gpu/generic/sycl/binary_kernels.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct binary_kernel_vec_t {
3939
: conf_(conf)
4040
, src0_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
4141
, src1_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_1))
42-
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
42+
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
4343
, src0_scale_(CTX_IN_SYCL_KERNEL_MEMORY(
4444
DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0))
4545
, src1_scale_(CTX_IN_SYCL_KERNEL_MEMORY(
@@ -187,7 +187,7 @@ struct binary_kernel_vec_t {
187187

188188
xpu::sycl::in_memory_arg_t src0_;
189189
xpu::sycl::in_memory_arg_t src1_;
190-
xpu::sycl::out_memory_arg_t dst_;
190+
xpu::sycl::inout_memory_arg_t dst_;
191191
xpu::sycl::in_memory_arg_t src0_scale_;
192192
xpu::sycl::in_memory_arg_t src1_scale_;
193193
data_type_t scales_dt_;

src/gpu/generic/sycl/convolution_kernels.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct convolution_kernel_fwd_t {
3939
, data_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
4040
, weights_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_WEIGHTS))
4141
, bias_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_BIAS))
42-
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
42+
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
4343
, data_scale_(CTX_IN_SYCL_KERNEL_MEMORY(
4444
DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0))
4545
, weights_scale_(CTX_IN_SYCL_KERNEL_MEMORY(
@@ -232,7 +232,7 @@ struct convolution_kernel_fwd_t {
232232
xpu::sycl::in_memory_arg_t data_;
233233
xpu::sycl::in_memory_arg_t weights_;
234234
xpu::sycl::in_memory_arg_t bias_;
235-
xpu::sycl::out_memory_arg_t dst_;
235+
xpu::sycl::inout_memory_arg_t dst_;
236236
xpu::sycl::in_memory_arg_t data_scale_;
237237
xpu::sycl::in_memory_arg_t weights_scale_;
238238
xpu::sycl::in_memory_arg_t dst_scale_;
@@ -250,7 +250,7 @@ struct convolution_kernel_bwd_data_t {
250250
convolution_kernel_bwd_data_t(const sycl_convolution_conf_t &conf,
251251
::sycl::handler &cgh, const exec_ctx_t &ctx)
252252
: conf_(conf)
253-
, diff_data_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DIFF_SRC))
253+
, diff_data_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DIFF_SRC))
254254
, weights_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_WEIGHTS))
255255
, bias_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_BIAS))
256256
, diff_dst_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_DIFF_DST))
@@ -461,7 +461,7 @@ struct convolution_kernel_bwd_data_t {
461461

462462
sycl_convolution_conf_t conf_;
463463

464-
xpu::sycl::out_memory_arg_t diff_data_;
464+
xpu::sycl::inout_memory_arg_t diff_data_;
465465
xpu::sycl::in_memory_arg_t weights_;
466466
xpu::sycl::in_memory_arg_t bias_;
467467
xpu::sycl::in_memory_arg_t diff_dst_;

src/gpu/generic/sycl/eltwise_kernels.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct eltwise_fwd_kernel_vec_t {
3737
: conf_(conf)
3838
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
3939
, po_args_(cgh, ctx, conf_.post_ops)
40-
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST)) {}
40+
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST)) {}
4141

4242
void operator()(::sycl::nd_item<1> item) const {
4343
memory_tensor_t src_mem(src_, conf_.src_md);
@@ -194,7 +194,7 @@ struct eltwise_fwd_kernel_vec_t {
194194
sycl_eltwise_conf_t conf_;
195195
xpu::sycl::in_memory_arg_t src_;
196196
post_op_input_args po_args_;
197-
xpu::sycl::out_memory_arg_t dst_;
197+
xpu::sycl::inout_memory_arg_t dst_;
198198
};
199199

200200
struct eltwise_bwd_kernel_vec_t {

src/gpu/generic/sycl/matmul_kernels.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct matmul_kernel_fwd_t {
8383
}
8484

8585
static void store_vec_helper(
86-
out_memory_tensor_t &output, Vec data, int offset) {
86+
inout_memory_tensor_t &output, Vec data, int offset) {
8787
data_type_t type = output.md().data_type();
8888
char *offset_ptr = static_cast<char *>(output.ptr())
8989
+ data_type_size(type) * offset;
@@ -189,7 +189,7 @@ struct matmul_kernel_fwd_t {
189189
}
190190
}
191191

192-
void store(out_memory_tensor_t &output, int offset, int row_stride) {
192+
void store(inout_memory_tensor_t &output, int offset, int row_stride) {
193193
for (int row = 0; row < Rows; row++) {
194194
for (int col = 0; col < Cols / vec_len; col++) {
195195
store_vec_helper(output, data[row][col],
@@ -198,7 +198,7 @@ struct matmul_kernel_fwd_t {
198198
}
199199
}
200200

201-
void store_edge(out_memory_tensor_t &output, int offset, int row_stride,
201+
void store_edge(inout_memory_tensor_t &output, int offset, int row_stride,
202202
int rows, int cols) {
203203
for (int row = 0; row < rows; row++) {
204204
int col;
@@ -215,7 +215,7 @@ struct matmul_kernel_fwd_t {
215215
}
216216
}
217217

218-
void store_generic(out_memory_tensor_t &output, int offset,
218+
void store_generic(inout_memory_tensor_t &output, int offset,
219219
int row_stride, bool transpose, bool is_edge_block, int rows,
220220
int cols) {
221221
if (is_edge_block) {
@@ -361,7 +361,7 @@ struct matmul_kernel_fwd_t {
361361
, data_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
362362
, weights_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_WEIGHTS))
363363
, bias_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_BIAS))
364-
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
364+
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
365365
, data_scale_(CTX_IN_SYCL_KERNEL_MEMORY(
366366
DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0))
367367
, data_scales_dt_((conf_.do_scale_data)
@@ -653,7 +653,7 @@ struct matmul_kernel_fwd_t {
653653
xpu::sycl::in_memory_arg_t data_;
654654
xpu::sycl::in_memory_arg_t weights_;
655655
xpu::sycl::in_memory_arg_t bias_;
656-
xpu::sycl::out_memory_arg_t dst_;
656+
xpu::sycl::inout_memory_arg_t dst_;
657657
xpu::sycl::in_memory_arg_t data_scale_;
658658
data_type_t data_scales_dt_;
659659
xpu::sycl::in_memory_arg_t weights_scale_;

src/gpu/generic/sycl/ref_pooling.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct ref_pooling_fwd_t : public gpu::generic::sycl::primitive_t {
6666
src_md(0)->data_type != dst_md(0)->data_type,
6767
desc()->prop_kind == forward_inference))
6868
&& attr()->has_default_values(sm::post_ops)
69+
&& sycl_post_ops_t::post_ops_ok(attr(), true, false)
6970
&& attr_.set_default_formats(dst_md(0)) == status::success
7071
&& md_dims_in_range(src_md());
7172
if (!ok) return status::unimplemented;

src/gpu/generic/sycl/ref_resampling.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct ref_resampling_fwd_t : public gpu::generic::sycl::primitive_t {
5050
const bool ok = is_fwd() && is_supported_type(src_md(0)->data_type)
5151
&& is_supported_type(dst_md(0)->data_type)
5252
&& attr()->has_default_values(sm::post_ops)
53+
&& sycl_post_ops_t::post_ops_ok(attr())
5354
&& set_default_params() == status::success
5455
&& attr_.set_default_formats(dst_md(0)) == status::success
5556
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)

src/gpu/generic/sycl/reorder_kernels.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct reorder_kernel_t {
3838
const exec_ctx_t &ctx)
3939
: conf_(conf)
4040
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
41-
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
41+
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
4242
, src_scale_(CTX_IN_SYCL_KERNEL_MEMORY(
4343
DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0))
4444
, dst_scale_(CTX_IN_SYCL_KERNEL_MEMORY(
@@ -143,7 +143,7 @@ struct reorder_kernel_t {
143143
sycl_reorder_conf_t conf_;
144144

145145
xpu::sycl::in_memory_arg_t src_;
146-
xpu::sycl::out_memory_arg_t dst_;
146+
xpu::sycl::inout_memory_arg_t dst_;
147147
xpu::sycl::in_memory_arg_t src_scale_;
148148
xpu::sycl::in_memory_arg_t dst_scale_;
149149
data_type_t scales_src_dt_;

src/gpu/generic/sycl/resampling_kernels.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct resampling_kernel_fwd_vec_t {
3838
::sycl::handler &cgh, const exec_ctx_t &ctx)
3939
: conf_(conf)
4040
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
41-
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
41+
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
4242
, po_args_(cgh, ctx, conf_.post_ops) {}
4343

4444
void operator()(::sycl::nd_item<1> item) const {
@@ -142,7 +142,7 @@ struct resampling_kernel_fwd_vec_t {
142142
sycl_resampling_conf_t conf_;
143143

144144
xpu::sycl::in_memory_arg_t src_;
145-
xpu::sycl::out_memory_arg_t dst_;
145+
xpu::sycl::inout_memory_arg_t dst_;
146146
post_op_input_args po_args_;
147147
};
148148

src/gpu/generic/sycl/softmax_kernels.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct softmax_fwd_kernel_vec_t {
4141
DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC))
4242
, scale_dst_(CTX_IN_SYCL_KERNEL_MEMORY(
4343
DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST))
44-
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
44+
, dst_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
4545
, po_args_(cgh, ctx, conf_.post_ops) {}
4646

4747
void operator()(::sycl::nd_item<1> item) const {
@@ -140,7 +140,7 @@ struct softmax_fwd_kernel_vec_t {
140140
xpu::sycl::in_memory_arg_t src_;
141141
xpu::sycl::in_memory_arg_t scale_src_;
142142
xpu::sycl::in_memory_arg_t scale_dst_;
143-
xpu::sycl::out_memory_arg_t dst_;
143+
xpu::sycl::inout_memory_arg_t dst_;
144144
post_op_input_args po_args_;
145145
};
146146

src/gpu/generic/sycl/sycl_post_ops.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ struct ref_sum_op_t {
224224
ref_sum_op_t(float scale, float zeropoint)
225225
: scale_(scale), zeropoint_(zeropoint) {}
226226

227-
float load_and_compute(float acc, const xpu::sycl::out_memory_arg_t &dst,
227+
float load_and_compute(float acc, const xpu::sycl::inout_memory_arg_t &dst,
228228
dnnl::impl::data_type_t sum_dt_,
229229
dim_t offset) const { // TODO dims32_t
230230
memory_plain_t dst_mem(dst, sum_dt_);
@@ -321,14 +321,14 @@ struct sycl_post_ops_t {
321321
n_post_ops_ = attr_po.len();
322322
}
323323

324-
inline float apply(float acc, const xpu::sycl::out_memory_arg_t &dst,
324+
inline float apply(float acc, const xpu::sycl::inout_memory_arg_t &dst,
325325
dim_t dst_offset, const post_op_input_args &po_args,
326326
dims_t src_offset) const;
327327
inline float apply(float acc, float dst, const post_op_input_args &po_args,
328328
dims_t src_offset) const;
329329
inline float apply(float acc, const post_op_input_args &po_args,
330330
dims_t src_offset) const;
331-
inline float apply(float acc, const xpu::sycl::out_memory_arg_t &dst,
331+
inline float apply(float acc, const xpu::sycl::inout_memory_arg_t &dst,
332332
dim_t dst_offset) const;
333333

334334
inline int get_post_op() const { return n_post_ops_; }
@@ -369,7 +369,7 @@ struct post_op_input_args {
369369
xpu::sycl::in_memory_arg_t args_[sycl_post_ops_t::max_post_ops];
370370
};
371371

372-
float sycl_post_ops_t::apply(float acc, const xpu::sycl::out_memory_arg_t &dst,
372+
float sycl_post_ops_t::apply(float acc, const xpu::sycl::inout_memory_arg_t &dst,
373373
dim_t dst_offset, const post_op_input_args &po_args,
374374
dims_t src_offset) const {
375375
using namespace primitive_kind;
@@ -438,7 +438,7 @@ float sycl_post_ops_t::apply(
438438
return acc;
439439
}
440440

441-
float sycl_post_ops_t::apply(float acc, const xpu::sycl::out_memory_arg_t &dst,
441+
float sycl_post_ops_t::apply(float acc, const xpu::sycl::inout_memory_arg_t &dst,
442442
dim_t dst_offset) const {
443443
using namespace primitive_kind;
444444

src/xpu/sycl/types.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ namespace sycl {
4848
&CTX_OUT_STORAGE(arg)) \
4949
->get_out_memory_arg(ctx.stream(), cgh)
5050

51+
#define CTX_INOUT_SYCL_KERNEL_MEMORY(arg) \
52+
CTX_OUT_STORAGE(arg).is_null() \
53+
? xpu::sycl::memory_storage_base_t::empty_inout_memory_arg( \
54+
ctx.stream(), cgh) \
55+
: utils::downcast<const xpu::sycl::memory_storage_base_t *>( \
56+
&CTX_OUT_STORAGE(arg)) \
57+
->get_inout_memory_arg(ctx.stream(), cgh)
58+
5159
#define CHECK_SYCL_KERNEL_ARG_TYPE(type) \
5260
static_assert(::sycl::is_device_copyable_v<type>)
5361

0 commit comments

Comments
 (0)