Skip to content

Commit 6b99866

Browse files
author
dmitrygo
committed
[FORK][FEATURE] IP weights compression: mxfp4 (wei=f4e2m1, scales=f8e8m0) support
1 parent b1c677d commit 6b99866

29 files changed

+345
-83
lines changed

include/oneapi/dnnl/dnnl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
403403
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
404404
dnnl_primitive_attr_t attr, int arg, int mask);
405405
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims(
406-
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims);
406+
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type);
407407

408408
/// Sets primitive attributes scaling factors for primitive operations for a
409409
/// given memory argument. The scaling factors must be passed at execution time

include/oneapi/dnnl/dnnl.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,10 @@ struct memory : public handle<dnnl_memory_t> {
884884
bin = dnnl_bin,
885885
/// 4-bit normalized float.
886886
nf4 = dnnl_nf4,
887+
/// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa.
888+
f8_e8m0 = dnnl_f8_e8m0,
889+
/// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa.
890+
f4_e2m1 = dnnl_f4_e2m1
887891
};
888892

889893
/// Returns size of data type in bytes.
@@ -4142,8 +4146,8 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
41424146
error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
41434147
"could not set scales primitive attribute");
41444148
}
4145-
void set_scales_dims(int arg, const memory::dims& dims) {
4146-
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size()),
4149+
void set_scales_dims(int arg, const memory::dims& dims, memory::data_type data_type = memory::data_type::f32) {
4150+
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(data_type)),
41474151
"could not set scales primitive attribute");
41484152
}
41494153

include/oneapi/dnnl/dnnl_common_types.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,12 @@ typedef enum {
104104
dnnl_u4 = 12,
105105
/// 4-bit normalized float.
106106
dnnl_nf4 = 13,
107+
/// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa
108+
dnnl_f8_e8m0 = 14,
109+
/// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa
110+
dnnl_f4_e2m1 = 15,
107111
/// 1-bit integer.
108-
dnnl_bin = 14,
112+
dnnl_bin = 16,
109113

110114
/// Parameter to allow internal only data_types without undefined behavior.
111115
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.

src/common/c_types_map.hpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,13 @@ const data_type_t u4 = dnnl_u4;
169169
const data_type_t boolean = dnnl_boolean;
170170
const data_type_t data_type_max = dnnl_data_type_max;
171171

172-
// Not exposed through API as all current uses are internal only
173-
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);
174-
175172
const data_type_t bin = dnnl_bin;
176173
const data_type_t nf4 = dnnl_nf4;
174+
const data_type_t f4_e2m1 = dnnl_f4_e2m1;
175+
const data_type_t f8_e8m0 = dnnl_f8_e8m0;
176+
177+
// Not exposed through API as all current uses are internal only
178+
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);
177179
} // namespace data_type
178180

179181
using fpmath_mode_t = dnnl_fpmath_mode_t;

src/common/dnnl_debug_autogenerated.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ const char *dnnl_dt2str(dnnl_data_type_t v) {
5858
if (v == dnnl_nf4) return "nf4";
5959
if (v == dnnl_s4) return "s4";
6060
if (v == dnnl_u4) return "u4";
61+
if (v == dnnl_f8_e8m0) return "f8_e8m0";
62+
if (v == dnnl_f4_e2m1) return "f4_e2m1";
6163
if (v == dnnl_data_type_max) return "data_type_max";
6264
assert(!"unknown dt");
6365
return "unknown dt";

src/common/dnnl_traits.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,16 @@ template <> struct prec_traits<data_type::nf4> {
100100
typedef uint8_t type;
101101
};
102102

103+
template <>
104+
struct prec_traits<data_type::f8_e8m0> {
105+
typedef uint8_t type;
106+
};
107+
108+
template <>
109+
struct prec_traits<data_type::f4_e2m1> {
110+
typedef uint8_t type;
111+
};
112+
103113
template <>
104114
struct data_traits<float8_e5m2_t> {
105115
static constexpr data_type_t data_type = data_type::f8_e5m2;

src/common/inner_product.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
111111
if (attr == nullptr) return status::success;
112112
const data_type_t src_dt = desc.src_desc.data_type;
113113
const data_type_t wei_dt = desc.weights_desc.data_type;
114-
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) && one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4)) ||
114+
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) &&
115+
one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) ||
115116
(one_of(src_dt, data_type::f32) && one_of(wei_dt, data_type::f16, data_type::bf16));
116117
auto attr_mask = smask_t::none;
117118
// From oneDNN 3.5, those checks must be skipped if wei_decomp is enabled
@@ -140,7 +141,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
140141
|| utils::one_of(dst_dt, data_type::s8, data_type::u8,
141142
data_type::s32);
142143
if (engine->kind() == engine_kind::cpu)
143-
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
144+
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1);
144145
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime | smask_t::src_dyn_quant_params;
145146
if (is_weight_compression) {
146147
fwd_attr_mask |= attr_mask;

src/common/memory_zero_pad.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
294294
case u4: return typed_zero_pad<u8>(memory, ctx);
295295
case bin: return typed_zero_pad<u8>(memory, ctx);
296296
case nf4: return typed_zero_pad<u8>(memory, ctx);
297+
case f8_e8m0: return typed_zero_pad<u8>(memory, ctx);
298+
case f4_e2m1: return typed_zero_pad<u8>(memory, ctx);
297299
default: assert(!"memory is undefined"); return unimplemented;
298300
}
299301
return unimplemented;

