@@ -323,6 +323,36 @@ struct matmul_kernel_fwd_t {
323
323
}
324
324
}
325
325
}
326
+
327
+ void apply_post_ops_edge (sycl_post_ops_t post_ops,
328
+ register_block<Rows, Cols> prev_dst, dims_t off_po, int dim1,
329
+ const matmul_kernel_fwd_t *kernel, int rows, int cols) {
330
+ for (int row = 0 ; row < rows; row++) {
331
+ int col;
332
+ for (col = 0 ; col < cols / vec_len; col++) {
333
+ for (int v_el = 0 ; v_el < vec_len; v_el++) {
334
+ off_po[dim1] += row;
335
+ off_po[dim1 + 1 ] += col * vec_len + v_el;
336
+ data[row][col][v_el]
337
+ = post_ops.apply (data[row][col][v_el],
338
+ prev_dst.data [row][col][v_el],
339
+ kernel->po_args_ , off_po);
340
+ off_po[dim1] -= row;
341
+ off_po[dim1 + 1 ] -= col * vec_len + v_el;
342
+ }
343
+ }
344
+ int n_remaining = cols - col * vec_len;
345
+ for (int v_el = 0 ; v_el < n_remaining; v_el++) {
346
+ off_po[dim1] += row;
347
+ off_po[dim1 + 1 ] += col * vec_len + v_el;
348
+ data[row][col][v_el] = post_ops.apply (data[row][col][v_el],
349
+ prev_dst.data [row][col][v_el], kernel->po_args_ ,
350
+ off_po);
351
+ off_po[dim1] -= row;
352
+ off_po[dim1 + 1 ] -= col * vec_len + v_el;
353
+ }
354
+ }
355
+ }
326
356
};
327
357
328
358
matmul_kernel_fwd_t (const sycl_matmul_conf_t &conf, ::sycl::handler &cgh,
@@ -377,7 +407,7 @@ struct matmul_kernel_fwd_t {
377
407
, dropout_seed_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_ATTR_DROPOUT_SEED))
378
408
, dropout_probability_(
379
409
CTX_IN_SYCL_KERNEL_MEMORY (DNNL_ARG_ATTR_DROPOUT_PROBABILITY))
380
- , po_args_(cgh, ctx) {}
410
+ , po_args_(cgh, ctx, conf_.post_ops ) {}
381
411
382
412
void operator ()(::sycl::nd_item<1 > item) const {
383
413
using data_block_t = register_block<register_block_M, register_block_K>;
@@ -597,8 +627,13 @@ struct matmul_kernel_fwd_t {
597
627
if (conf_.transpose_dst ) {
598
628
std::swap (off_po[matmul_dim_1], off_po[matmul_dim_2]);
599
629
}
600
- dst_block.apply_post_ops (
601
- conf_.post_ops , prev_dst, off_po, matmul_dim_1, this );
630
+ if (is_dst_edge_block) {
631
+ dst_block.apply_post_ops_edge (conf_.post_ops , prev_dst, off_po,
632
+ matmul_dim_1, this , remaining_m, remaining_n);
633
+ } else {
634
+ dst_block.apply_post_ops (
635
+ conf_.post_ops , prev_dst, off_po, matmul_dim_1, this );
636
+ }
602
637
603
638
if (conf_.do_scale_dst ) {
604
639
dst_block.eltwise ([=](float &el) { el /= dst_scale; });
0 commit comments