Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 170bfaf

Browse files
committedMay 9, 2024·
gpu: sycl: binary: add support for remaining post ops
1 parent 7bdb0e1 commit 170bfaf

File tree

4 files changed

+137
-17
lines changed

4 files changed

+137
-17
lines changed
 

‎src/gpu/sycl/binary_kernels.hpp

+107-4
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,24 @@ struct binary_kernel_vec_t {
3636
xpu::sycl::in_memory_arg_t &src0, xpu::sycl::in_memory_arg_t &src1,
3737
xpu::sycl::out_memory_arg_t &dst,
3838
xpu::sycl::in_memory_arg_t &src0_scale,
39-
xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt)
39+
xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt,
40+
xpu::sycl::in_memory_arg_t &po1_src,
41+
xpu::sycl::in_memory_arg_t &po2_src,
42+
xpu::sycl::in_memory_arg_t &po3_src,
43+
xpu::sycl::in_memory_arg_t &po4_src,
44+
xpu::sycl::in_memory_arg_t &po5_src)
4045
: conf_(conf)
4146
, src0_(src0)
4247
, src1_(src1)
4348
, dst_(dst)
4449
, src0_scale_(src0_scale)
4550
, src1_scale_(src1_scale)
46-
, scales_dt_(scales_dt) {}
51+
, scales_dt_(scales_dt)
52+
, po1_src_(po1_src)
53+
, po2_src_(po2_src)
54+
, po3_src_(po3_src)
55+
, po4_src_(po4_src)
56+
, po5_src_(po5_src) {}
4757