src/common/primitive_attr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -656,11 +656,11 @@ status_t dnnl_primitive_attr_set_scales_mask(
656656
return attr->scales_.set(arg, mask);
657657
}
658658
status_t dnnl_primitive_attr_set_scales_dims(
659-
primitive_attr_t *attr, int arg, const dims_t dims, int ndims) {
659+
primitive_attr_t *attr, int arg, const dims_t dims, int ndims, data_type_t data_type) {
660660
bool ok = attr && arg >= 0 && ndims > 0
661661
&& attr->output_scales_.has_default_values();
662662
if (!ok) return invalid_arguments;
663-
return attr->scales_.set(arg, dims, ndims);
663+
return attr->scales_.set(arg, dims, ndims, data_type);
664664
}
665665

666666
status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg,

src/common/primitive_attr.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,12 @@ struct runtime_scales_t : public c_compatible {
265265
return status::success;
266266
}
267267

268-
status_t set(const dims_t dims, int ndims) {
268+
status_t set(const dims_t dims, int ndims, data_type_t data_type = data_type::f32) {
269269
is_set_ = true;
270270
ndims_ = ndims;
271271
mask_ = 1;
272272
utils::array_copy(dims_, dims, ndims_);
273+
data_type_ = data_type;
273274
return status::success;
274275
}
275276

@@ -348,9 +349,9 @@ struct arg_scales_t : public c_compatible {
348349
if (!check_arg(arg)) return status::invalid_arguments;
349350
return scales_[arg].set(mask);
350351
}
351-
status_t set(int arg, const dims_t dims, int ndims) {
352+
status_t set(int arg, const dims_t dims, int ndims, data_type_t data_type) {
352353
if (!check_arg(arg)) return status::invalid_arguments;
353-
return scales_[arg].set(dims, ndims);
354+
return scales_[arg].set(dims, ndims, data_type);
354355
}
355356

356357
status_t set(int arg, int mask, int ndims, const dims_t group_dims,

src/common/type_helpers.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ inline size_t data_type_size(data_type_t data_type) {
9898
case boolean: return sizeof(prec_traits<boolean>::type);
9999
case bin: return sizeof(prec_traits<u8>::type);
100100
case nf4: return sizeof(prec_traits<u8>::type);
101+
case f8_e8m0: return sizeof(prec_traits<f8_e8m0>::type);
102+
case f4_e2m1: return sizeof(prec_traits<f4_e2m1>::type);
101103
case data_type::undef:
102104
default: assert(!"unknown data_type");
103105
}
@@ -318,7 +320,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
318320

319321
/* prop_kind doesn't matter */
320322
if (everyone_is(f32, src_dt, wei_dt)) return f32;
321-
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4)) return f32;
323+
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) return f32;
322324
if (everyone_is(f64, src_dt, wei_dt)) return f64;
323325

324326
if (one_of(prop_kind, forward_training, forward_inference)) {
@@ -1086,7 +1088,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
10861088

10871089
bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
10881090
&& utils::one_of(data_type, f8_e5m2, f8_e4m3, f16, bf16, f32, f64,
1089-
s32, s8, u8, nf4, s4, u4, bin);
1091+
s32, s8, u8, nf4, s4, u4, bin, f8_e8m0, f4_e2m1);
10901092
if (!ok) return false;
10911093

