Skip to content

Commit 90483d1

Browse files
dmitry-gorokhovazhai219
authored andcommittedDec 9, 2024
[FORK][FEATURE] IP weights compression: mxfp4 (wei=f4e2m1, scales=f8e8m0) support
1 parent 3a0d17d commit 90483d1

29 files changed

+348
-83
lines changed
 

‎include/oneapi/dnnl/dnnl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
422422
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
423423
dnnl_primitive_attr_t attr, int arg, int mask);
424424
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims(
425-
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims);
425+
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type);
426426

427427
/// Sets primitive attributes scaling factors for primitive operations for a
428428
/// given memory argument. The scaling factors must be passed at execution time

‎include/oneapi/dnnl/dnnl.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,10 @@ struct memory : public handle<dnnl_memory_t> {
903903
bin = dnnl_bin,
904904
/// 4-bit normalized float.
905905
nf4 = dnnl_nf4,
906+
/// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa.
907+
f8_e8m0 = dnnl_f8_e8m0,
908+
/// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa.
909+
f4_e2m1 = dnnl_f4_e2m1
906910
};
907911

908912
/// Returns size of data type in bytes.
@@ -4239,8 +4243,8 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
42394243
error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
42404244
"could not set scales primitive attribute");
42414245
}
4242-
void set_scales_dims(int arg, const memory::dims& dims) {
4243-
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size()),
4246+
void set_scales_dims(int arg, const memory::dims& dims, memory::data_type data_type = memory::data_type::f32) {
4247+
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(data_type)),
42444248
"could not set scales primitive attribute");
42454249
}
42464250

‎include/oneapi/dnnl/dnnl_common_types.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,12 @@ typedef enum {
106106
dnnl_e8m0 = 13,
107107
/// 4-bit normalized float.
108108
dnnl_nf4 = 14,
109+
/// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa
110+
dnnl_f8_e8m0 = 15,
111+
/// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa
112+
dnnl_f4_e2m1 = 16,
109113
/// 1-bit integer.
110-
dnnl_bin = 15,
114+
dnnl_bin = 17,
111115

112116
/// Parameter to allow internal only data_types without undefined behavior.
113117
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.

‎src/common/c_types_map.hpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,13 @@ const data_type_t u4 = dnnl_u4;
170170
const data_type_t boolean = dnnl_boolean;
171171
const data_type_t data_type_max = dnnl_data_type_max;
172172

173-
// Not exposed through API as all current uses are internal only
174-
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);
175-
176173
const data_type_t bin = dnnl_bin;
177174
const data_type_t nf4 = dnnl_nf4;
175+
const data_type_t f4_e2m1 = dnnl_f4_e2m1;
176+
const data_type_t f8_e8m0 = dnnl_f8_e8m0;
177+
178+
// Not exposed through API as all current uses are internal only
179+
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);
178180
} // namespace data_type
179181

180182
using fpmath_mode_t = dnnl_fpmath_mode_t;

‎src/common/dnnl_debug_autogenerated.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ const char *dnnl_dt2str(dnnl_data_type_t v) {
6161
if (v == dnnl_nf4) return "nf4";
6262
if (v == dnnl_s4) return "s4";
6363
if (v == dnnl_u4) return "u4";
64+
if (v == dnnl_f8_e8m0) return "f8_e8m0";
65+
if (v == dnnl_f4_e2m1) return "f4_e2m1";
6466
if (v == dnnl_data_type_max) return "data_type_max";
6567
assert(!"unknown dt");
6668
return "unknown dt";

‎src/common/dnnl_traits.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ template <> struct prec_traits<data_type::nf4> {
104104
typedef uint8_t type;
105105
};
106106

107+
template <>
108+
struct prec_traits<data_type::f8_e8m0> {
109+
typedef uint8_t type;
110+
};
111+
112+
template <>
113+
struct prec_traits<data_type::f4_e2m1> {
114+
typedef uint8_t type;
115+
};
116+
107117
template <>
108118
struct data_traits<float8_e5m2_t> {
109119
static constexpr data_type_t data_type = data_type::f8_e5m2;

‎src/common/inner_product.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
111111
if (attr == nullptr) return status::success;
112112
const data_type_t src_dt = desc.src_desc.data_type;
113113
const data_type_t wei_dt = desc.weights_desc.data_type;
114-
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) && one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4)) ||
114+
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) &&
115+
one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) ||
115116
(one_of(src_dt, data_type::f32) && one_of(wei_dt, data_type::f16, data_type::bf16));
116117
auto attr_mask = smask_t::none;
117118
// From oneDNN 3.5, those checks must be skipped if wei_decomp is enabled
@@ -140,7 +141,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
140141
|| utils::one_of(dst_dt, data_type::s8, data_type::u8,
141142
data_type::s32);
142143
if (engine->kind() == engine_kind::cpu)
143-
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
144+
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1);
144145
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime | smask_t::src_dyn_quant_params;
145146
if (is_weight_compression) {
146147
fwd_attr_mask |= attr_mask;

‎src/common/memory_zero_pad.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
295295
case u4: return typed_zero_pad<u8>(memory, ctx);
296296
case bin: return typed_zero_pad<u8>(memory, ctx);
297297
case nf4: return typed_zero_pad<u8>(memory, ctx);
298+
case f8_e8m0: return typed_zero_pad<u8>(memory, ctx);
299+
case f4_e2m1: return typed_zero_pad<u8>(memory, ctx);
298300
default: assert(!"memory is undefined"); return unimplemented;
299301
}
300302
return unimplemented;