4858
void operator()(::sycl::nd_item<1> item) const {
4959
auto sg = item.get_sub_group();
@@ -73,7 +83,7 @@ struct binary_kernel_vec_t {
7383
any_broadcast |= conf_.broadcast_dims[i];
7484
}
7585
}
76-
if (!any_broadcast
86+
if (!any_broadcast && conf_.post_ops.get_post_op() == 0
7787
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
7888
< conf_.wk_size) {
7989
for (int i = 0; i < conf_.block_size / vec_len; i++) {
@@ -123,7 +133,8 @@ struct binary_kernel_vec_t {
123133
if (conf_.do_scale_src1) src1 *= sm_1;
124134

125135
auto acc = compute_alg_n(src0, src1, conf_.alg_kind);
126-
acc = conf_.post_ops.apply(acc, dst);
136+
::sycl::vec<float, 16> post_po_sr = post_op_src_val(idx);
137+
acc = conf_.post_ops.apply(acc, dst, post_po_sr);
127138
store_float_value(
128139
dst_md().data_type(), acc, dst_ptr(), idx);
129140
}
@@ -146,6 +157,93 @@ struct binary_kernel_vec_t {
146157
return static_cast<float *>(src1_scale_.get_pointer());
147158
}
148159

160+
inline ::sycl::vec<float, 16> post_op_src_val(dim_t data_l_off) const {
161+
::sycl::vec<float, 16> post_po_sr;
162+
const auto maxPostPo = conf_.post_ops.get_post_op();
163+
164+
for (dim_t po_idx = 0; po_idx < maxPostPo; po_idx++) {
165+
float res = 0.0f;
166+
if (po_idx == 0)
167+
res = get_post_op_val(po1_src_, po_idx, data_l_off);
168+
else if (po_idx == 1)
169+
res = get_post_op_val(po2_src_, po_idx, data_l_off);
170+
else if (po_idx == 2)
171+
res = get_post_op_val(po3_src_, po_idx, data_l_off);
172+
else if (po_idx == 3)
173+
res = get_post_op_val(po4_src_, po_idx, data_l_off);
174+
else if (po_idx == 4)
175+
res = get_post_op_val(po5_src_, po_idx, data_l_off);
176+
177+
post_po_sr[po_idx] = res;
178+
}
179+
return post_po_sr;
180+
}
181+
182+
float get_post_op_val(const xpu::sycl::in_memory_arg_t &bin_src_op,
183+
dim_t &idx, dim_t offset) const {
184+
auto src1_desc = conf_.binary_src_arr[idx];
185+
186+
const auto off = get_binary_src1_off(
187+
src1_desc, offset, dst_md().dims(), dst_md().ndims());
188+
189+
auto dst = load_float_value(
190+
src1_desc.data_type(), bin_src_op.get_pointer(), off);
191+
return dst;
192+
}
193+
194+
dim_t get_binary_src1_off(const xpu::sycl::md_t &src1_md, dim_t l_offset,
195+
const xpu::sycl::md_t::dims32_t &dst_dims,
196+
const xpu::sycl::md_t::dim32_t &dst_ndims) const {
197+
const dim_t mask_binary_po
198+
= get_dims_mask(dst_dims, src1_md.dims(), dst_ndims);
199+
return get_po_tensor_off(
200+
src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po);
201+
}
202+
203+
inline dim_t get_dims_mask(const xpu::sycl::md_t::dims32_t &dims1,
204+
const xpu::sycl::md_t::dims32_t &dims2, const dim_t &ndims,
205+
bool skip_dim_of_one = false) const {
206+
dim_t mask = 0;
207+
for (dim_t d = 0; d < ndims; ++d) {
208+
// Disable mask_bit for dimensions of `1` by request.
209+
dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d);
210+
mask += dims1[d] == dims2[d] ? mask_bit : 0;
211+
}
212+
return mask;
213+
}
214+
215+
inline dim_t get_po_tensor_off(const xpu::sycl::md_t &tensor_md,
216+
dim_t l_offset, const xpu::sycl::md_t::dims32_t &dst_dims,
217+
const dim_t &dst_ndims, const dim_t &mask) const {
218+
dims_t l_dims_po {};
219+
get_l_dims_po(l_dims_po, l_offset, dst_dims, dst_ndims, mask);
220+
221+
return tensor_md.off_v(l_dims_po);
222+
}
223+
224+
inline void get_l_dims_po(dims_t l_dims_po, dim_t l_offset,
225+
const xpu::sycl::md_t::dims32_t &dst_dims, const dim_t &dst_ndims,
226+
const dim_t &mask) const {
227+
228+
l_dims_by_l_offset(l_dims_po, l_offset, dst_dims, dst_ndims);
229+
utils::apply_mask_on_dims(l_dims_po, dst_ndims, mask);
230+
}
231+
232+
inline void l_dims_by_l_offset(dims_t dims_pos, dim_t l_offset,
233+
const xpu::sycl::md_t::dims32_t &dims, const dim_t &ndims) const {
234+
for (dim_t rd = 0; rd < ndims; ++rd) {
235+
const dim_t d = ndims - 1 - rd;
236+
/* switch to faster 32-bit division when possible. */
237+
if (l_offset <= INT32_MAX && dims[d] <= INT32_MAX) {
238+
dims_pos[d] = (int32_t)l_offset % (int32_t)dims[d];
239+
l_offset = (int32_t)l_offset / (int32_t)dims[d];
240+
} else {
241+
dims_pos[d] = l_offset % dims[d];
242+
l_offset /= dims[d];
243+
}
244+
}
245+
}
246+
149247
template <int width>
150248
::sycl::vec<float, width> compute_alg(::sycl::vec<float, width> src0,
151249
::sycl::vec<float, width> src1, alg_kind_t alg) const {
@@ -199,6 +297,11 @@ struct binary_kernel_vec_t {
199297
xpu::sycl::in_memory_arg_t src0_scale_;
200298
xpu::sycl::in_memory_arg_t src1_scale_;
201299
data_type_t scales_dt_;
300+
xpu::sycl::in_memory_arg_t po1_src_;
301+
xpu::sycl::in_memory_arg_t po2_src_;
302+
xpu::sycl::in_memory_arg_t po3_src_;
303+
xpu::sycl::in_memory_arg_t po4_src_;
304+
xpu::sycl::in_memory_arg_t po5_src_;
202305
};
203306

204307
} // namespace sycl

‎src/gpu/sycl/ref_binary.cpp

+21-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ status_t ref_binary_t::pd_t::init_conf() {
5252

5353
conf_.post_ops = sycl_post_ops_t(attr());
5454

55+
for (auto i = 0; i < conf_.post_ops.get_post_op(); ++i) {
56+
const auto &e = attr()->post_ops_.entry_[i];
57+
if (e.is_binary() || e.is_prelu()) {
58+
conf_.binary_src_arr[i] = xpu::sycl::md_t(
59+
arg_md(DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1));
60+
}
61+
}
5562
return status::success;
5663
}
5764

@@ -62,6 +69,7 @@ status_t ref_binary_t::init(engine_t *engine) {
6269
}
6370

6471
status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {
72+
6573
parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
6674
auto src0_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0);
6775
auto src1_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_1);
@@ -76,9 +84,21 @@ status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {
7684
.data_type()
7785
: data_type_t::dnnl_f32;
7886

87+
auto src_mem_po_1 = CTX_IN_SYCL_KERNEL_MEMORY(
88+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1));
89+
auto src_mem_po_2 = CTX_IN_SYCL_KERNEL_MEMORY(
90+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1));
91+
auto src_mem_po_3 = CTX_IN_SYCL_KERNEL_MEMORY(
92+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1));
93+
auto src_mem_po_4 = CTX_IN_SYCL_KERNEL_MEMORY(
94+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1));
95+
auto src_mem_po_5 = CTX_IN_SYCL_KERNEL_MEMORY(
96+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(4) | DNNL_ARG_SRC_1));
97+
7998
binary_kernel_vec_t binary_kernel(pd()->conf_, src0_mem_arg,
8099
src1_mem_arg, dst_mem_arg, src0_scale_mem_arg,
81-
src1_scale_mem_arg, scales_dt);
100+
src1_scale_mem_arg, scales_dt, src_mem_po_1, src_mem_po_2,
101+
src_mem_po_3, src_mem_po_4, src_mem_po_5);
82102