10921094
bool has_runtime_dims = false;

src/cpu/cpu_inner_product_list.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
7171
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
7272
nullptr,
7373
}},
74+
{{forward, f32, f4_e2m1, f32}, {
75+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
76+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
77+
nullptr,
78+
}},
7479
{{forward, f32, s4, f32}, {
7580
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
7681
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
@@ -123,7 +128,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
123128
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
124129
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
125130
nullptr,
126-
}},
131+
}},
127132
{{forward, bf16, s8, bf16}, {
128133
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
129134
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
@@ -139,6 +144,16 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
139144
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
140145
nullptr,
141146
}},
147+
{{forward, bf16, f4_e2m1, f32}, {
148+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
149+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
150+
nullptr,
151+
}},
152+
{{forward, bf16, f4_e2m1, bf16}, {
153+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
154+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
155+
nullptr,
156+
}},
142157
{{forward, bf16, s4, f32}, {
143158
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
144159
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)

src/cpu/cpu_primitive.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
VCHECK_ATTR(scales != nullptr, \
7676
"Scales buffer for arg %d is missing", arg); \
7777
const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \
78-
bool ok = scales_d.data_type() == data_type::f32 \
78+
bool ok = (scales_d.data_type() == data_type::f32 || scales_d.data_type() == data_type::f8_e8m0) \
7979
&& (scales_d.ndims() == 1 || scales_d.ndims() == 2); \
8080
if (!ok) return status::invalid_arguments; \
8181
if (scales_d.dims()[0] == 1) { \

src/cpu/reorder/cpu_reorder.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ regular_impl_list_map() {
3535
{{f32, u8, 0}, &regular_f32_u8_impl_list_map()},
3636
{{f8_e5m2, data_type::undef, 0}, &regular_fp8_impl_list_map()},
3737
{{f8_e4m3, data_type::undef, 0}, &regular_fp8_impl_list_map()},
38+
{{f8_e8m0, data_type::undef, 0}, &regular_fp8_impl_list_map()},
3839
{{f32, bin, 0}, &regular_f32_bin_impl_list_map()},
3940
{{bf16, data_type::undef, 0}, &regular_bf16_impl_list_map()},
4041
{{f16, data_type::undef, 0}, &regular_f16_impl_list_map()},
@@ -47,6 +48,7 @@ regular_impl_list_map() {
4748
{{u4, f32, 0}, &regular_u4_impl_list_map()},
4849
{{bin, data_type::undef, 0}, &regular_bin_impl_list_map()},
4950
{{nf4, data_type::undef, 0}, &regular_nf4_impl_list_map()},
51+
{{f4_e2m1, data_type::undef, 0}, &regular_f4_impl_list_map()},
5052
{{s4, data_type::undef, 0}, &regular_s4_impl_list_map()},
5153
{{u4, data_type::undef, 0}, &regular_u4_impl_list_map()},
5254
};

src/cpu/reorder/cpu_reorder.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ extern const impl_list_map_t &regular_s4_impl_list_map();
9191
extern const impl_list_map_t &regular_u4_impl_list_map();
9292
extern const impl_list_map_t &regular_bin_impl_list_map();
9393
extern const impl_list_map_t &regular_nf4_impl_list_map();
94+
extern const impl_list_map_t &regular_f4_impl_list_map();
9495
extern const impl_list_map_t &regular_s4_impl_list_map();
9596
extern const impl_list_map_t &regular_u4_impl_list_map();
9697

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*******************************************************************************
2+
* Copyright 2021 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "cpu/reorder/cpu_reorder.hpp"
18+
19+
namespace dnnl {
20+
namespace impl {
21+
namespace cpu {
22+
23+
// clang-format off
24+
25+
const impl_list_map_t &regular_f4_impl_list_map() {
26+
static const impl_list_map_t the_map = REG_REORDER_P({
27+
// f4_e2m1 ->
28+
{{f4_e2m1, data_type::undef, 0}, {
29+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i8o2i, fmt_order_keep)
30+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i16o2i, fmt_order_keep)
31+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i24o2i, fmt_order_keep)
32+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i32o2i, fmt_order_keep)
33+
REG_SR(f4_e2m1, any, f4_e2m1, OI8i64o2i, fmt_order_keep)
34+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i16o2i, fmt_order_keep)
35+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i32o2i, fmt_order_keep)
36+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i48o2i, fmt_order_keep)
37+
REG_SR(f4_e2m1, any, f4_e2m1, OI16i64o2i, fmt_order_keep)
38+
nullptr,
39+
}},
40+
});
41+
return the_map;
42+
}
43+
44+
// clang-format on
45+
46+
} // namespace cpu
47+
} // namespace impl
48+
} // namespace dnnl

