Skip to content

Commit 12bafbe

Browse files
Shreyas-fujmgouicem
authored andcommitted
cpu : aarch64 : reorder : reenabled bf16 jit uni reorders
1 parent e3782e8 commit 12bafbe

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#===============================================================================
1717

18-
build
18+
build*
1919
external
2020
.vs
2121
.vscode

src/cpu/aarch64/jit_uni_reorder.cpp

+20-14
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,20 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
161161
static bool applicable(const prb_t &p) {
162162
using namespace data_type;
163163

164+
bool bf16_ok
165+
= (mayiuse_bf16() && (p.itype == bf16) && (p.otype == bf16)
166+
&& !interim_f32_needed(p, false) && p.beta == 0.f)
167+
|| (p.itype != bf16 && p.otype != bf16)
168+
|| (p.itype == f32 && p.otype == bf16 && mayiuse_bf16()
169+
&& p.beta == 0.f);
170+
164171
bool ok = true && p.ndims > 0
165-
&& utils::one_of(p.itype, f32, s32, data_type::s8, u8)
172+
&& utils::one_of(p.itype, f32, bf16, s32, data_type::s8, u8)
166173
&& utils::one_of(p.otype, f32, bf16, s32, data_type::s8, u8)
167174
&& utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
168175
&& utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
169176
&& simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p)
170-
&& IMPLICATION(
171-
p.otype == bf16, p.itype == f32 && mayiuse_bf16());
177+
&& bf16_ok;
172178

173179
return ok;
174180
}
@@ -702,7 +708,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
702708
const int load_tail_step
703709
= !can_load_xmm && can_store_xmm ? ur_step : load_step;
704710

705-
const bool interim_f32 = interim_f32_needed();
711+
const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_);
706712

707713
const bool need_saturation
708714
= (utils::one_of(prb_.otype, u8, data_type::s8, s32)
@@ -1285,17 +1291,17 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
12851291
}
12861292
}
12871293

1288-
bool interim_f32_needed() {
1294+
static bool interim_f32_needed(const prb_t &prb, bool compensation_needed) {
12891295
using namespace data_type;
1290-
1291-
return utils::one_of(f32, prb_.itype, prb_.otype)
1292-
|| prb_.src_scale_type != scale_type_t::NONE
1293-
|| prb_.dst_scale_type != scale_type_t::NONE || prb_.beta != 0.f
1294-
|| ((prb_.req_src_zp || prb_.req_dst_zp)
1295-
? !(prb_.itype == s32 && prb_.otype == s32)
1296+
bool ret = utils::one_of(f32, prb.itype, prb.otype)
1297+
|| prb.src_scale_type != scale_type_t::NONE
1298+
|| prb.dst_scale_type != scale_type_t::NONE || prb.beta != 0.f
1299+
|| ((prb.req_src_zp || prb.req_dst_zp)
1300+
? !(prb.itype == s32 && prb.otype == s32)
12961301
: false)
1297-
|| (prb_.itype != f32 && compensation_needed_)
1298-
|| prb_.scale_adjust != 1.f;
1302+
|| (prb.itype != f32 && compensation_needed)
1303+
|| prb.scale_adjust != 1.f;
1304+
return ret;
12991305
}
13001306

13011307
void process_unroll_generic(
@@ -1313,7 +1319,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
13131319

13141320
int curr = 0; // will switch between 0 and 1
13151321

1316-
const bool interim_f32 = interim_f32_needed();
1322+
const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_);
13171323

13181324
if (prb_.req_src_zp) {
13191325
add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0);

0 commit comments

Comments
 (0)