@@ -161,14 +161,20 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
161
161
static bool applicable (const prb_t &p) {
162
162
using namespace data_type ;
163
163
164
+ bool bf16_ok
165
+ = (mayiuse_bf16 () && (p.itype == bf16) && (p.otype == bf16)
166
+ && !interim_f32_needed (p, false ) && p.beta == 0 .f )
167
+ || (p.itype != bf16 && p.otype != bf16)
168
+ || (p.itype == f32 && p.otype == bf16 && mayiuse_bf16 ()
169
+ && p.beta == 0 .f );
170
+
164
171
bool ok = true && p.ndims > 0
165
- && utils::one_of (p.itype , f32, s32, data_type::s8, u8)
172
+ && utils::one_of (p.itype , f32, bf16, s32, data_type::s8, u8)
166
173
&& utils::one_of (p.otype , f32, bf16, s32, data_type::s8, u8)
167
174
&& utils::everyone_is (0 , p.ioff , p.ooff ) /* do we need this? */
168
175
&& utils::one_of (p.beta , 0 .f , 1 .f ) /* anything else? */
169
176
&& simple_impl_desc_init (p, nullptr ) && prb_has_small_strides (p)
170
- && IMPLICATION (
171
- p.otype == bf16, p.itype == f32 && mayiuse_bf16 ());
177
+ && bf16_ok;
172
178
173
179
return ok;
174
180
}
@@ -702,7 +708,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
702
708
const int load_tail_step
703
709
= !can_load_xmm && can_store_xmm ? ur_step : load_step;
704
710
705
- const bool interim_f32 = interim_f32_needed ();
711
+ const bool interim_f32 = interim_f32_needed (prb_, compensation_needed_ );
706
712
707
713
const bool need_saturation
708
714
= (utils::one_of (prb_.otype , u8, data_type::s8, s32)
@@ -1285,17 +1291,17 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
1285
1291
}
1286
1292
}
1287
1293
1288
- bool interim_f32_needed () {
1294
+ static bool interim_f32_needed (const prb_t &prb, bool compensation_needed ) {
1289
1295
using namespace data_type ;
1290
-
1291
- return utils::one_of (f32, prb_.itype , prb_.otype )
1292
- || prb_.src_scale_type != scale_type_t ::NONE
1293
- || prb_.dst_scale_type != scale_type_t ::NONE || prb_.beta != 0 .f
1294
- || ((prb_.req_src_zp || prb_.req_dst_zp )
1295
- ? !(prb_.itype == s32 && prb_.otype == s32)
1296
+ bool ret = utils::one_of (f32, prb.itype , prb.otype )
1297
+ || prb.src_scale_type != scale_type_t ::NONE
1298
+ || prb.dst_scale_type != scale_type_t ::NONE || prb.beta != 0 .f
1299
+ || ((prb.req_src_zp || prb.req_dst_zp )
1300
+ ? !(prb.itype == s32 && prb.otype == s32)
1296
1301
: false )
1297
- || (prb_.itype != f32 && compensation_needed_)
1298
- || prb_.scale_adjust != 1 .f ;
1302
+ || (prb.itype != f32 && compensation_needed)
1303
+ || prb.scale_adjust != 1 .f ;
1304
+ return ret;
1299
1305
}
1300
1306
1301
1307
void process_unroll_generic (
@@ -1313,7 +1319,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
1313
1319
1314
1320
int curr = 0 ; // will switch between 0 and 1
1315
1321
1316
- const bool interim_f32 = interim_f32_needed ();
1322
+ const bool interim_f32 = interim_f32_needed (prb_, compensation_needed_ );
1317
1323
1318
1324
if (prb_.req_src_zp ) {
1319
1325
add_imm (X_DEFAULT_ADDR, PARAM (src_zp), X_TMP_0);
0 commit comments