@@ -151,7 +151,7 @@ jit_brgemm_ip_conf_t::get_desired_weights_tag() const {
151
151
const bool is_xf16 = utils::one_of (jbgp.wei_dt , bf16, f16);
152
152
const bool is_not_vnni_tag = (jbgp.wei_dt == f32
153
153
|| (jbgp.wei_dt == f16 && jbgp.isa == avx512_core_fp16)) && !jbgp.weights_decompression ;
154
- if (is_not_vnni_tag || (jbgp.weights_decompression && jbgp.orig_wei_dt == u8 && jbgp. wei_dt != bf16 )) {
154
+ if (is_not_vnni_tag || (jbgp.weights_decompression && jbgp.orig_wei_dt == u8)) {
155
155
if (is_superset (jbgp.isa , avx512_core))
156
156
return {{64 ,
157
157
pick (n_sp_dims, OI16i64o, OIw16i64o, OIhw16i64o,
@@ -176,7 +176,7 @@ jit_brgemm_ip_conf_t::get_desired_weights_tag() const {
176
176
pick (n_sp_dims, OI8i16o, OIw8i16o, OIhw8i16o,
177
177
OIdhw8i16o)},
178
178
{8 , pick (n_sp_dims, OI8i8o, OIw8i8o, OIhw8i8o, OIdhw8i8o)}};
179
- } else if (is_xf16 || (jbgp. weights_decompression && jbgp. orig_wei_dt == u8 && jbgp. wei_dt == bf16) ) {
179
+ } else if (is_xf16) {
180
180
if (jbgp.is_amx ) {
181
181
return {{64 ,
182
182
pick (n_sp_dims, OI16i64o2i, OIw16i64o2i,
@@ -374,10 +374,11 @@ int jit_brgemm_ip_conf_t::get_adjusted_oc_block() const {
374
374
// time for weights reorder are key optimization points there.
375
375
const size_t wei_size = static_cast <size_t >(jbgp.ic * jbgp.oc ) * types::data_type_size (jbgp.wei_dt );
376
376
// Use oc block to be 32 if weight size >= 8MB on amx bf16 to optimized memory consumption.
377
- if (jbgp.is_amx && jbgp.wei_dt == bf16 && !jbgp.is_bf32 && wei_size >= 8 * (1 << 20 ))
377
+ if (jbgp.is_amx && jbgp.orig_wei_dt == bf16 && !jbgp.is_bf32 && wei_size >= 8 * (1 << 20 ))
378
378
return 32 ;
379
379
// Use oc block to be 64 if weight size >= 16MB on avx512 f32 to optimized memory consumption.
380
- if (is_f32_compute_avx512 && wei_size >= 16 * (1 << 20 ))
380
+ if ((is_f32_compute_avx512 || (jbgp.is_amx && jbgp.orig_wei_dt != bf16 && !jbgp.is_bf32 ))
381
+ && wei_size >= 16 * (1 << 20 ))
381
382
return 64 ;
382
383
// Use oc block to be 24 if weight size >= 16MB on avx2 f32 to optimized memory consumption.
383
384
if (is_f32_compute_avx2 && wei_size >= 16 * (1 << 20 ))
@@ -1339,8 +1340,7 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
1339
1340
1340
1341
jbgp.weights_decompression = one_of (jbgp.src_dt , f32, bf16) &&
1341
1342
one_of (jbgp.wei_dt , u8, nf4, s4, u4);
1342
- jbgp.wei_decomp_algo = jbgp.is_amx ? weights_decomp_kind_t ::prepack
1343
- : weights_decomp_kind_t ::immediate;
1343
+ jbgp.wei_decomp_algo = weights_decomp_kind_t ::immediate;
1344
1344
jbgp.orig_wei_dt = jbgp.wei_dt ;
1345
1345
jbgp.with_grouped_weights_decompression = false ;
1346
1346
if (jbgp.weights_decompression ) {
@@ -1366,6 +1366,10 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
1366
1366
}
1367
1367
}
1368
1368
1369
+ // Current AMX implementation cannot provide perfromance benefit for immediate algorithm over avx512 version
1370
+ if (jbgp.is_amx && jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t ::immediate)
1371
+ return status::unimplemented;
1372
+
1369
1373
jbgp.bia_dt = jbgp.with_bias
1370
1374
? pick_by_prop_kind (jbgp.prop_kind , ipd.bias_desc .data_type ,
1371
1375
data_type::undef, ipd.diff_bias_desc .data_type )
@@ -1382,7 +1386,8 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
1382
1386
everyone_is (bf16, jbgp.wei_dt , jbgp.dst_dt )
1383
1387
&& jbgp.src_dt == f32,
1384
1388
everyone_is (bf16, jbgp.src_dt , jbgp.dst_dt )
1385
- && jbgp.wei_dt == f32);
1389
+ && jbgp.wei_dt == f32)
1390
+ || (jbgp.weights_decompression && jbgp.src_dt == bf16 && one_of (jbgp.dst_dt , f32, bf16));
1386
1391
const bool is_f16 = everyone_is (f16, jbgp.src_dt , jbgp.wei_dt , jbgp.dst_dt )
1387
1392
|| pick_by_prop_kind (jbgp.prop_kind ,
1388
1393
everyone_is (f16, jbgp.src_dt , jbgp.wei_dt )
0 commit comments