@@ -36,16 +36,7 @@ struct eltwise_fwd_kernel_vec_t {
36
36
::sycl::handler &cgh, const exec_ctx_t &ctx)
37
37
: conf_(conf)
38
38
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
39
- , srcOp1_(CTX_IN_SYCL_KERNEL_MEMORY(
40
- (DNNL_ARG_ATTR_MULTIPLE_POST_OP(0 ) | DNNL_ARG_SRC_1)))
41
- , srcOp2_(CTX_IN_SYCL_KERNEL_MEMORY(
42
- (DNNL_ARG_ATTR_MULTIPLE_POST_OP(1 ) | DNNL_ARG_SRC_1)))
43
- , srcOp3_(CTX_IN_SYCL_KERNEL_MEMORY(
44
- (DNNL_ARG_ATTR_MULTIPLE_POST_OP(2 ) | DNNL_ARG_SRC_1)))
45
- , srcOp4_(CTX_IN_SYCL_KERNEL_MEMORY(
46
- (DNNL_ARG_ATTR_MULTIPLE_POST_OP(3 ) | DNNL_ARG_SRC_1)))
47
- , srcOp5_(CTX_IN_SYCL_KERNEL_MEMORY(
48
- (DNNL_ARG_ATTR_MULTIPLE_POST_OP(4 ) | DNNL_ARG_SRC_1)))
39
+ , po_args_(cgh, ctx)
49
40
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST)) {}
50
41
51
42
void operator ()(::sycl::nd_item<1 > item) const {
@@ -62,20 +53,23 @@ struct eltwise_fwd_kernel_vec_t {
62
53
63
54
auto operation = [&](dim_t &idx, dim_t &n, dim_t &c, dim_t &d, dim_t &h,
64
55
dim_t &w) {
65
- dim_t src_offset = data_offset (src_md (), n, c, d, h, w);
66
-
56
+ dim_t src_offset = data_offset (src_mem.md (), n, c, d, h, w);
67
57
auto src = src_mem.load (src_offset);
68
- auto dst = dst_mem.load (src_offset);
69
58
70
- dim_t data_l_off = (((n * conf_.c + c) * conf_.d + d) * conf_.h + h)
71
- * conf_.w
72
- + w;
59
+ float acc = compute_alg_n (
60
+ src, conf_.alpha , conf_.beta , conf_.alg_kind );
73
61
74
- ::sycl::vec<float , 16 > post_po_sr = post_op_src_val (data_l_off);
62
+ dims_t po_off {n, c, d, h, w};
63
+ switch (src_mem.md ().ndims ()) {
64
+ case 3 : po_off[2 ] = w; break ;
65
+ case 4 :
66
+ po_off[2 ] = h;
67
+ po_off[3 ] = w;
68
+ break ;
69
+ }
70
+ acc = conf_.post_ops .apply (acc, dst_, src_offset, po_args_, po_off);
75
71
76
- dst = compute_alg_n (src, conf_.alpha , conf_.beta , conf_.alg_kind );
77
- dst = conf_.post_ops .apply (dst, post_po_sr);
78
- dst_mem.store (dst, src_offset);
72
+ dst_mem.store (acc, src_offset);
79
73
};
80
74
81
75
for (dim_t blk_idx = 0 ; blk_idx < conf_.block_size ; blk_idx++) {
@@ -98,9 +92,6 @@ struct eltwise_fwd_kernel_vec_t {
98
92
}
99
93
100
94
private:
101
- const xpu::sycl::md_t &src_md () const { return conf_.src_md ; }
102
- const xpu::sycl::md_t &dst_md () const { return conf_.dst_md ; }
103
-
104
95
float compute_alg_n (const float &s, const float &alpha, const float &beta,
105
96
const alg_kind_t &alg) const {
106
97
switch (alg) {
@@ -196,28 +187,6 @@ struct eltwise_fwd_kernel_vec_t {
196
187
}
197
188
}
198
189
199
- inline ::sycl::vec<float , 16 > post_op_src_val (dim_t &data_l_off) const {
200
- ::sycl::vec<float , 16 > post_po_sr;
201
- const auto maxPostPo = conf_.post_po_len ;
202
-
203
- for (dim_t po_idx = 0 ; po_idx < maxPostPo; po_idx++) {
204
- float res = 0 .0f ;
205
- if (po_idx == 0 )
206
- res = get_post_op_val (srcOp1_, po_idx, data_l_off);
207
- else if (po_idx == 1 )
208
- res = get_post_op_val (srcOp2_, po_idx, data_l_off);
209
- else if (po_idx == 2 )
210
- res = get_post_op_val (srcOp3_, po_idx, data_l_off);
211
- else if (po_idx == 3 )
212
- res = get_post_op_val (srcOp4_, po_idx, data_l_off);
213
- else if (po_idx == 4 )
214
- res = get_post_op_val (srcOp5_, po_idx, data_l_off);
215
-
216
- post_po_sr[po_idx] = res;
217
- }
218
- return post_po_sr;
219
- }
220
-
221
190
inline dim_t data_offset (const xpu::sycl::md_t &mem, dim_t &n, dim_t &c,
222
191
dim_t &d, dim_t &h, dim_t &w) const {
223
192
const auto ndims = mem.ndims ();
@@ -232,78 +201,9 @@ struct eltwise_fwd_kernel_vec_t {
232
201
return -1 ;
233
202
}
234
203
235
- float get_post_op_val (const xpu::sycl::in_memory_arg_t &bin_src_op,
236
- dim_t &idx, dim_t &offset) const {
237
- auto src1_desc = conf_.binary_src_arr [idx];
238
-
239
- const auto off = get_binary_src1_off (
240
- src1_desc, offset, dst_md ().dims (), dst_md ().ndims ());
241
-
242
- auto dst = load_float_value (
243
- src1_desc.data_type (), bin_src_op.get_pointer (), off);
244
- return dst;
245
- }
246
-
247
- dim_t get_binary_src1_off (const xpu::sycl::md_t &src1_md,
248
- const dim_t &l_offset, const xpu::sycl::md_t ::dims32_t &dst_dims,
249
- const xpu::sycl::md_t ::dim32_t &dst_ndims) const {
250
- const dim_t mask_binary_po
251
- = get_dims_mask (dst_dims, src1_md.dims (), dst_ndims);
252
- return get_po_tensor_off (
253
- src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po);
254
- }
255
-
256
- inline dim_t get_dims_mask (const xpu::sycl::md_t ::dims32_t &dims1,
257
- const xpu::sycl::md_t ::dims32_t &dims2, const dim_t &ndims,
258
- bool skip_dim_of_one = false ) const {
259
- dim_t mask = 0 ;
260
- for (dim_t d = 0 ; d < ndims; ++d) {
261
- // Disable mask_bit for dimensions of `1` by request.
262
- dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d);
263
- mask += dims1[d] == dims2[d] ? mask_bit : 0 ;
264
- }
265
- return mask;
266
- }
267
-
268
- inline dim_t get_po_tensor_off (const xpu::sycl::md_t &tensor_md,
269
- const dim_t &l_offset, const xpu::sycl::md_t ::dims32_t &dst_dims,
270
- const dim_t &dst_ndims, const dim_t &mask) const {
271
- dims_t l_dims_po {};
272
- get_l_dims_po (l_dims_po, l_offset, dst_dims, dst_ndims, mask);
273
-
274
- return tensor_md.off_v (l_dims_po);
275
- }
276
-
277
- inline void get_l_dims_po (dims_t l_dims_po, dim_t l_offset,
278
- const xpu::sycl::md_t ::dims32_t &dst_dims, const dim_t &dst_ndims,
279
- const dim_t &mask) const {
280
-
281
- l_dims_by_l_offset (l_dims_po, l_offset, dst_dims, dst_ndims);
282
- utils::apply_mask_on_dims (l_dims_po, dst_ndims, mask);
283
- }
284
-
285
- inline void l_dims_by_l_offset (dims_t dims_pos, dim_t l_offset,
286
- const xpu::sycl::md_t ::dims32_t &dims, const dim_t &ndims) const {
287
- for (dim_t rd = 0 ; rd < ndims; ++rd) {
288
- const dim_t d = ndims - 1 - rd;
289
- /* switch to faster 32-bit division when possible. */
290
- if (l_offset <= INT32_MAX && dims[d] <= INT32_MAX) {
291
- dims_pos[d] = (int32_t )l_offset % (int32_t )dims[d];
292
- l_offset = (int32_t )l_offset / (int32_t )dims[d];
293
- } else {
294
- dims_pos[d] = l_offset % dims[d];
295
- l_offset /= dims[d];
296
- }
297
- }
298
- }
299
-
300
204
sycl_eltwise_conf_t conf_;
301
205
xpu::sycl::in_memory_arg_t src_;
302
- xpu::sycl::in_memory_arg_t srcOp1_;
303
- xpu::sycl::in_memory_arg_t srcOp2_;
304
- xpu::sycl::in_memory_arg_t srcOp3_;
305
- xpu::sycl::in_memory_arg_t srcOp4_;
306
- xpu::sycl::in_memory_arg_t srcOp5_;
206
+ post_op_input_args po_args_;
307
207
xpu::sycl::out_memory_arg_t dst_;
308
208
};
309
209
@@ -342,10 +242,6 @@ struct eltwise_bwd_kernel_vec_t {
342
242
}
343
243
344
244
private:
345
- const xpu::sycl::md_t &src_md () const { return conf_.src_md ; }
346
- const xpu::sycl::md_t &diff_src_md () const { return conf_.diff_src_md ; }
347
- const xpu::sycl::md_t &diff_dst_md () const { return conf_.diff_dst_md ; }
348
-
349
245
inline float compute_alg_n (const float &dd, const float &s,
350
246
const float &alpha, const float &beta,
351
247
const alg_kind_t &alg) const {
0 commit comments