1
1
/* ******************************************************************************
2
- * Copyright 2021-2022 Intel Corporation
3
- * Copyright 2022 FUJITSU LIMITED
2
+ * Copyright 2021-2024 Intel Corporation
3
+ * Copyright 2022-2024 FUJITSU LIMITED
4
4
*
5
5
* Licensed under the Apache License, Version 2.0 (the "License");
6
6
* you may not use this file except in compliance with the License.
@@ -47,9 +47,12 @@ jit_uni_shuffle_kernel_t<isa>::jit_uni_shuffle_kernel_t(
47
47
template <cpu_isa_t isa>
48
48
void jit_uni_shuffle_kernel_t <isa>::prepare_mask() {
49
49
using namespace data_type ;
50
+ using namespace types ;
50
51
if (conf_.simd_tail > 0 ) {
51
- assert (utils::one_of (conf_.data_type , f32, s32));
52
- assert (conf_.simd_tail < isa_sveLen / sizeof (float ));
52
+ /* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, UXTW #1]" is used
53
+ to gather data for bf16, simd_tail must be evaluated
54
+ with sizeof(unsigned). */
55
+ assert (conf_.simd_tail < isa_sveLen / sizeof (uint32_t ));
53
56
index (vmm_tmp_.s , 0 , 1 );
54
57
cmplt (k_tail_mask_.s , P_ALL_ONE / T_z, vmm_tmp_.s , conf_.simd_tail );
55
58
}
@@ -68,13 +71,17 @@ void jit_uni_shuffle_kernel_t<asimd>::prepare_mask() {}
68
71
template <cpu_isa_t isa>
69
72
void jit_uni_shuffle_kernel_t <isa>::gather_data(const XReg ®_src_addr,
70
73
const int indices_idx, const int data_idx, const bool is_tail) {
71
- if (conf_.dt_size == sizeof (float )) {
72
- const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
74
+ using namespace data_type ;
75
+ const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
76
+
77
+ if (utils::one_of (conf_.data_type , f32, s32)) {
73
78
lsr (TRegS (indices_idx), TRegS (indices_idx), 2 );
74
79
ld1w (TRegS (data_idx), mask / T_z,
75
80
ptr (reg_src_addr, TRegS (indices_idx), UXTW, 2 ));
76
- } else {
77
- assert (!" unsupported emu_gather_data" );
81
+ } else if (conf_.data_type == bf16) {
82
+ lsr (TRegS (indices_idx), TRegS (indices_idx), 1 );
83
+ ld1h (TRegS (data_idx), mask / T_z,
84
+ ptr (reg_src_addr, TRegS (indices_idx), UXTW, 1 ));
78
85
}
79
86
}
80
87
@@ -97,21 +104,26 @@ void jit_uni_shuffle_kernel_t<asimd>::gather_data(const XReg &addr,
97
104
template <cpu_isa_t isa>
98
105
void jit_uni_shuffle_kernel_t <isa>::store_data(const int data_idx,
99
106
const XReg ®_dst_addr, const int offset, const bool is_tail) {
107
+ using namespace data_type ;
100
108
const auto extend_for_padding
101
109
= is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w ;
110
+ const PReg &mask = is_tail ? k_tail_mask_ : P_ALL_ONE;
111
+
112
+ add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
113
+
102
114
if (extend_for_padding) {
103
115
sel (vmm_tmp_.s , k_tail_mask_, TRegS (data_idx), vmm_zero_.s );
104
- add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
105
- st1w (vmm_tmp_.s , P_ALL_ONE, ptr (X_DEFAULT_ADDR));
116
+ if (utils::one_of (conf_.data_type , f32, s32))
117
+ st1w (vmm_tmp_.s , P_ALL_ONE, ptr (X_DEFAULT_ADDR));
118
+ else // bf16
119
+ st1h (vmm_tmp_.s , P_ALL_ONE, ptr (X_DEFAULT_ADDR));
106
120
} else {
107
- if (is_tail) {
108
- add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
109
- st1w (TRegS (data_idx), k_tail_mask_, ptr (X_DEFAULT_ADDR));
110
- } else {
111
- add_imm (X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
112
- st1w (TRegS (data_idx), P_ALL_ONE, ptr (X_DEFAULT_ADDR));
113
- }
121
+ if (utils::one_of (conf_.data_type , f32, s32))
122
+ st1w (TRegS (data_idx), mask, ptr (X_DEFAULT_ADDR));
123
+ else // bf16
124
+ st1h (TRegS (data_idx), mask, ptr (X_DEFAULT_ADDR));
114
125
}
126
+
115
127
append_zero_padding (
116
128
reg_dst_, isa_sveLen > 128 ? extend_for_padding : false );
117
129
}
0 commit comments