Skip to content

Commit 19f0862

Browse files
authored
generic: sycl: implement prelu post-op (#2131)
1 parent c19ea38 commit 19f0862

17 files changed

+139
-65
lines changed

src/gpu/generic/sycl/binary_kernels.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct binary_kernel_vec_t {
4848
| DNNL_ARG_SRC_0)
4949
.data_type()
5050
: data_type_t::dnnl_f32)
51-
, po_args_(cgh, ctx) {}
51+
, po_args_(cgh, ctx, conf_.post_ops) {}
5252

5353
void operator()(::sycl::nd_item<1> item) const {
5454
memory_tensor_t src0_mem(src0_, conf_.src0_md);

src/gpu/generic/sycl/eltwise_kernels.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct eltwise_fwd_kernel_vec_t {
3636
::sycl::handler &cgh, const exec_ctx_t &ctx)
3737
: conf_(conf)
3838
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
39-
, po_args_(cgh, ctx)
39+
, po_args_(cgh, ctx, conf_.post_ops)
4040
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST)) {}
4141

4242
void operator()(::sycl::nd_item<1> item) const {

src/gpu/generic/sycl/matmul_kernels.hpp

+38-3
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,36 @@ struct matmul_kernel_fwd_t {
323323
}
324324
}
325325
}
326+
327+
void apply_post_ops_edge(sycl_post_ops_t post_ops,
328+
register_block<Rows, Cols> prev_dst, dims_t off_po, int dim1,
329+
const matmul_kernel_fwd_t *kernel, int rows, int cols) {
330+
for (int row = 0; row < rows; row++) {
331+
int col;
332+
for (col = 0; col < cols / vec_len; col++) {
333+
for (int v_el = 0; v_el < vec_len; v_el++) {
334+
off_po[dim1] += row;
335+
off_po[dim1 + 1] += col * vec_len + v_el;
336+
data[row][col][v_el]
337+
= post_ops.apply(data[row][col][v_el],
338+
prev_dst.data[row][col][v_el],
339+
kernel->po_args_, off_po);
340+
off_po[dim1] -= row;
341+
off_po[dim1 + 1] -= col * vec_len + v_el;
342+
}
343+
}
344+
int n_remaining = cols - col * vec_len;
345+
for (int v_el = 0; v_el < n_remaining; v_el++) {
346+
off_po[dim1] += row;
347+
off_po[dim1 + 1] += col * vec_len + v_el;
348+
data[row][col][v_el] = post_ops.apply(data[row][col][v_el],
349+
prev_dst.data[row][col][v_el], kernel->po_args_,
350+
off_po);
351+
off_po[dim1] -= row;
352+
off_po[dim1 + 1] -= col * vec_len + v_el;
353+
}
354+
}
355+
}
326356
};
327357

328358
matmul_kernel_fwd_t(const sycl_matmul_conf_t &conf, ::sycl::handler &cgh,
@@ -377,7 +407,7 @@ struct matmul_kernel_fwd_t {
377407
, dropout_seed_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_ATTR_DROPOUT_SEED))
378408
, dropout_probability_(
379409
CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_ATTR_DROPOUT_PROBABILITY))
380-
, po_args_(cgh, ctx) {}
410+
, po_args_(cgh, ctx, conf_.post_ops) {}
381411