src/cpu/reorder/cpu_reorder_regular_fp8.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ const impl_list_map_t &regular_fp8_impl_list_map() {
4646
REG_SR(f8_e4m3, any, bf16, any, fmt_order::any, spec::reference)
4747
REG_SR(f8_e4m3, any, f32, any, fmt_order::any, spec::reference)
4848

49+
nullptr,
50+
}},
51+
// f8_e8m0 ->
52+
{{f8_e8m0, data_type::undef, 0}, {
53+
REG_SR(f8_e8m0, any, f8_e8m0, any, fmt_order::any, spec::reference)
54+
4955
nullptr,
5056
}},
5157
});

src/cpu/reorder/simple_reorder.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,8 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
16971697
tag_traits<tag_o>::ndims >= 4
16981698
&& tag_traits<tag_o>::ndims <= 6)
16991699
&& (type_i != dnnl_bin && type_o != dnnl_bin)
1700-
&& (type_i != dnnl_nf4 && type_o != dnnl_nf4)>::type> {
1700+
&& (type_i != dnnl_nf4 && type_o != dnnl_nf4)
1701+
&& (type_i != dnnl_f4_e2m1 && type_o != dnnl_f4_e2m1)>::type> {
17011702
PLAIN_TO_BLOCKED_IS_APPLICABLE();
17021703

17031704
GET_SCRATCHPAD_SIZE_ZERO();
@@ -2004,7 +2005,7 @@ template <SIMPLE_REORDER_TEMPL_DECL>
20042005
struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
20052006
typename utils::enable_if<tag_i == format_tag::any &&
20062007
tag_traits<tag_o>::block_dims == bd::_AB &&
2007-
utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4) &&
2008+
utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1) &&
20082009
type_i == type_o>::type>
20092010
{
20102011
static bool is_applicable(const memory_desc_wrapper &input_d,
@@ -2483,8 +2484,8 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
24832484
typename utils::enable_if<tag_i == format_tag::any
24842485
&& tag_o == format_tag::any
24852486
&& order_keep == fmt_order::any
2486-
&& !(utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4) ||
2487-
utils::one_of(type_o, dnnl_nf4, dnnl_s4, dnnl_u4)),
2487+
&& !(utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1) ||
2488+
utils::one_of(type_o, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1)),
24882489
spec::reference>::type> {
24892490
static bool is_applicable(const memory_desc_wrapper &input_d,
24902491
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {

src/cpu/x64/brgemm/brgemm.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
285285
brg->with_wei_decomp_scales = !wei_scales.has_default_values();
286286
brg->wei_decomp_scales_group_size = wei_d.dims()[1];
287287
if (brg->with_wei_decomp_scales) {
288+
brg->wei_decomp_scales_dt = wei_scales.data_type_;
289+
if (!one_of(brg->wei_decomp_scales_dt, f32, f8_e8m0))
290+
return status::unimplemented;
291+
288292
auto ld_dim = wei_scales.dims_[0];
289293
brg->wei_decomp_scales_stride = ld_dim > 1 ? ld_dim : 0;
290294
brg->wei_decomp_scales_group_size = wei_d.dims()[1] / wei_scales.dims_[1];

src/cpu/x64/brgemm/brgemm_types.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ struct brgemm_desc_t {
315315
int wei_decomp_zero_points_stride = 0;
316316
int wei_decomp_scales_group_size = 0;
317317
int wei_decomp_zero_points_group_size = 0;
318+
impl::data_type_t wei_decomp_scales_dt = data_type::undef;
318319
impl::data_type_t wei_decomp_zero_points_dt = data_type::undef;
319320
bool with_src_dyn_quant = false;
320321
int src_scales_group_size = 0;

0 commit comments

Comments
 (0)