@@ -43,10 +43,10 @@ static bcast_set_t get_supported_bcast_strategies() {
43
43
}
44
44
45
45
template <cpu_isa_t isa>
46
- jit_uni_pool_kernel <isa>::~jit_uni_pool_kernel () = default ;
46
+ jit_uni_pool_kernel_t <isa>::~jit_uni_pool_kernel_t () = default ;
47
47
48
48
template <cpu_isa_t isa>
49
- jit_uni_pool_kernel <isa>::jit_uni_pool_kernel (
49
+ jit_uni_pool_kernel_t <isa>::jit_uni_pool_kernel_t (
50
50
const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md)
51
51
: jit_generator_t (jit_name(), isa), jpp(ajpp) {
52
52
@@ -161,7 +161,7 @@ static status_t set_binary_postops_formats(
161
161
}
162
162
163
163
template <cpu_isa_t isa>
164
- bool jit_uni_pool_kernel <isa>::has_large_buffers(const pooling_pd_t *ppd) {
164
+ bool jit_uni_pool_kernel_t <isa>::has_large_buffers(const pooling_pd_t *ppd) {
165
165
auto is_large = [](const memory_desc_t &md) {
166
166
memory_desc_wrapper mdw (md);
167
167
return mdw.size ()
@@ -180,7 +180,7 @@ bool jit_uni_pool_kernel<isa>::has_large_buffers(const pooling_pd_t *ppd) {
180
180
}
181
181
182
182
template <cpu_isa_t isa>
183
- status_t jit_uni_pool_kernel <isa>::init_conf(
183
+ status_t jit_uni_pool_kernel_t <isa>::init_conf(
184
184
jit_pool_conf_t &jpp, primitive_attr_t &attr, const pooling_pd_t *ppd) {
185
185
186
186
const auto &pd = *ppd->desc ();
@@ -479,7 +479,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(
479
479
}
480
480
481
481
template <cpu_isa_t isa>
482
- void jit_uni_pool_kernel <isa>::init_scratchpad(
482
+ void jit_uni_pool_kernel_t <isa>::init_scratchpad(
483
483
const jit_pool_conf_t &jpp, memory_tracking::registrar_t &scratchpad) {
484
484
485
485
// scratchpad for c_block slice of input and/or output
@@ -510,43 +510,43 @@ static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept {
510
510
};
511
511
512
512
template <cpu_isa_t isa>
513
- inline void jit_uni_pool_kernel <isa>::put_one_in_vmm() {
513
+ inline void jit_uni_pool_kernel_t <isa>::put_one_in_vmm() {
514
514
mov (tmp_gpr, 1 );
515
515
uni_broadcast_reg_val (tmp_gpr.getIdx (), vmm_one.getIdx ());
516
516
}
517
517
518
518
template <cpu_isa_t isa>
519
- inline void jit_uni_pool_kernel <isa>::uni_broadcast_reg_val(
519
+ inline void jit_uni_pool_kernel_t <isa>::uni_broadcast_reg_val(
520
520
const int reg_idx, const int vmm_idx) {
521
521
uni_vmovq (Xmm (vmm_idx), reg64_t (reg_idx));
522
522
uni_vpbroadcastd (Vmm (vmm_idx), Xmm (vmm_idx));
523
523
}
524
524
525
525
template <cpu_isa_t isa>
526
- inline void jit_uni_pool_kernel <isa>::push_vmm_val(const int idx) {
526
+ inline void jit_uni_pool_kernel_t <isa>::push_vmm_val(const int idx) {
527
527
Vmm val_to_store (idx);
528
528
sub (rsp, val_to_store.getBit ());
529
529
uni_vmovups (ptr[rsp], val_to_store);
530
530
}
531
531
532
532
template <cpu_isa_t isa>
533
- inline void jit_uni_pool_kernel <isa>::pop_vmm_val(const int idx) {
533
+ inline void jit_uni_pool_kernel_t <isa>::pop_vmm_val(const int idx) {
534
534
Vmm val_to_load (idx);
535
535
uni_vmovups (val_to_load, ptr[rsp]);
536
536
add (rsp, val_to_load.getBit ());
537
537
}
538
538
539
539
template <cpu_isa_t isa>
540
- inline void jit_uni_pool_kernel <isa>::load(const data_type_t dt, const int idx ,
541
- const reg64_t ®_ptr, const int offset,
540
+ inline void jit_uni_pool_kernel_t <isa>::load(const data_type_t dt,
541
+ const int idx, const reg64_t ®_ptr, const int offset,
542
542
const bool is_c_tail_proccessing) {
543
543
io_[dt]->load (vmmword[reg_ptr + offset], Vmm (idx),
544
544
is_c_tail_proccessing && !jpp.is_c_padded );
545
545
}
546
546
547
547
template <cpu_isa_t isa>
548
- inline void jit_uni_pool_kernel <isa>::store(const data_type_t dt, const int idx ,
549
- const reg64_t ®_ptr, const int offset,
548
+ inline void jit_uni_pool_kernel_t <isa>::store(const data_type_t dt,
549
+ const int idx, const reg64_t ®_ptr, const int offset,
550
550
const bool is_c_tail_proccessing) {
551
551
if (is_c_tail_proccessing && jpp.is_c_padded && jpp.with_postops )
552
552
pad_with_zeros (idx);
@@ -555,7 +555,7 @@ inline void jit_uni_pool_kernel<isa>::store(const data_type_t dt, const int idx,
555
555
}
556
556
557
557
template <cpu_isa_t isa>
558
- inline void jit_uni_pool_kernel <isa>::pad_with_zeros(const int idx) {
558
+ inline void jit_uni_pool_kernel_t <isa>::pad_with_zeros(const int idx) {
559
559
if (isa == sse41) {
560
560
uni_vxorps (xmm_tmp_1, xmm_tmp_1, xmm_tmp_1);
561
561
if (jpp.c_tail <= sse41_single_block_size && sse_high_half) {
@@ -575,7 +575,7 @@ inline void jit_uni_pool_kernel<isa>::pad_with_zeros(const int idx) {
575
575
}
576
576
577
577
template <cpu_isa_t isa>
578
- inline void jit_uni_pool_kernel <isa>::load_indices(
578
+ inline void jit_uni_pool_kernel_t <isa>::load_indices(
579
579
const int indr_i, const int step_index, bool is_c_tail_processing) {
580
580
if (jpp.ind_dt == data_type::u8) {
581
581
auto indvr = vreg (indr_i);
@@ -619,7 +619,7 @@ inline void jit_uni_pool_kernel<isa>::load_indices(
619
619
}
620
620
621
621
template <cpu_isa_t isa>
622
- inline void jit_uni_pool_kernel <isa>::store_indices(const int indr_i,
622
+ inline void jit_uni_pool_kernel_t <isa>::store_indices(const int indr_i,
623
623
const int step_index, const bool is_c_tail_processing,
624
624
const bool is_first_w_block) {
625
625
if (jpp.ind_dt == data_type::u8) {
@@ -717,7 +717,7 @@ inline void jit_uni_pool_kernel<isa>::store_indices(const int indr_i,
717
717
}
718
718
719
719
template <cpu_isa_t isa>
720
- bool jit_uni_pool_kernel <isa>::post_ops_ok(jit_pool_conf_t &jpp,
720
+ bool jit_uni_pool_kernel_t <isa>::post_ops_ok(jit_pool_conf_t &jpp,
721
721
const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) {
722
722
const auto &post_ops = attr.post_ops_ ;
723
723
const auto &entries = post_ops.entry_ ;
@@ -757,7 +757,7 @@ bool jit_uni_pool_kernel<isa>::post_ops_ok(jit_pool_conf_t &jpp,
757
757
}
758
758
759
759
template <cpu_isa_t isa>
760
- void jit_uni_pool_kernel <isa>::apply_postops(int ur_bc, int ur_w, int c_block,
760
+ void jit_uni_pool_kernel_t <isa>::apply_postops(int ur_bc, int ur_w, int c_block,
761
761
const std::function<bool (int , bool )> &is_tail_predicate) {
762
762
binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
763
763
const int end_idx = vmm_idx_upper_bound () + 1 ;
@@ -803,7 +803,7 @@ void jit_uni_pool_kernel<isa>::apply_postops(int ur_bc, int ur_w, int c_block,
803
803
}
804
804
805
805
template <cpu_isa_t isa>
806
- inline void jit_uni_pool_kernel <isa>::maybe_recalculate_divisor(
806
+ inline void jit_uni_pool_kernel_t <isa>::maybe_recalculate_divisor(
807
807
int jj, int ur_w, int pad_l, int pad_r, bool with_c_tail_proccessing) {
808
808
if (jpp.alg == pooling_avg_exclude_padding) {
809
809
int kw = jpp.kw ;
@@ -834,7 +834,7 @@ inline void jit_uni_pool_kernel<isa>::maybe_recalculate_divisor(
834
834
}
835
835
836
836
template <cpu_isa_t isa>
837
- inline void jit_uni_pool_kernel <isa>::avg_step(int ur_w, int ur_bc, int pad_l,
837
+ inline void jit_uni_pool_kernel_t <isa>::avg_step(int ur_w, int ur_bc, int pad_l,
838
838
int pad_r, bool with_c_tail_proccessing) {
839
839
840
840
auto iw = jpp.iw ;
@@ -959,7 +959,7 @@ inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
959
959
}
960
960
961
961
template <cpu_isa_t isa>
962
- inline void jit_uni_pool_kernel <isa>::max_step_fwd(int ur_w, int ur_bc,
962
+ inline void jit_uni_pool_kernel_t <isa>::max_step_fwd(int ur_w, int ur_bc,
963
963
int pad_l, int pad_r, bool with_c_tail_proccessing) {
964
964
int iw = jpp.iw ;
965
965
int kw = jpp.kw ;
@@ -1115,7 +1115,7 @@ inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int ur_bc,
1115
1115
}
1116
1116
1117
1117
template <cpu_isa_t isa>
1118
- inline void jit_uni_pool_kernel <isa>::max_step_bwd(int ur_w, int ur_bc,
1118
+ inline void jit_uni_pool_kernel_t <isa>::max_step_bwd(int ur_w, int ur_bc,
1119
1119
int pad_l, int pad_r, bool with_c_tail_proccessing) {
1120
1120
1121
1121
int iw = jpp.iw ;
@@ -1264,7 +1264,7 @@ inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int ur_bc,
1264
1264
}
1265
1265
1266
1266
template <cpu_isa_t isa>
1267
- void jit_uni_pool_kernel <isa>::zero_diff_src(
1267
+ void jit_uni_pool_kernel_t <isa>::zero_diff_src(
1268
1268
int ur_bc, bool with_c_tail_proccessing) {
1269
1269
const int c_off = jpp.needs_f32_accum_for_bf16
1270
1270
? jpp.f32_accum_block_size
@@ -1345,7 +1345,7 @@ void jit_uni_pool_kernel<isa>::zero_diff_src(
1345
1345
}
1346
1346
1347
1347
template <cpu_isa_t isa>
1348
- void jit_uni_pool_kernel <isa>::generate() {
1348
+ void jit_uni_pool_kernel_t <isa>::generate() {
1349
1349
1350
1350
this ->preamble ();
1351
1351
@@ -1553,12 +1553,12 @@ void jit_uni_pool_kernel<isa>::generate() {
1553
1553
io_.prepare_table_fp8 ();
1554
1554
}
1555
1555
1556
- template struct jit_uni_pool_kernel <sse41>;
1557
- template struct jit_uni_pool_kernel <avx>;
1558
- template struct jit_uni_pool_kernel <avx2>;
1559
- template struct jit_uni_pool_kernel <avx2_vnni_2>;
1560
- template struct jit_uni_pool_kernel <avx512_core>;
1561
- template struct jit_uni_pool_kernel <avx512_core_fp16>;
1556
+ template struct jit_uni_pool_kernel_t <sse41>;
1557
+ template struct jit_uni_pool_kernel_t <avx>;
1558
+ template struct jit_uni_pool_kernel_t <avx2>;
1559
+ template struct jit_uni_pool_kernel_t <avx2_vnni_2>;
1560
+ template struct jit_uni_pool_kernel_t <avx512_core>;
1561
+ template struct jit_uni_pool_kernel_t <avx512_core_fp16>;
1562
1562
1563
1563
} // namespace x64
1564
1564
} // namespace cpu
0 commit comments