1
1
/* ******************************************************************************
2
2
* Copyright 2021-2023 Intel Corporation
3
- * Copyright 2024 FUJITSU LIMITED
3
+ * Copyright 2024-2025 FUJITSU LIMITED
4
4
*
5
5
* Licensed under the Apache License, Version 2.0 (the "License");
6
6
* you may not use this file except in compliance with the License.
@@ -43,8 +43,8 @@ using namespace jit_uni_brgemm_conv_comp_pad_kernel;
43
43
#define ndims_pick (v5, v4, v3 ) \
44
44
((ndims == 5 ) ? (v5) : (ndims == 4 ) ? (v4) : (ndims == 3 ) ? (v3) : 0 )
45
45
46
- template <cpu_isa_t isa, bool use_inversion >
47
- void brgemm_convolution_fwd_t <isa, use_inversion >::pd_t ::init_batch(int icc,
46
+ template <cpu_isa_t isa>
47
+ void brgemm_convolution_fwd_t <isa>::pd_t ::init_batch(int icc,
48
48
const char *src_base, const char *wei_base, int n_ic_blocks,
49
49
int ic_block_s, int iid_b, int iih_b, int iiw_b,
50
50
const dim_t *const __restrict kw_top_vpads,
@@ -117,8 +117,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init_batch(int icc,
117
117
}
118
118
}
119
119
120
- template <cpu_isa_t isa, bool use_inversion >
121
- inline void brgemm_convolution_fwd_t <isa, use_inversion >::pd_t ::get_A_B(int icc,
120
+ template <cpu_isa_t isa>
121
+ inline void brgemm_convolution_fwd_t <isa>::pd_t ::get_A_B(int icc,
122
122
const char *src_base, const char *wei_base, int ic_block_s, int iid_b,
123
123
int iih_b, int iiw_b, int kd_b, int kh_b, const void *&ptrA,
124
124
const void *&ptrB) const {
@@ -147,10 +147,9 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::get_A_B(int icc,
147
147
ptrB = wei_base_kh + wei_kw * wei_kw_offset;
148
148
}
149
149
150
- template <cpu_isa_t isa, bool use_inversion>
151
- status_t brgemm_convolution_fwd_t <isa, use_inversion>::pd_t ::add_brg_descriptor(
152
- int vM, int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b,
153
- int kh_e) {
150
+ template <cpu_isa_t isa>
151
+ status_t brgemm_convolution_fwd_t <isa>::pd_t ::add_brg_descriptor(int vM,
152
+ int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
154
153
155
154
const auto src_type = src_md (0 )->data_type ;
156
155
const auto wei_type = weights_md (0 )->data_type ;
@@ -287,9 +286,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::add_brg_descriptor(
287
286
return status::success;
288
287
}
289
288
290
- template <cpu_isa_t isa, bool use_inversion>
291
- status_t brgemm_convolution_fwd_t <isa, use_inversion>::pd_t ::init(
292
- engine_t *engine) {
289
+ template <cpu_isa_t isa>
290
+ status_t brgemm_convolution_fwd_t <isa>::pd_t ::init(engine_t *engine) {
293
291
using namespace data_type ;
294
292
using namespace utils ;
295
293
brgemm_descriptors_
@@ -306,7 +304,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
306
304
// executing 'use_inversion == true' as FWD. This can only work if the
307
305
// diff_src_desc and diff_dst_desc are defined in the aforementioned.
308
306
const convolution_desc_t &cd = *desc ();
309
- if (use_inversion
307
+ if (cd. use_inversion
310
308
&& one_of (true , types::is_zero_md (&cd.diff_src_desc ),
311
309
types::is_zero_md (&cd.diff_dst_desc )))
312
310
return status::unimplemented;
@@ -336,6 +334,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
336
334
// For exec_base it makes sense to use unrolled kernel only if
337
335
// there is no padding by width.
338
336
// 2. For exec_trans block by kw is always KW
337
+ // 3. 'false' is used intentionally to disable the condition, ensuring that
338
+ // the assert fails only when jcp_.use_uker is true, regardless of exec_type.
339
339
assert (IMPLICATION (jcp_.use_uker ,
340
340
false && one_of (jcp_.exec_type , exec_base, exec_trans)));
341
341
assert (IMPLICATION (jcp_.use_interleave_stores , jcp_.use_uker ));
@@ -535,13 +535,12 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
535
535
return status::success;
536
536
}
537
537
538
- template <cpu_isa_t isa, bool use_inversion>
539
- brgemm_convolution_fwd_t <isa, use_inversion>::brgemm_convolution_fwd_t (
540
- const pd_t *apd)
538
+ template <cpu_isa_t isa>
539
+ brgemm_convolution_fwd_t <isa>::brgemm_convolution_fwd_t (const pd_t *apd)
541
540
: primitive_t (apd), bias_d(pd()->weights_md (1 )) {}
542
541
543
- template <cpu_isa_t isa, bool use_inversion >
544
- void brgemm_convolution_fwd_t <isa, use_inversion >::get_kw_range(
542
+ template <cpu_isa_t isa>
543
+ void brgemm_convolution_fwd_t <isa>::get_kw_range(
545
544
int ow, int &kw_s, int &kw_full_s, int &kw_full_f, int &kw_f) const {
546
545
// This function needed for exec_base only
547
546
const auto _pd = pd ();
@@ -570,8 +569,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::get_kw_range(
570
569
if (kw_full_f == -1 ) kw_full_s = kw_full_f = kw_f;
571
570
}
572
571
573
- template <cpu_isa_t isa, bool use_inversion >
574
- inline void brgemm_convolution_fwd_t <isa, use_inversion >::get_ow_range(
572
+ template <cpu_isa_t isa>
573
+ inline void brgemm_convolution_fwd_t <isa>::get_ow_range(
575
574
int ow, int kw, int &ow_s, int &ow_f) const {
576
575
// This function needed for exec_base only
577
576
const auto _pd = pd ();
@@ -602,9 +601,9 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::get_ow_range(
602
601
ow_f = nstl::min (nstl::max (ow_f, ow_s), ow + M);
603
602
}
604
603
605
- template <cpu_isa_t isa, bool use_inversion >
606
- status_t brgemm_convolution_fwd_t <isa, use_inversion >::add_brg_kernel(int M,
607
- int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
604
+ template <cpu_isa_t isa>
605
+ status_t brgemm_convolution_fwd_t <isa>::add_brg_kernel(int M, int i_N, int i_K ,
606
+ int i_init, int kd_b, int kd_e, int kh_b, int kh_e) {
608
607
if (M <= 0 ) return status::success;
609
608
const auto _pd = pd ();
610
609
const auto &jcp = _pd->jcp_ ;
@@ -623,8 +622,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_brg_kernel(int M,
623
622
return status::success;
624
623
}
625
624
626
- template <cpu_isa_t isa, bool use_inversion >
627
- status_t brgemm_convolution_fwd_t <isa, use_inversion >::add_po_kernel(
625
+ template <cpu_isa_t isa>
626
+ status_t brgemm_convolution_fwd_t <isa>::add_po_kernel(
628
627
brgemm_t *bcfg, int ker_idx, bool is_init) {
629
628
if (!bcfg) return status::success;
630
629
const auto _pd = pd ();
@@ -641,8 +640,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
641
640
return status::success;
642
641
}
643
642
644
- template <cpu_isa_t isa, bool use_inversion >
645
- void brgemm_convolution_fwd_t <isa, use_inversion >::add_po_kernels(
643
+ template <cpu_isa_t isa>
644
+ void brgemm_convolution_fwd_t <isa>::add_po_kernels(
646
645
int i_N, int init_bcast_dim, int po_bcast_dim) {
647
646
const auto _pd = pd ();
648
647
const auto &jcp = _pd->jcp_ ;
@@ -676,10 +675,10 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernels(
676
675
}
677
676
}
678
677
}
679
- template <cpu_isa_t isa, bool use_inversion >
680
- int brgemm_convolution_fwd_t <isa, use_inversion >::get_comp_ker_idx(
681
- const int kd_b , const int kd_e , const int kh_b , const int kh_e ,
682
- const int kw_b, const int kw_e) const {
678
+ template <cpu_isa_t isa>
679
+ int brgemm_convolution_fwd_t <isa>::get_comp_ker_idx(const int kd_b,
680
+ const int kd_e , const int kh_b , const int kh_e , const int kw_b ,
681
+ const int kw_e) const {
683
682
const auto _pd = pd ();
684
683
const auto &jcp = _pd->jcp_ ;
685
684
@@ -696,11 +695,10 @@ int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_ker_idx(
696
695
return -1 ;
697
696
}
698
697
699
- template <cpu_isa_t isa, bool use_inversion>
700
- inline int brgemm_convolution_fwd_t <isa, use_inversion>::get_comp_offset(
701
- const int g, const int ocb, const int ow, const int kd_b,
702
- const int kd_e, const int kh_b, const int kh_e, const int kw_b,
703
- const int kw_e) const {
698
+ template <cpu_isa_t isa>
699
+ inline int brgemm_convolution_fwd_t <isa>::get_comp_offset(const int g,
700
+ const int ocb, const int ow, const int kd_b, const int kd_e,
701
+ const int kh_b, const int kh_e, const int kw_b, const int kw_e) const {
704
702
const auto _pd = pd ();
705
703
const auto &jcp = _pd->jcp_ ;
706
704
@@ -714,8 +712,8 @@ inline int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_offset(
714
712
: (g * jcp.nb_oc + ocb) * jcp.oc_block ;
715
713
}
716
714
717
- template <cpu_isa_t isa, bool use_inversion >
718
- status_t brgemm_convolution_fwd_t <isa, use_inversion >::init(engine_t *engine) {
715
+ template <cpu_isa_t isa>
716
+ status_t brgemm_convolution_fwd_t <isa>::init(engine_t *engine) {
719
717
720
718
const auto _pd = pd ();
721
719
const auto &jcp = _pd->jcp_ ;
@@ -1054,8 +1052,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {
1054
1052
1055
1053
return status::success;
1056
1054
}
1057
- template <cpu_isa_t isa, bool use_inversion >
1058
- struct brgemm_convolution_fwd_t <isa, use_inversion >::brgemm_thread_ctx_t {
1055
+ template <cpu_isa_t isa>
1056
+ struct brgemm_convolution_fwd_t <isa>::brgemm_thread_ctx_t {
1059
1057
brgemm_thread_ctx_t (brgemm_exec_ctx_t &brgemm_ctx_, int ithr_,
1060
1058
brgemm_batch_element_t *__restrict brg_batch_, char *c_buffer_,
1061
1059
char *wsp_tile_)
@@ -1082,9 +1080,8 @@ struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
1082
1080
const float *dst_scales {nullptr };
1083
1081
};
1084
1082
1085
- template <cpu_isa_t isa, bool use_inversion>
1086
- status_t brgemm_convolution_fwd_t <isa, use_inversion>::execute(
1087
- const exec_ctx_t &ctx) const {
1083
+ template <cpu_isa_t isa>
1084
+ status_t brgemm_convolution_fwd_t <isa>::execute(const exec_ctx_t &ctx) const {
1088
1085
const auto _pd = pd ();
1089
1086
const auto &jcp = _pd->jcp_ ;
1090
1087
@@ -1266,8 +1263,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
1266
1263
return status::success;
1267
1264
}
1268
1265
1269
- template <cpu_isa_t isa, bool use_inversion >
1270
- status_t brgemm_convolution_fwd_t <isa, use_inversion >::cal_compensation(
1266
+ template <cpu_isa_t isa>
1267
+ status_t brgemm_convolution_fwd_t <isa>::cal_compensation(
1271
1268
const char *__restrict weights, int32_t *src_zp_buffer,
1272
1269
int32_t *s8s8_comp_buffer) const {
1273
1270
const auto _pd = pd ();
@@ -1332,8 +1329,8 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::cal_compensation(
1332
1329
return status::success;
1333
1330
}
1334
1331
1335
- template <cpu_isa_t isa, bool use_inversion >
1336
- void brgemm_convolution_fwd_t <isa, use_inversion >::perform_outwork(
1332
+ template <cpu_isa_t isa>
1333
+ void brgemm_convolution_fwd_t <isa>::perform_outwork(
1337
1334
const brgemm_thread_ctx_t &btc, char *dst_base, const char *bias_w,
1338
1335
int ow, int g_oc, bool is_oc_tail, int ker_ow_s, int ker_ow_f, int kd_l,
1339
1336
int kh_l, bool maybe_do_init, bool do_postwork,
@@ -1417,8 +1414,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
1417
1414
}
1418
1415
}
1419
1416
1420
- template <cpu_isa_t isa, bool use_inversion >
1421
- inline void brgemm_convolution_fwd_t <isa, use_inversion >::call_brgemm_kernel(
1417
+ template <cpu_isa_t isa>
1418
+ inline void brgemm_convolution_fwd_t <isa>::call_brgemm_kernel(
1422
1419
const brgemm_thread_ctx_t &btc, const brgemm_kernel_t *brg_ker,
1423
1420
int batch_size, char *ptr_C, char *ptr_D, const char *bias_w, int g_oc,
1424
1421
bool do_postops, int comp_ker_offs, bool do_only_comp) const {
@@ -1467,8 +1464,8 @@ inline void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
1467
1464
ptr_C, static_cast <void *>(btc.wsp_tile ));
1468
1465
}
1469
1466
1470
- template <cpu_isa_t isa, bool use_inversion >
1471
- void brgemm_convolution_fwd_t <isa, use_inversion >::maybe_conv_inp(int ithr,
1467
+ template <cpu_isa_t isa>
1468
+ void brgemm_convolution_fwd_t <isa>::maybe_conv_inp(int ithr,
1472
1469
const char *__restrict src, char *__restrict inp_buffer,
1473
1470
uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int odb,
1474
1471
int ohb, int owb, int last_g, int last_n, int last_icc, int last_odb,
@@ -1648,9 +1645,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::maybe_conv_inp(int ithr,
1648
1645
char *ptr_D; \
1649
1646
int kd_b (0 ), kd_e(0 ), kh_b(0 ), kh_e(0 ), k_l(0 ), iiw_b(0 );
1650
1647
1651
- template <cpu_isa_t isa, bool use_inversion>
1652
- void brgemm_convolution_fwd_t <isa, use_inversion>::ker_base(
1653
- brgemm_thread_ctx_t &btc) const {
1648
+ template <cpu_isa_t isa>
1649
+ void brgemm_convolution_fwd_t <isa>::ker_base(brgemm_thread_ctx_t &btc) const {
1654
1650
1655
1651
const auto _pd = pd ();
1656
1652
const auto &jcp = _pd->jcp_ ;
@@ -1799,8 +1795,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
1799
1795
}
1800
1796
}
1801
1797
1802
- template <cpu_isa_t isa, bool use_inversion >
1803
- void brgemm_convolution_fwd_t <isa, use_inversion >::ker_trans(
1798
+ template <cpu_isa_t isa>
1799
+ void brgemm_convolution_fwd_t <isa>::ker_trans(
1804
1800
brgemm_thread_ctx_t &btc, char *inp_buffer) const {
1805
1801
1806
1802
const auto _pd = pd ();
@@ -1924,9 +1920,8 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
1924
1920
}
1925
1921
}
1926
1922
1927
- template <cpu_isa_t isa, bool use_inversion>
1928
- void brgemm_convolution_fwd_t <isa, use_inversion>::ker_vpad(
1929
- brgemm_thread_ctx_t &btc) const {
1923
+ template <cpu_isa_t isa>
1924
+ void brgemm_convolution_fwd_t <isa>::ker_vpad(brgemm_thread_ctx_t &btc) const {
1930
1925
1931
1926
const auto _pd = pd ();
1932
1927
const auto &jcp = _pd->jcp_ ;
0 commit comments