382412
void operator()(::sycl::nd_item<1> item) const {
383413
using data_block_t = register_block<register_block_M, register_block_K>;
@@ -597,8 +627,13 @@ struct matmul_kernel_fwd_t {
597627
if (conf_.transpose_dst) {
598628
std::swap(off_po[matmul_dim_1], off_po[matmul_dim_2]);
599629
}
600-
dst_block.apply_post_ops(
601-
conf_.post_ops, prev_dst, off_po, matmul_dim_1, this);
630+
if (is_dst_edge_block) {
631+
dst_block.apply_post_ops_edge(conf_.post_ops, prev_dst, off_po,
632+
matmul_dim_1, this, remaining_m, remaining_n);
633+
} else {
634+
dst_block.apply_post_ops(
635+
conf_.post_ops, prev_dst, off_po, matmul_dim_1, this);
636+
}
602637

603638
if (conf_.do_scale_dst) {
604639
dst_block.eltwise([=](float &el) { el /= dst_scale; });

src/gpu/generic/sycl/pooling_kernels.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct pooling_fwd_kernel_vec_t {
4242
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
4343
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
4444
, ws_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_WORKSPACE))
45-
, po_args_(cgh, ctx) {}
45+
, po_args_(cgh, ctx, conf_.post_ops) {}
4646

4747
void operator()(::sycl::nd_item<1> item) const {
4848
memory_tensor_t src_mem(src_, conf_.src_md);

src/gpu/generic/sycl/ref_binary.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ status_t ref_binary_t::pd_t::init_conf() {
4949
= conf_.src0_md.dims()[i] != 1 && conf_.src1_md.dims()[i] == 1;
5050
}
5151

52-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
52+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
5353

5454
return status::success;
5555
}

src/gpu/generic/sycl/ref_convolution.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ status_t ref_convolution_fwd_t::pd_t::init_conf() {
5151
conf_.single_data_zeropoint = attr()->zero_points_.common(DNNL_ARG_SRC_0);
5252
conf_.single_dst_zeropoint = attr()->zero_points_.common(DNNL_ARG_DST);
5353

54-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
54+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
5555

5656
conf_.padding[0] = static_cast<int>(desc()->padding[0][0]);
5757
conf_.padding[1] = static_cast<int>(desc()->padding[0][1]);
@@ -111,7 +111,7 @@ status_t ref_convolution_bwd_data_t::pd_t::init_conf() {
111111
conf_.single_data_zeropoint = attr()->zero_points_.common(DNNL_ARG_SRC_0);
112112
conf_.single_dst_zeropoint = attr()->zero_points_.common(DNNL_ARG_DST);
113113

114-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
114+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
115115

116116
conf_.padding[0] = static_cast<int>(desc()->padding[0][0]);
117117
conf_.padding[1] = static_cast<int>(desc()->padding[0][1]);
@@ -173,7 +173,7 @@ status_t ref_convolution_bwd_weights_t::pd_t::init_conf() {
173173
conf_.single_data_zeropoint = attr()->zero_points_.common(DNNL_ARG_SRC_0);
174174
conf_.single_dst_zeropoint = attr()->zero_points_.common(DNNL_ARG_DST);
175175

176-
conf_.post_ops = sycl_post_ops_t(attr());
176+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
177177

178178
conf_.padding[0] = static_cast<int>(desc()->padding[0][0]);
179179
conf_.padding[1] = static_cast<int>(desc()->padding[0][1]);

src/gpu/generic/sycl/ref_eltwise.cpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ status_t ref_sycl_eltwise_fwd_t::pd_t::init_conf() {
3838
conf_.h = H();
3939
conf_.w = W();
4040

41-
if (attr()->post_ops_.len() > sycl_post_ops_t::max_post_ops) {
42-
return status::unimplemented;
43-
}
44-
conf_.post_po_len = attr()->post_ops_.len();
45-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
41+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
4642

4743
return status::success;
4844
}

src/gpu/generic/sycl/ref_matmul.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ void ref_matmul_t::pd_t::init_conf() {
4343
= !attr()->zero_points_.has_default_values(DNNL_ARG_DST);
4444

4545
conf_.use_dropout = !attr()->dropout_.has_default_values();
46-
47-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
46+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
4847

4948
memory_desc_wrapper src_d = src_md();
5049
memory_desc_wrapper weights_d = weights_md();

src/gpu/generic/sycl/ref_matmul.hpp

+2-9
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
5959
| sm::zero_points_runtime_data_type)
6060
&& IMPLICATION(
6161
!attr()->scales_.has_default_values(), scales_ok())
62-
&& post_ops_ok() && md_dims_in_range(src_md())
62+
&& sycl_post_ops_t::post_ops_ok(attr())
63+
&& md_dims_in_range(src_md())
6364
&& md_dims_in_range(weights_md());
6465
if (!ok) return status::unimplemented;
6566

@@ -121,14 +122,6 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
121122
return dt_ok && attr_scales_ok(supported_args);
122123
}
123124

124-
bool post_ops_ok() const {
125-
// Dw conv post-ops are not supported.
126-
return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops
127-
&& attr()->post_ops_.has_default_values(
128-
{primitive_kind::eltwise, primitive_kind::binary,
129-
primitive_kind::sum});
130-
}
131-
132125
static bool check_data_types(const memory_desc_wrapper &src,
133126
const memory_desc_wrapper &weights,
134127
const memory_desc_wrapper &dst) {

src/gpu/generic/sycl/ref_pooling.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ status_t ref_pooling_fwd_t::pd_t::init_conf() {
6464
conf_.DH = KDH(); //K:kernel D:Dilation H:Height
6565
conf_.DW = KDW(); //K:kernel D:Dilation W:Weight
6666

67-
const auto *att = attr();
68-
const auto &attr_po = att->post_ops_;
69-
if (attr_po.len() > sycl_post_ops_t::max_post_ops) {
70-
return dnnl_unimplemented;
71-
}
72-
conf_.po_len = attr_po.len();
73-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
67+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
7468
return status::success;
7569
}
7670

src/gpu/generic/sycl/ref_reorder.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ status_t ref_reorder_t::pd_t::init_conf() {
3838
conf_.do_scale_dst
3939
= !attr()->scales_.get(DNNL_ARG_DST).has_default_values();
4040
conf_.scale_dst_mask = attr()->scales_.get(DNNL_ARG_DST).mask_;
41-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
41+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
4242

4343
return status::success;
4444
}

src/gpu/generic/sycl/ref_resampling.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,8 @@ status_t ref_resampling_fwd_t::pd_t::init_conf() {
4343
conf_.dst_md = xpu::sycl::md_t(dst_md());
4444

4545
conf_.alg = desc()->alg_kind;
46-
const auto *att = attr();
47-
const auto &attr_po = att->post_ops_;
48-
if (attr_po.len() > sycl_post_ops_t::max_post_ops) {
49-
return dnnl_unimplemented;
50-
}
51-
conf_.po_len = attr_po.len();
5246

53-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
47+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
5448
return status::success;
5549
}
5650

src/gpu/generic/sycl/ref_softmax.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ status_t ref_sycl_softmax_fwd_t::pd_t::init_conf() {
4141
conf_.do_scale_dst
4242
= !attr()->scales_.get(DNNL_ARG_DST).has_default_values();
4343

44-
conf_.post_ops = sycl_post_ops_t(attr(), dst_md()->data_type);
45-
conf_.po_len = attr()->post_ops_.len();
44+
conf_.post_ops = sycl_post_ops_t(attr(), dst_md());
4645

4746
return status::success;
4847
}

src/gpu/generic/sycl/resampling_kernels.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct resampling_kernel_fwd_vec_t {
3939
: conf_(conf)
4040
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
4141
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
42-
, po_args_(cgh, ctx) {}
42+
, po_args_(cgh, ctx, conf_.post_ops) {}
4343

4444
void operator()(::sycl::nd_item<1> item) const {
4545
memory_tensor_t src_mem(src_, conf_.src_md);

src/gpu/generic/sycl/softmax_kernels.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct softmax_fwd_kernel_vec_t {
4242
, scale_dst_(CTX_IN_SYCL_KERNEL_MEMORY(
4343
DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST))
4444
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
45-
, po_args_(cgh, ctx) {}
45+
, po_args_(cgh, ctx, conf_.post_ops) {}
4646

4747
void operator()(::sycl::nd_item<1> item) const {
4848
memory_tensor_t src_mem(src_, conf_.src_md);

0 commit comments

Comments
 (0)