‎src/common/primitive_attr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -697,11 +697,11 @@ status_t dnnl_primitive_attr_set_scales_mask(
697697
return attr->scales_.set(arg, mask);
698698
}
699699
status_t dnnl_primitive_attr_set_scales_dims(
700-
primitive_attr_t *attr, int arg, const dims_t dims, int ndims) {
700+
primitive_attr_t *attr, int arg, const dims_t dims, int ndims, data_type_t data_type) {
701701
bool ok = attr && arg >= 0 && ndims > 0
702702
&& attr->output_scales_.has_default_values();
703703
if (!ok) return invalid_arguments;
704-
return attr->scales_.set(arg, dims, ndims);
704+
return attr->scales_.set(arg, dims, ndims, data_type);
705705
}
706706

707707
status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg,

‎src/common/primitive_attr.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,12 @@ struct runtime_scales_t : public c_compatible {
264264
return status::success;
265265
}
266266

267-
status_t set(const dims_t dims, int ndims) {
267+
status_t set(const dims_t dims, int ndims, data_type_t data_type = data_type::f32) {
268268
is_set_ = true;
269269
ndims_ = ndims;
270270
mask_ = 1;
271271
utils::array_copy(dims_, dims, ndims_);
272+
data_type_ = data_type;
272273
return status::success;
273274
}
274275

@@ -346,9 +347,9 @@ struct arg_scales_t : public c_compatible {
346347
status_t set(int arg, int mask) {
347348
return set(arg, mask, 0, {}, data_type::f32);
348349
}
349-
status_t set(int arg, const dims_t dims, int ndims) {
350+
status_t set(int arg, const dims_t dims, int ndims, data_type_t data_type) {
350351
if (!check_arg(arg)) return status::invalid_arguments;
351-
return scales_[arg].set(dims, ndims);
352+
return scales_[arg].set(dims, ndims, data_type);
352353
}
353354

354355
status_t set(int arg, int mask, int ndims, const dims_t group_dims,

‎src/common/type_helpers.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ inline size_t data_type_size(data_type_t data_type) {
9898
case boolean: return sizeof(prec_traits<boolean>::type);
9999
case bin: return sizeof(prec_traits<u8>::type);
100100
case nf4: return sizeof(prec_traits<u8>::type);
101+
case f8_e8m0: return sizeof(prec_traits<f8_e8m0>::type);
102+
case f4_e2m1: return sizeof(prec_traits<f4_e2m1>::type);
101103
case data_type::undef:
102104
default: assert(!"unknown data_type");
103105
}
@@ -421,7 +423,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
421423

422424
/* prop_kind doesn't matter */
423425
if (everyone_is(f32, src_dt, wei_dt)) return f32;
424-
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4)) return f32;
426+
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) return f32;
425427
if (everyone_is(f64, src_dt, wei_dt)) return f64;
426428

427429
if (one_of(prop_kind, forward_training, forward_inference)) {
@@ -1263,7 +1265,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
12631265

12641266
bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
12651267
&& utils::one_of(data_type, f8_e5m2, f8_e4m3, f16, bf16, f32, f64,
1266-
s32, s8, u8, nf4, s4, u4, bin);
1268+
s32, s8, u8, nf4, s4, u4, bin, f8_e8m0, f4_e2m1);
12671269
if (!ok) return false;
12681270

12691271
bool has_runtime_dims = false;

‎src/cpu/cpu_inner_product_list.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
7171
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
7272
nullptr,
7373
}},
74+
{{forward, f32, f4_e2m1, f32}, {
75+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
76+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
77+
nullptr,
78+
}},
7479
{{forward, f32, s4, f32}, {
7580
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
7681
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
@@ -123,7 +128,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
123128
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
124129
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
125130
nullptr,
126-
}},
131+
}},
127132
{{forward, bf16, s8, bf16}, {
128133
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
129134
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
@@ -139,6 +144,16 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
139144
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
140145
nullptr,
141146
}},
147+
{{forward, bf16, f4_e2m1, f32}, {
148+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
149+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
150+
nullptr,
151+
}},
152+
{{forward, bf16, f4_e2m1, bf16}, {
153+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
154+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
155+
nullptr,
156+
}},
142157
{{forward, bf16, s4, f32}, {
143158
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
144159
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)

‎src/cpu/cpu_primitive.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
VCHECK_ATTR(scales != nullptr, \
7171
"Scales buffer for arg %d is missing", arg); \
7272
const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \
73-
bool ok = scales_d.data_type() == data_type::f32 \
73+
bool ok = (scales_d.data_type() == data_type::f32 || scales_d.data_type() == data_type::f8_e8m0) \
7474
&& (scales_d.ndims() == 1 || scales_d.ndims() == 2); \
7575
if (!ok) return status::invalid_arguments; \
7676
if (scales_d.dims()[0] == 1) { \

‎src/cpu/reorder/cpu_reorder.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ regular_impl_list_map() {
3636
{{f32, u8, 0}, &regular_f32_u8_impl_list_map()},
3737
{{f8_e5m2, data_type::undef, 0}, &regular_fp8_impl_list_map()},
3838
{{f8_e4m3, data_type::undef, 0}, &regular_fp8_impl_list_map()},
39+
{{f8_e8m0, data_type::undef, 0}, &regular_fp8_impl_list_map()},
3940
{{f32, bin, 0}, &regular_f32_bin_impl_list_map()},
4041
{{bf16, data_type::undef, 0}, &regular_bf16_impl_list_map()},
4142
{{f16, data_type::undef, 0}, &regular_f16_impl_list_map()},
@@ -48,6 +49,7 @@ regular_impl_list_map() {
4849
{{u4, f32, 0}, &regular_u4_impl_list_map()},
4950
{{bin, data_type::undef, 0}, &regular_bin_impl_list_map()},
5051
{{nf4, data_type::undef, 0}, &regular_nf4_impl_list_map()},
52+
{{f4_e2m1, data_type::undef, 0}, &regular_f4_impl_list_map()},
5153
{{s4, data_type::undef, 0}, &regular_s4_impl_list_map()},
5254
{{u4, data_type::undef, 0}, &regular_u4_impl_list_map()},
5355
};

‎src/cpu/reorder/cpu_reorder.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ extern const impl_list_map_t &regular_s4_impl_list_map();
9292
extern const impl_list_map_t &regular_u4_impl_list_map();
9393
extern const impl_list_map_t &regular_bin_impl_list_map();
9494
extern const impl_list_map_t &regular_nf4_impl_list_map();
95+
extern const impl_list_map_t &regular_f4_impl_list_map();
9596
extern const impl_list_map_t &regular_s4_impl_list_map();
9697
extern const impl_list_map_t &regular_u4_impl_list_map();
9798

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*******************************************************************************
2+
* Copyright 2021 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "cpu/reorder/cpu_reorder.hpp"
18+
19+
namespace dnnl {
20+
namespace impl {
21+
namespace cpu {
22+
23+
// clang-format off
24+
25+
const impl_list_map_t &regular_f4_impl_list_map() {
26+
static const impl_list_map_t the_map = REG_REORDER_P({
27+
// f4_e2m1 ->
28+
{{f4_e2m1, data_type::undef, 0}, {
29+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i8o2i, fmt_order_keep)
30+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i16o2i, fmt_order_keep)
31+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i24o2i, fmt_order_keep)
32+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i32o2i, fmt_order_keep)
33+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i64o2i, fmt_order_keep)
34+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i16o2i, fmt_order_keep)
35+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i32o2i, fmt_order_keep)
36+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i48o2i, fmt_order_keep)
37+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i64o2i, fmt_order_keep)
38+
nullptr,
39+
}},
40+
});
41+
return the_map;
42+
}
43+
44+
// clang-format on
45+
46+
} // namespace cpu
47+
} // namespace impl
48+
} // namespace dnnl

