@@ -36,14 +36,24 @@ struct binary_kernel_vec_t {
36
36
xpu::sycl::in_memory_arg_t &src0, xpu::sycl::in_memory_arg_t &src1,
37
37
xpu::sycl::out_memory_arg_t &dst,
38
38
xpu::sycl::in_memory_arg_t &src0_scale,
39
- xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt)
39
+ xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt,
40
+ xpu::sycl::in_memory_arg_t &po1_src,
41
+ xpu::sycl::in_memory_arg_t &po2_src,
42
+ xpu::sycl::in_memory_arg_t &po3_src,
43
+ xpu::sycl::in_memory_arg_t &po4_src,
44
+ xpu::sycl::in_memory_arg_t &po5_src)
40
45
: conf_(conf)
41
46
, src0_(src0)
42
47
, src1_(src1)
43
48
, dst_(dst)
44
49
, src0_scale_(src0_scale)
45
50
, src1_scale_(src1_scale)
46
- , scales_dt_(scales_dt) {}
51
+ , scales_dt_(scales_dt)
52
+ , po1_src_(po1_src)
53
+ , po2_src_(po2_src)
54
+ , po3_src_(po3_src)
55
+ , po4_src_(po4_src)
56
+ , po5_src_(po5_src) {}
47
57
48
58
void operator ()(::sycl::nd_item<1 > item) const {
49
59
auto sg = item.get_sub_group ();
@@ -73,7 +83,7 @@ struct binary_kernel_vec_t {
73
83
any_broadcast |= conf_.broadcast_dims [i];
74
84
}
75
85
}
76
- if (!any_broadcast
86
+ if (!any_broadcast && conf_. post_ops . get_post_op () == 0
77
87
&& sg_base_idx + (sg.get_local_range ()[0 ] * conf_.block_size )
78
88
< conf_.wk_size ) {
79
89
for (int i = 0 ; i < conf_.block_size / vec_len; i++) {
@@ -123,7 +133,8 @@ struct binary_kernel_vec_t {
123
133
if (conf_.do_scale_src1 ) src1 *= sm_1;
124
134
125
135
auto acc = compute_alg_n (src0, src1, conf_.alg_kind );
126
- acc = conf_.post_ops .apply (acc, dst);
136
+ ::sycl::vec<float , 16 > post_po_sr = post_op_src_val (idx);
137
+ acc = conf_.post_ops .apply (acc, dst, post_po_sr);
127
138
store_float_value (
128
139
dst_md ().data_type (), acc, dst_ptr (), idx);
129
140
}
@@ -146,6 +157,93 @@ struct binary_kernel_vec_t {
146
157
return static_cast <float *>(src1_scale_.get_pointer ());
147
158
}
148
159
160
+ inline ::sycl::vec<float , 16 > post_op_src_val (dim_t data_l_off) const {
161
+ ::sycl::vec<float , 16 > post_po_sr;
162
+ const auto maxPostPo = conf_.post_ops .get_post_op ();
163
+
164
+ for (dim_t po_idx = 0 ; po_idx < maxPostPo; po_idx++) {
165
+ float res = 0 .0f ;
166
+ if (po_idx == 0 )
167
+ res = get_post_op_val (po1_src_, po_idx, data_l_off);
168
+ else if (po_idx == 1 )
169
+ res = get_post_op_val (po2_src_, po_idx, data_l_off);
170
+ else if (po_idx == 2 )
171
+ res = get_post_op_val (po3_src_, po_idx, data_l_off);
172
+ else if (po_idx == 3 )
173
+ res = get_post_op_val (po4_src_, po_idx, data_l_off);
174
+ else if (po_idx == 4 )
175
+ res = get_post_op_val (po5_src_, po_idx, data_l_off);
176
+
177
+ post_po_sr[po_idx] = res;
178
+ }
179
+ return post_po_sr;
180
+ }
181
+
182
+ float get_post_op_val (const xpu::sycl::in_memory_arg_t &bin_src_op,
183
+ dim_t &idx, dim_t offset) const {
184
+ auto src1_desc = conf_.binary_src_arr [idx];
185
+
186
+ const auto off = get_binary_src1_off (
187
+ src1_desc, offset, dst_md ().dims (), dst_md ().ndims ());
188
+
189
+ auto dst = load_float_value (
190
+ src1_desc.data_type (), bin_src_op.get_pointer (), off);
191
+ return dst;
192
+ }
193
+
194
+ dim_t get_binary_src1_off (const xpu::sycl::md_t &src1_md, dim_t l_offset,
195
+ const xpu::sycl::md_t ::dims32_t &dst_dims,
196
+ const xpu::sycl::md_t ::dim32_t &dst_ndims) const {
197
+ const dim_t mask_binary_po
198
+ = get_dims_mask (dst_dims, src1_md.dims (), dst_ndims);
199
+ return get_po_tensor_off (
200
+ src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po);
201
+ }
202
+
203
+ inline dim_t get_dims_mask (const xpu::sycl::md_t ::dims32_t &dims1,
204
+ const xpu::sycl::md_t ::dims32_t &dims2, const dim_t &ndims,
205
+ bool skip_dim_of_one = false ) const {
206
+ dim_t mask = 0 ;
207
+ for (dim_t d = 0 ; d < ndims; ++d) {
208
+ // Disable mask_bit for dimensions of `1` by request.
209
+ dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d);
210
+ mask += dims1[d] == dims2[d] ? mask_bit : 0 ;
211
+ }
212
+ return mask;
213
+ }
214
+
215
+ inline dim_t get_po_tensor_off (const xpu::sycl::md_t &tensor_md,
216
+ dim_t l_offset, const xpu::sycl::md_t ::dims32_t &dst_dims,
217
+ const dim_t &dst_ndims, const dim_t &mask) const {
218
+ dims_t l_dims_po {};
219
+ get_l_dims_po (l_dims_po, l_offset, dst_dims, dst_ndims, mask);
220
+
221
+ return tensor_md.off_v (l_dims_po);
222
+ }
223
+
224
+ inline void get_l_dims_po (dims_t l_dims_po, dim_t l_offset,
225
+ const xpu::sycl::md_t ::dims32_t &dst_dims, const dim_t &dst_ndims,
226
+ const dim_t &mask) const {
227
+
228
+ l_dims_by_l_offset (l_dims_po, l_offset, dst_dims, dst_ndims);
229
+ utils::apply_mask_on_dims (l_dims_po, dst_ndims, mask);
230
+ }
231
+
232
+ inline void l_dims_by_l_offset (dims_t dims_pos, dim_t l_offset,
233
+ const xpu::sycl::md_t ::dims32_t &dims, const dim_t &ndims) const {
234
+ for (dim_t rd = 0 ; rd < ndims; ++rd) {
235
+ const dim_t d = ndims - 1 - rd;
236
+ /* switch to faster 32-bit division when possible. */
237
+ if (l_offset <= INT32_MAX && dims[d] <= INT32_MAX) {
238
+ dims_pos[d] = (int32_t )l_offset % (int32_t )dims[d];
239
+ l_offset = (int32_t )l_offset / (int32_t )dims[d];
240
+ } else {
241
+ dims_pos[d] = l_offset % dims[d];
242
+ l_offset /= dims[d];
243
+ }
244
+ }
245
+ }
246
+
149
247
template <int width>
150
248
::sycl::vec<float , width> compute_alg (::sycl::vec<float , width> src0,
151
249
::sycl::vec<float , width> src1, alg_kind_t alg) const {
@@ -199,6 +297,11 @@ struct binary_kernel_vec_t {
199
297
xpu::sycl::in_memory_arg_t src0_scale_;
200
298
xpu::sycl::in_memory_arg_t src1_scale_;
201
299
data_type_t scales_dt_;
300
+ xpu::sycl::in_memory_arg_t po1_src_;
301
+ xpu::sycl::in_memory_arg_t po2_src_;
302
+ xpu::sycl::in_memory_arg_t po3_src_;
303
+ xpu::sycl::in_memory_arg_t po4_src_;
304
+ xpu::sycl::in_memory_arg_t po5_src_;
202
305
};
203
306
204
307
} // namespace sycl
0 commit comments