83103
const int block_size = pd()->conf_.block_size;
84104
const int wg_size = pd()->conf_.wg_size;

‎src/gpu/sycl/ref_binary.hpp

+7-12
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
4848
const memory_desc_wrapper dst_d(dst_md());
4949

5050
const bool ok = set_default_params() == status::success
51+
&& attr_.set_default_formats(dst_md()) == status::success
5152
&& check_data_types(src0_d, src1_d, dst_d)
5253
&& check_formats(src0_d, src1_d, dst_d)
5354
&& attr()->has_default_values(
@@ -72,18 +73,12 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
7273
}
7374

7475
bool post_ops_ok() const {
75-
for (int i = 0; i < attr()->post_ops_.len(); i++) {
76-
const auto &e = attr()->post_ops_.entry_[i];
77-
if (!IMPLICATION(e.is_eltwise(),
78-
utils::one_of(e.eltwise.alg, alg_kind::eltwise_relu,
79-
alg_kind::eltwise_linear))) {
80-
return false;
81-
}
82-
}
83-
// Binary, prelu and dw conv post-ops are not supported.
76+
// Dw conv post-ops are not supported.
8477
return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops
8578
&& attr()->post_ops_.has_default_values(
86-
{primitive_kind::eltwise});
79+
{primitive_kind::eltwise, primitive_kind::binary,
80+
primitive_kind::prelu,
81+
primitive_kind::sum});
8782
}
8883

8984
static bool check_data_types(const memory_desc_wrapper &src0,
@@ -100,7 +95,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
10095
}
10196

10297
return IMPLICATION(utils::one_of(bf16, src0_dt, src1_dt, dst_dt),
103-
src0_dt == src1_dt == dst_dt);
98+
src0_dt == dst_dt && src1_dt == dst_dt);
10499
}
105100

106101
static bool check_formats(const memory_desc_wrapper &src0,
@@ -109,7 +104,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
109104
using namespace format_tag;
110105

111106
for (const auto &mdw : {src0, src1, dst}) {
112-
if (mdw.matches_one_of_tag(ab, abc, abcd, abcde) == undef) {
107+
if (mdw.matches_one_of_tag(a, ab, abc, abcd, abcde) == undef) {
113108
return false;
114109
}
115110
}

‎src/gpu/sycl/sycl_primitive_conf.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ struct sycl_binary_conf_t {
4444
int wg_size;
4545
int wk_size;
4646

47+
xpu::sycl::md_t binary_src_arr[8];
48+
4749
sycl_post_ops_t post_ops;
4850
};
4951

0 commit comments

Comments
 (0)
Please sign in to comment.