‎src/cpu/reorder/cpu_reorder_regular_fp8.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ const impl_list_map_t &regular_fp8_impl_list_map() {
4646
REG_SR(f8_e4m3, any, bf16, any, fmt_order::any, spec::reference)
4747
REG_SR(f8_e4m3, any, f32, any, fmt_order::any, spec::reference)
4848

49+
nullptr,
50+
}},
51+
// f8_e8m0 ->
52+
{{f8_e8m0, data_type::undef, 0}, {
53+
REG_SR(f8_e8m0, any, f8_e8m0, any, fmt_order::any, spec::reference)
54+
4955
nullptr,
5056
}},
5157
});

‎src/cpu/reorder/simple_reorder.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,8 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
16971697
tag_traits<tag_o>::ndims >= 4
16981698
&& tag_traits<tag_o>::ndims <= 6)
16991699
&& (type_i != dnnl_bin && type_o != dnnl_bin)
1700-
&& (type_i != dnnl_nf4 && type_o != dnnl_nf4)>::type> {
1700+
&& (type_i != dnnl_nf4 && type_o != dnnl_nf4)
1701+
&& (type_i != dnnl_f4_e2m1 && type_o != dnnl_f4_e2m1)>::type> {
17011702
PLAIN_TO_BLOCKED_IS_APPLICABLE();
17021703

17031704
GET_SCRATCHPAD_SIZE_ZERO();
@@ -2004,7 +2005,7 @@ template <SIMPLE_REORDER_TEMPL_DECL>
20042005
struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
20052006
typename utils::enable_if<tag_i == format_tag::any &&
20062007
tag_traits<tag_o>::block_dims == bd::_AB &&
2007-
utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4) &&
2008+
utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1) &&
20082009
type_i == type_o>::type>
20092010
{
20102011
static bool is_applicable(const memory_desc_wrapper &input_d,
@@ -2537,8 +2538,8 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
25372538
typename utils::enable_if<tag_i == format_tag::any
25382539
&& tag_o == format_tag::any
25392540
&& order_keep == fmt_order::any
2540-
&& !(utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4) ||
2541-
utils::one_of(type_o, dnnl_nf4, dnnl_s4, dnnl_u4)),
2541+
&& !(utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1) ||
2542+
utils::one_of(type_o, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1)),
25422543
spec::reference>::type> {
25432544
static bool is_applicable(const memory_desc_wrapper &input_d,
25442545
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {

‎src/cpu/x64/brgemm/brgemm.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
285285
brg->with_wei_decomp_scales = !wei_scales.has_default_values();
286286
brg->wei_decomp_scales_group_size = wei_d.dims()[1];
287287
if (brg->with_wei_decomp_scales) {
288+
brg->wei_decomp_scales_dt = wei_scales.data_type_;
289+
if (!one_of(brg->wei_decomp_scales_dt, f32, f8_e8m0))
290+
return status::unimplemented;
291+
288292
auto ld_dim = wei_scales.dims_[0];
289293
brg->wei_decomp_scales_stride = ld_dim > 1 ? ld_dim : 0;
290294
brg->wei_decomp_scales_group_size = wei_d.dims()[1] / wei_scales.dims_[1];

‎src/cpu/x64/brgemm/brgemm_types.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ struct brgemm_desc_t {
316316
int wei_decomp_zero_points_stride = 0;
317317
int wei_decomp_scales_group_size = 0;
318318
int wei_decomp_zero_points_group_size = 0;
319+
impl::data_type_t wei_decomp_scales_dt = data_type::undef;
319320
impl::data_type_t wei_decomp_zero_points_dt = data_type::undef;
320321
bool with_src_dyn_quant = false;
321322
int src_scales_group_size = 0;

0 commit comments

Comments
 (0)