Skip to content

Commit f5b0373

Browse files
committed
generic: sycl: refactor how post op arguments are loaded
1 parent 7cfab18 commit f5b0373

14 files changed

+186
-504
lines changed

src/gpu/generic/sycl/binary_kernels.hpp

+11-93
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,7 @@ struct binary_kernel_vec_t {
4848
| DNNL_ARG_SRC_0)
4949
.data_type()
5050
: data_type_t::dnnl_f32)
51-
, po1_src_(CTX_IN_SYCL_KERNEL_MEMORY(
52-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1)))
53-
, po2_src_(CTX_IN_SYCL_KERNEL_MEMORY(
54-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1)))
55-
, po3_src_(CTX_IN_SYCL_KERNEL_MEMORY(
56-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1)))
57-
, po4_src_(CTX_IN_SYCL_KERNEL_MEMORY(
58-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1)))
59-
, po5_src_(CTX_IN_SYCL_KERNEL_MEMORY(
60-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(4) | DNNL_ARG_SRC_1))) {}
51+
, po_args_(cgh, ctx) {}
6152

6253
void operator()(::sycl::nd_item<1> item) const {
6354
memory_tensor_t src0_mem(src0_, conf_.src0_md);
@@ -85,18 +76,19 @@ struct binary_kernel_vec_t {
8576
bool any_broadcast = false;
8677
bool is_same_tag = true;
8778
for (int i = 0; i < max_supported_ndims; i++) {
88-
if (i < dst_md().ndims()) {
89-
dims[i] = dst_md().dims()[i];
90-
strides[i] = dst_md().strides()[i];
79+
if (i < dst_mem.md().ndims()) {
80+
dims[i] = dst_mem.md().dims()[i];
81+
strides[i] = dst_mem.md().strides()[i];
9182
any_broadcast |= conf_.broadcast_dims0[i];
9283
any_broadcast |= conf_.broadcast_dims1[i];
9384
} else {
9485
dims[i] = 1;
9586
strides[i] = INT_MAX;
9687
}
97-
if (i < src0_md().ndims()) {
88+
if (i < src0_mem.md().ndims()) {
9889
is_same_tag = is_same_tag
99-
&& (src0_md().strides()[i] == src1_md().strides()[i]);
90+
&& (src0_mem.md().strides()[i]
91+
== src1_mem.md().strides()[i]);
10092
}
10193
}
10294
if (!any_broadcast && conf_.post_ops.get_post_op() == 0
@@ -106,7 +98,6 @@ struct binary_kernel_vec_t {
10698
for (int i = 0; i < conf_.block_size / vec_len; i++) {
10799
auto src0_vec = src0_mem.load_vec<vec_len>(vec_base_idx + i);
108100
auto src1_vec = src1_mem.load_vec<vec_len>(vec_base_idx + i);
109-
auto dst_vec = dst_mem.load_vec<vec_len>(vec_base_idx + i);
110101

111102
if (conf_.do_scale_src0)
112103
src0_vec *= ::sycl::vec<float, vec_len>(sm_0);
@@ -117,7 +108,6 @@ struct binary_kernel_vec_t {
117108
// TODO: Adding post-ops seems to be interfering with compiler's
118109
// optimizations. Figure out how to make the compiler to generate
119110
// the right code.
120-
acc_vec = conf_.post_ops.apply(acc_vec, dst_vec);
121111
dst_mem.store_vec(acc_vec, vec_base_idx + i);
122112
}
123113
} else {
@@ -135,89 +125,21 @@ struct binary_kernel_vec_t {
135125

136126
auto src0 = src0_mem.load_md(off0);
137127
auto src1 = src1_mem.load_md(off1);
138-
auto dst = dst_mem.load(idx);
139128

140129
if (conf_.do_scale_src0) src0 *= sm_0;
141130
if (conf_.do_scale_src1) src1 *= sm_1;
142131

143132
auto acc = compute_alg_n(src0, src1, conf_.alg_kind);
144-
::sycl::vec<float, 16> post_po_sr
145-
= post_op_src_val(off_dst);
146-
acc = conf_.post_ops.apply(acc, dst, post_po_sr);
133+
134+
acc = conf_.post_ops.apply(
135+
acc, dst_, idx, po_args_, off_dst);
147136
dst_mem.store(acc, idx);
148137
}
149138
}
150139
}
151140
}
152141

153142
private:
154-
const xpu::sycl::md_t &src0_md() const { return conf_.src0_md; }
155-
const xpu::sycl::md_t &src1_md() const { return conf_.src1_md; }
156-
const xpu::sycl::md_t &dst_md() const { return conf_.dst_md; }
157-
158-
inline ::sycl::vec<float, 16> post_op_src_val(dims_t data_off) const {
159-
::sycl::vec<float, 16> post_po_sr;
160-
const auto maxPostPo = conf_.post_ops.get_post_op();
161-
162-
for (dim_t po_idx = 0; po_idx < maxPostPo; po_idx++) {
163-
float res = 0.0f;
164-
if (po_idx == 0)
165-
res = get_post_op_val(po1_src_, po_idx, data_off);
166-
else if (po_idx == 1)
167-
res = get_post_op_val(po2_src_, po_idx, data_off);
168-
else if (po_idx == 2)
169-
res = get_post_op_val(po3_src_, po_idx, data_off);
170-
else if (po_idx == 3)
171-
res = get_post_op_val(po4_src_, po_idx, data_off);
172-
else if (po_idx == 4)
173-
res = get_post_op_val(po5_src_, po_idx, data_off);
174-
175-
post_po_sr[po_idx] = res;
176-
}
177-
return post_po_sr;
178-
}
179-
180-
float get_post_op_val(const xpu::sycl::in_memory_arg_t &bin_src_op,
181-
dim_t &idx, dims_t offset) const {
182-
auto src1_desc = conf_.binary_src_arr[idx];
183-
184-
const auto off = get_binary_src1_off(
185-
src1_desc, offset, dst_md().dims(), dst_md().ndims());
186-
187-
auto dst = load_float_value(
188-
src1_desc.data_type(), bin_src_op.get_pointer(), off);
189-
return dst;
190-
}
191-
192-
dim_t get_binary_src1_off(const xpu::sycl::md_t &src1_md, dims_t offset,
193-
const xpu::sycl::md_t::dims32_t &dst_dims,
194-
const xpu::sycl::md_t::dim32_t &dst_ndims) const {
195-
const dim_t mask_binary_po
196-
= get_dims_mask(dst_dims, src1_md.dims(), dst_ndims);
197-
return get_po_tensor_off(
198-
src1_md, offset, dst_dims, dst_ndims, mask_binary_po);
199-
}
200-
201-
inline dim_t get_dims_mask(const xpu::sycl::md_t::dims32_t &dims1,
202-
const xpu::sycl::md_t::dims32_t &dims2, const dim_t &ndims,
203-
bool skip_dim_of_one = false) const {
204-
dim_t mask = 0;
205-
for (dim_t d = 0; d < ndims; ++d) {
206-
// Disable mask_bit for dimensions of `1` by request.
207-
dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d);
208-
mask += dims1[d] == dims2[d] ? mask_bit : 0;
209-
}
210-
return mask;
211-
}
212-
213-
inline dim_t get_po_tensor_off(const xpu::sycl::md_t &tensor_md,
214-
dims_t offset, const xpu::sycl::md_t::dims32_t &dst_dims,
215-
const dim_t &dst_ndims, const dim_t &mask) const {
216-
dims_t offset_po {};
217-
utils::copy_dims_with_mask(offset_po, offset, dst_ndims, mask);
218-
return tensor_md.off_v(offset_po);
219-
}
220-
221143
template <int width>
222144
::sycl::vec<float, width> compute_alg(::sycl::vec<float, width> src0,
223145
::sycl::vec<float, width> src1, alg_kind_t alg) const {
@@ -271,11 +193,7 @@ struct binary_kernel_vec_t {
271193
xpu::sycl::in_memory_arg_t src0_scale_;
272194
xpu::sycl::in_memory_arg_t src1_scale_;
273195
data_type_t scales_dt_;
274-
xpu::sycl::in_memory_arg_t po1_src_;
275-
xpu::sycl::in_memory_arg_t po2_src_;
276-
xpu::sycl::in_memory_arg_t po3_src_;
277-
xpu::sycl::in_memory_arg_t po4_src_;
278-
xpu::sycl::in_memory_arg_t po5_src_;
196+
post_op_input_args po_args_;
279197
};
280198

281199
} // namespace sycl

src/gpu/generic/sycl/convolution_kernels.hpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,7 @@ struct convolution_kernel_fwd_t {
208208
accumulator += bias;
209209
}
210210

211-
auto dst = load_float_value(
212-
conf_.post_ops.sum_dt_ == dnnl_data_type_undef
213-
? dst_md().data_type()
214-
: conf_.post_ops.sum_dt_,
215-
dst_ptr(), idx);
216-
accumulator = conf_.post_ops.apply(accumulator, dst);
211+
accumulator = conf_.post_ops.apply(accumulator, dst_, idx);
217212

218213
if (conf_.do_scale_dst) { accumulator /= sm_dst; }
219214
if (conf_.use_dst_zeropoints) {
@@ -446,12 +441,7 @@ struct convolution_kernel_bwd_data_t {
446441
accumulator += bias;
447442
}
448443

449-
auto diff_data = load_float_value(
450-
conf_.post_ops.sum_dt_ == dnnl_data_type_undef
451-
? diff_data_md().data_type()
452-
: conf_.post_ops.sum_dt_,
453-
diff_data_ptr(), idx);
454-
accumulator = conf_.post_ops.apply(accumulator, diff_data);
444+
accumulator = conf_.post_ops.apply(accumulator, diff_data_, idx);
455445

456446
if (conf_.do_scale_dst) { accumulator /= sm_dst; }
457447
if (conf_.use_dst_zeropoints) {

src/gpu/generic/sycl/eltwise_kernels.hpp

+15-119
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,7 @@ struct eltwise_fwd_kernel_vec_t {
3636
::sycl::handler &cgh, const exec_ctx_t &ctx)
3737
: conf_(conf)
3838
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
39-
, srcOp1_(CTX_IN_SYCL_KERNEL_MEMORY(
40-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1)))
41-
, srcOp2_(CTX_IN_SYCL_KERNEL_MEMORY(
42-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1)))
43-
, srcOp3_(CTX_IN_SYCL_KERNEL_MEMORY(
44-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1)))
45-
, srcOp4_(CTX_IN_SYCL_KERNEL_MEMORY(
46-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1)))
47-
, srcOp5_(CTX_IN_SYCL_KERNEL_MEMORY(
48-
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(4) | DNNL_ARG_SRC_1)))
39+
, po_args_(cgh, ctx)
4940
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST)) {}
5041

5142
void operator()(::sycl::nd_item<1> item) const {
@@ -62,20 +53,23 @@ struct eltwise_fwd_kernel_vec_t {
6253

6354
auto operation = [&](dim_t &idx, dim_t &n, dim_t &c, dim_t &d, dim_t &h,
6455
dim_t &w) {
65-
dim_t src_offset = data_offset(src_md(), n, c, d, h, w);
66-
56+
dim_t src_offset = data_offset(src_mem.md(), n, c, d, h, w);
6757
auto src = src_mem.load(src_offset);
68-
auto dst = dst_mem.load(src_offset);
6958

70-
dim_t data_l_off = (((n * conf_.c + c) * conf_.d + d) * conf_.h + h)
71-
* conf_.w
72-
+ w;
59+
float acc = compute_alg_n(
60+
src, conf_.alpha, conf_.beta, conf_.alg_kind);
7361

74-
::sycl::vec<float, 16> post_po_sr = post_op_src_val(data_l_off);
62+
dims_t po_off {n, c, d, h, w};
63+
switch (src_mem.md().ndims()) {
64+
case 3: po_off[2] = w; break;
65+
case 4:
66+
po_off[2] = h;
67+
po_off[3] = w;
68+
break;
69+
}
70+
acc = conf_.post_ops.apply(acc, dst_, src_offset, po_args_, po_off);
7571

76-
dst = compute_alg_n(src, conf_.alpha, conf_.beta, conf_.alg_kind);
77-
dst = conf_.post_ops.apply(dst, post_po_sr);
78-
dst_mem.store(dst, src_offset);
72+
dst_mem.store(acc, src_offset);
7973
};
8074

8175
for (dim_t blk_idx = 0; blk_idx < conf_.block_size; blk_idx++) {
@@ -98,9 +92,6 @@ struct eltwise_fwd_kernel_vec_t {
9892
}
9993

10094
private:
101-
const xpu::sycl::md_t &src_md() const { return conf_.src_md; }
102-
const xpu::sycl::md_t &dst_md() const { return conf_.dst_md; }
103-
10495
float compute_alg_n(const float &s, const float &alpha, const float &beta,
10596
const alg_kind_t &alg) const {
10697
switch (alg) {
@@ -196,28 +187,6 @@ struct eltwise_fwd_kernel_vec_t {
196187
}
197188
}
198189

199-
inline ::sycl::vec<float, 16> post_op_src_val(dim_t &data_l_off) const {
200-
::sycl::vec<float, 16> post_po_sr;
201-
const auto maxPostPo = conf_.post_po_len;
202-
203-
for (dim_t po_idx = 0; po_idx < maxPostPo; po_idx++) {
204-
float res = 0.0f;
205-
if (po_idx == 0)
206-
res = get_post_op_val(srcOp1_, po_idx, data_l_off);
207-
else if (po_idx == 1)
208-
res = get_post_op_val(srcOp2_, po_idx, data_l_off);
209-
else if (po_idx == 2)
210-
res = get_post_op_val(srcOp3_, po_idx, data_l_off);
211-
else if (po_idx == 3)
212-
res = get_post_op_val(srcOp4_, po_idx, data_l_off);
213-
else if (po_idx == 4)
214-
res = get_post_op_val(srcOp5_, po_idx, data_l_off);
215-
216-
post_po_sr[po_idx] = res;
217-
}
218-
return post_po_sr;
219-
}
220-
221190
inline dim_t data_offset(const xpu::sycl::md_t &mem, dim_t &n, dim_t &c,
222191
dim_t &d, dim_t &h, dim_t &w) const {
223192
const auto ndims = mem.ndims();
@@ -232,78 +201,9 @@ struct eltwise_fwd_kernel_vec_t {
232201
return -1;
233202
}
234203

235-
float get_post_op_val(const xpu::sycl::in_memory_arg_t &bin_src_op,
236-
dim_t &idx, dim_t &offset) const {
237-
auto src1_desc = conf_.binary_src_arr[idx];
238-
239-
const auto off = get_binary_src1_off(
240-
src1_desc, offset, dst_md().dims(), dst_md().ndims());
241-
242-
auto dst = load_float_value(
243-
src1_desc.data_type(), bin_src_op.get_pointer(), off);
244-
return dst;
245-
}
246-
247-
dim_t get_binary_src1_off(const xpu::sycl::md_t &src1_md,
248-
const dim_t &l_offset, const xpu::sycl::md_t::dims32_t &dst_dims,
249-
const xpu::sycl::md_t::dim32_t &dst_ndims) const {
250-
const dim_t mask_binary_po
251-
= get_dims_mask(dst_dims, src1_md.dims(), dst_ndims);
252-
return get_po_tensor_off(
253-
src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po);
254-
}
255-
256-
inline dim_t get_dims_mask(const xpu::sycl::md_t::dims32_t &dims1,
257-
const xpu::sycl::md_t::dims32_t &dims2, const dim_t &ndims,
258-
bool skip_dim_of_one = false) const {
259-
dim_t mask = 0;
260-
for (dim_t d = 0; d < ndims; ++d) {
261-
// Disable mask_bit for dimensions of `1` by request.
262-
dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d);
263-
mask += dims1[d] == dims2[d] ? mask_bit : 0;
264-
}
265-
return mask;
266-
}
267-
268-
inline dim_t get_po_tensor_off(const xpu::sycl::md_t &tensor_md,
269-
const dim_t &l_offset, const xpu::sycl::md_t::dims32_t &dst_dims,
270-
const dim_t &dst_ndims, const dim_t &mask) const {
271-
dims_t l_dims_po {};
272-
get_l_dims_po(l_dims_po, l_offset, dst_dims, dst_ndims, mask);
273-
274-
return tensor_md.off_v(l_dims_po);
275-
}
276-
277-
inline void get_l_dims_po(dims_t l_dims_po, dim_t l_offset,
278-
const xpu::sycl::md_t::dims32_t &dst_dims, const dim_t &dst_ndims,
279-
const dim_t &mask) const {
280-
281-
l_dims_by_l_offset(l_dims_po, l_offset, dst_dims, dst_ndims);
282-
utils::apply_mask_on_dims(l_dims_po, dst_ndims, mask);
283-
}
284-
285-
inline void l_dims_by_l_offset(dims_t dims_pos, dim_t l_offset,
286-
const xpu::sycl::md_t::dims32_t &dims, const dim_t &ndims) const {
287-
for (dim_t rd = 0; rd < ndims; ++rd) {
288-
const dim_t d = ndims - 1 - rd;
289-
/* switch to faster 32-bit division when possible. */
290-
if (l_offset <= INT32_MAX && dims[d] <= INT32_MAX) {
291-
dims_pos[d] = (int32_t)l_offset % (int32_t)dims[d];
292-
l_offset = (int32_t)l_offset / (int32_t)dims[d];
293-
} else {
294-
dims_pos[d] = l_offset % dims[d];
295-
l_offset /= dims[d];
296-
}
297-
}
298-
}
299-
300204
sycl_eltwise_conf_t conf_;
301205
xpu::sycl::in_memory_arg_t src_;
302-
xpu::sycl::in_memory_arg_t srcOp1_;
303-
xpu::sycl::in_memory_arg_t srcOp2_;
304-
xpu::sycl::in_memory_arg_t srcOp3_;
305-
xpu::sycl::in_memory_arg_t srcOp4_;
306-
xpu::sycl::in_memory_arg_t srcOp5_;
206+
post_op_input_args po_args_;
307207
xpu::sycl::out_memory_arg_t dst_;
308208
};
309209

@@ -342,10 +242,6 @@ struct eltwise_bwd_kernel_vec_t {
342242
}
343243

344244
private:
345-
const xpu::sycl::md_t &src_md() const { return conf_.src_md; }
346-
const xpu::sycl::md_t &diff_src_md() const { return conf_.diff_src_md; }
347-
const xpu::sycl::md_t &diff_dst_md() const { return conf_.diff_dst_md; }
348-
349245
inline float compute_alg_n(const float &dd, const float &s,
350246
const float &alpha, const float &beta,
351247
const alg_kind_t &alg) const {

0 commit comments

Comments
 (0)