Skip to content

Commit bd4f691

Browse files
luweizhou2016Tingqian Li
authored andcommitted
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
U8,i8,u4,i4,fp16 and dynamic quantization support on U8,i8,u4,i4. ONEDNN3.5 migration has squashed these changes into one: [FORK][FEATURE] InnerProduct primitive: 8bit weights decompression support [FORK][FEATURE] InnerProduct primitive: 8bit weights decompression support on AMX [FORK][FEATURE] InnerProduct primitive: 4bit weights decompression support [FORK][FEATURE] Enable prepack algorithm for 4bit weights decompression [FORK][FEATURE] InnerProduct primitive: 4bit weights decompression support on SPR [FORK][FEATURE] InnerProduct primitive: src dynamic quantization [FORK][FIX] Fixed behavior for unaligned src and weights ic groups [FORK][FEATURE] InnerProduct primitive: src dynamic quantization [FORK][FEATURE] Support (f32,fp16,f32) inner-product [FORK][FEATURE] Support s8 for weight-compresseion ip [FORK][FIX] Squash FC weight compression fix when migrating 3.5. [FORK][FIX]fix ld_step [FORK][FEATURE]Remove sub_byte_data_type_multiplier [Fork][Fix]Fix offset for dynamic quantization [FORK][FIX]Fix zp check and u4/s4 reorder [FORK][FIX]Fix zp set API [Fork][Fix]Skip runtime scale & zp check with weight compression [Fork][Fix] fix bug in brgemm introduced by (f32,fp16,f32) inner-product Co-Authored-By: Tingqian Li <Tingqian.Li@intel.com>, Yi Zhang<Yi3.Zhang@intel.com>
1 parent db4e3d2 commit bd4f691

38 files changed

+2682
-207
lines changed

include/oneapi/dnnl/dnnl.h

+8
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,8 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
421421
/// otherwise.
422422
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
423423
dnnl_primitive_attr_t attr, int arg, int mask);
424+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims(
425+
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims);
424426

425427
/// Sets primitive attributes scaling factors for primitive operations for a
426428
/// given memory argument. The scaling factors must be passed at execution time
@@ -468,6 +470,8 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
468470
/// otherwise.
469471
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
470472
dnnl_primitive_attr_t attr, int arg, int mask);
473+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_dims(
474+
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type);
471475

472476
/// Sets primitive attributes zero points for primitive operations for a given
473477
/// memory argument. The zero points must be passed at execution time
@@ -2703,6 +2707,10 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
27032707
const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
27042708
const float **scales);
27052709

2710+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_src_dyn_quant_params(
2711+
dnnl_primitive_attr_t attr, uint64_t group_size);
2712+
dnnl_status_t DNNL_API dnnl_primitive_attr_get_src_dyn_quant_params(
2713+
dnnl_primitive_attr_t attr, uint64_t* group_size);
27062714
/// @} dnnl_api_attributes
27072715

27082716
/// @addtogroup dnnl_api_rnn

include/oneapi/dnnl/dnnl.hpp

+22-1
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,9 @@ struct memory : public handle<dnnl_memory_t> {
900900
/// 4-bit unsigned integer.
901901
u4 = dnnl_u4,
902902
/// 1-bit integer
903-
bin = dnnl_bin
903+
bin = dnnl_bin,
904+
/// 4-bit normalized float.
905+
nf4 = dnnl_nf4,
904906
};
905907

906908
/// Returns size of data type in bytes.
@@ -4237,6 +4239,10 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
42374239
error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
42384240
"could not set scales primitive attribute");
42394241
}
4242+
void set_scales_dims(int arg, const memory::dims& dims) {
4243+
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size()),
4244+
"could not set scales primitive attribute");
4245+
}
42404246

42414247
/// Sets scaling factors for primitive operations for a given memory
42424248
/// argument. The scaling factors must be passed at execution time
@@ -4282,6 +4288,11 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
42824288
dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
42834289
"could not set zero points primitive attribute");
42844290
}
4291+
void set_zero_points_dims(int arg, const memory::dims& dims, memory::data_type dt) {
4292+
error::wrap_c_api(
4293+
dnnl_primitive_attr_set_zero_points_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(dt)),
4294+
"could not set zero points primitive attribute");
4295+
}
42854296

42864297
/// Sets zero points for primitive operations for a given memory argument.
42874298
/// The zero points must be passed at execution time as an argument with
@@ -4550,6 +4561,16 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
45504561
for (dnnl_dim_t c = 0; c < count; c++)
45514562
scales[c] = c_scales[c];
45524563
}
4564+
4565+
void set_src_dyn_quant_params(uint64_t group_size) {
4566+
error::wrap_c_api(dnnl_primitive_attr_set_src_dyn_quant_params(get(), group_size),
4567+
"could not set src dynamic quantization parameters primitive attribute");
4568+
}
4569+
4570+
void get_src_dyn_quant_params(uint64_t& group_size) const {
4571+
error::wrap_c_api(dnnl_primitive_attr_get_src_dyn_quant_params(get(), &group_size),
4572+
"could not get src dynamic quantization parameters primitive attribute");
4573+
}
45534574
};
45544575

45554576
/// @} dnnl_api_attributes

include/oneapi/dnnl/dnnl_common_types.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,11 @@ typedef enum {
104104
dnnl_u4 = 12,
105105
/// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent.
106106
dnnl_e8m0 = 13,
107-
107+
/// 4-bit normalized float.
108+
dnnl_nf4 = 14,
108109
/// 1-bit integer.
109-
dnnl_bin = 14,
110+
dnnl_bin = 15,
111+
110112
/// Parameter to allow internal only data_types without undefined behavior.
111113
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.
112114
dnnl_data_type_max = 0x7fff,

src/common/c_types_map.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ const data_type_t data_type_max = dnnl_data_type_max;
174174
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);
175175

176176
const data_type_t bin = dnnl_bin;
177+
const data_type_t nf4 = dnnl_nf4;
177178
} // namespace data_type
178179

179180
using fpmath_mode_t = dnnl_fpmath_mode_t;

src/common/dnnl_debug_autogenerated.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ const char *dnnl_dt2str(dnnl_data_type_t v) {
5858
if (v == dnnl_u4) return "u4";
5959
if (v == dnnl_e8m0) return "e8m0";
6060
if (v == dnnl_bin) return "bin";
61+
if (v == dnnl_nf4) return "nf4";
62+
if (v == dnnl_s4) return "s4";
63+
if (v == dnnl_u4) return "u4";
6164
if (v == dnnl_data_type_max) return "data_type_max";
6265
assert(!"unknown dt");
6366
return "unknown dt";

src/common/dnnl_traits.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ template <> struct prec_traits<data_type::bin> {
100100
typedef uint8_t type;
101101
};
102102

103+
template <> struct prec_traits<data_type::nf4> {
104+
typedef uint8_t type;
105+
};
106+
103107
template <>
104108
struct data_traits<float8_e5m2_t> {
105109
static constexpr data_type_t data_type = data_type::f8_e5m2;

src/common/inner_product.cpp

+21-3
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,27 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
109109
using smask_t = primitive_attr_t::skip_mask_t;
110110

111111
if (attr == nullptr) return status::success;
112-
if (attr->has_default_values()) return status::success;
112+
const data_type_t src_dt = desc.src_desc.data_type;
113+
const data_type_t wei_dt = desc.weights_desc.data_type;
114+
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) && one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4)) ||
115+
(one_of(src_dt, data_type::f32) && one_of(wei_dt, data_type::f16, data_type::bf16));
116+
auto attr_mask = smask_t::none;
117+
// From oneDNN 3.5, those checks must be skipped if wei_decomp is enabled
118+
// reference from src/plugins/intel_cpu/thirdparty/onednn/src/common/matmul.cpp:L62
119+
if (is_weight_compression) {
120+
attr_mask |= smask_t::zero_points_runtime_data_type;
121+
attr_mask |= smask_t::zero_points_runtime_groups;
122+
attr_mask |= smask_t::scales_runtime_data_type;
123+
attr_mask |= smask_t::scales_runtime_groups;
124+
}
125+
if (attr->has_default_values(attr_mask)) return status::success;
113126

114127
// Check attributes
115128
if (utils::one_of(desc.prop_kind, prop_kind::forward_inference,
116129
prop_kind::forward_training)) {
117130
const data_type_t src_dt = desc.src_desc.data_type;
118131
const data_type_t dst_dt = desc.dst_desc.data_type;
132+
const data_type_t wei_dt = desc.weights_desc.data_type;
119133

120134
auto fwd_attr_mask
121135
= smask_t::post_ops | smask_t::sum_dt | smask_t::fpmath_mode;
@@ -125,8 +139,12 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
125139
is_int8 = is_int8
126140
|| utils::one_of(dst_dt, data_type::s8, data_type::u8,
127141
data_type::s32);
128-
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime;
129-
142+
if (engine->kind() == engine_kind::cpu)
143+
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
144+
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime | smask_t::src_dyn_quant_params;
145+
if (is_weight_compression) {
146+
fwd_attr_mask |= attr_mask;
147+
}
130148
VCHECK_IP_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt),
131149
VERBOSE_UNSUPPORTED_ATTR);
132150

src/common/memory_desc_wrapper.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,7 @@ struct memory_desc_wrapper : public c_compatible {
298298
max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
299299
}
300300

301-
size_t data_size = max_size * data_type_size()
302-
/ sub_byte_data_type_multiplier();
301+
size_t data_size = max_size * data_type_size() / sub_byte_data_type_multiplier();
303302
if (is_additional_buffer()) {
304303
// The additional buffers, typically of data type int32_t, float
305304
// are stored at the end of data. Pad the data, so that the

src/common/memory_tracking.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ enum {
316316
key_wino_V,
317317
key_wino_M,
318318
key_wino_workspace,
319+
key_decompression_scales,
320+
key_decompression_zero_points,
321+
key_src_quantized,
322+
key_src_dequantized_scales,
319323
// These two keys should always be the last ones,
320324
// even though they are not in alphabetical order
321325
key_nested,

src/common/memory_zero_pad.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
294294
case s4: return typed_zero_pad<s8>(memory, ctx);
295295
case u4: return typed_zero_pad<u8>(memory, ctx);
296296
case bin: return typed_zero_pad<u8>(memory, ctx);
297+
case nf4: return typed_zero_pad<u8>(memory, ctx);
297298
default: assert(!"memory is undefined"); return unimplemented;
298299
}
299300
return unimplemented;

src/common/primitive_attr.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ status_t zero_points_t::set(int arg, int mask, int ndims, const dims_t groups,
123123
data_type_wei = data_type;
124124
group_ndims_wei = ndims;
125125
utils::array_copy(group_dims_wei, groups, group_ndims_wei);
126+
utils::array_copy(dims_wei, groups, ndims);
127+
ndims_wei = ndims;
126128
break;
127129
case DNNL_ARG_DST:
128130
is_set_dst = true;
@@ -132,6 +134,23 @@ status_t zero_points_t::set(int arg, int mask, int ndims, const dims_t groups,
132134
return status::success;
133135
}
134136

137+
status_t zero_points_t::set(int arg, const dims_t dims, int ndims, data_type_t data_type) {
138+
const bool supported_arg
139+
= utils::one_of(arg, DNNL_ARG_WEIGHTS);
140+
if (!supported_arg) return status::unimplemented;
141+
142+
switch (arg) {
143+
case DNNL_ARG_WEIGHTS:
144+
is_set_wei = true;
145+
ndims_wei = ndims;
146+
mask_wei = 1;
147+
utils::array_copy(dims_wei, dims, ndims);
148+
data_type_wei = data_type;
149+
break;
150+
}
151+
return status::success;
152+
}
153+
135154
status_t dropout_t::set_default_formats(const memory_desc_t *dst_md) {
136155
auto is_any_or_undef = [](format_kind_t kind) {
137156
return one_of(kind, dnnl_format_kind_any, dnnl_format_kind_undef);
@@ -177,6 +196,8 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
177196
CHECK_ARG(
178197
IMPLICATION((bool)(~mask & smask_t::zero_points_runtime_data_type),
179198
zero_points_.has_default_data_type()));
199+
CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::zero_points),
200+
zero_points_.has_default_data_type()));
180201
CHECK_MASK(smask_t::input_zero_points, input_zero_points_);
181202
CHECK_MASK(smask_t::weights_zero_points, weights_zero_points_);
182203
CHECK_MASK(smask_t::output_compensations, output_compensations_);
@@ -185,6 +206,7 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
185206
CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_);
186207
CHECK_MASK(smask_t::rnn_weights_projection_qparams,
187208
rnn_weights_projection_qparams_);
209+
CHECK_MASK(smask_t::src_dyn_quant_params, src_dyn_quant_params_);
188210
CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::sum_dt),
189211
post_ops_.sum_with_default_dt(dst_dt)));
190212
bool gpu_attr_ok = IMPLICATION((bool)(~mask & smask_t::gpu_attr),
@@ -222,6 +244,7 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
222244
CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_);
223245
CHECK_MASK(smask_t::rnn_weights_projection_qparams,
224246
rnn_weights_projection_qparams_);
247+
CHECK_MASK(smask_t::src_dyn_quant_params, src_dyn_quant_params_);
225248
return ok;
226249
#undef CHECK_MASK
227250
#undef CHECK_ARG
@@ -673,6 +696,13 @@ status_t dnnl_primitive_attr_set_scales_mask(
673696
if (!ok) return invalid_arguments;
674697
return attr->scales_.set(arg, mask);
675698
}
699+
status_t dnnl_primitive_attr_set_scales_dims(
700+
primitive_attr_t *attr, int arg, const dims_t dims, int ndims) {
701+
bool ok = attr && arg >= 0 && ndims > 0
702+
&& attr->output_scales_.has_default_values();
703+
if (!ok) return invalid_arguments;
704+
return attr->scales_.set(arg, dims, ndims);
705+
}
676706

677707
status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg,
678708
int mask, int ndims, const dims_t group_dims, data_type_t data_type) {
@@ -700,6 +730,13 @@ status_t dnnl_primitive_attr_set_zero_points_mask(
700730

701731
return attr->zero_points_.set(arg, mask);
702732
}
733+
status_t dnnl_primitive_attr_set_zero_points_dims(
734+
primitive_attr_t *attr, int arg, const dims_t dims, int ndims, dnnl_data_type_t data_type) {
735+
bool ok = attr && ndims > 0;
736+
if (!ok) return invalid_arguments;
737+
738+
return attr->zero_points_.set(arg, dims, ndims, data_type);
739+
}
703740

704741
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
705742
dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
@@ -1028,6 +1065,21 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams(
10281065
return attr->rnn_tparams_.set(mode, ngates, scales, cscale);
10291066
}
10301067

1068+
status_t dnnl_primitive_attr_set_src_dyn_quant_params(
1069+
primitive_attr_t *attr, const uint64_t group_size) {
1070+
if (attr == nullptr) return invalid_arguments;
1071+
1072+
return attr->src_dyn_quant_params_.set(group_size);
1073+
}
1074+
1075+
status_t dnnl_primitive_attr_get_src_dyn_quant_params(
1076+
primitive_attr_t *attr, uint64_t* group_size) {
1077+
if (attr == nullptr) return invalid_arguments;
1078+
1079+
if (group_size) *group_size = attr->src_dyn_quant_params_.get();
1080+
return success;
1081+
}
1082+
10311083
template struct dnnl::impl::shifts_t<uint8_t>;
10321084
template struct dnnl::impl::shifts_t<int32_t>;
10331085
template struct dnnl::impl::shifts_t<float>;

0 commit comments

Comments
 (0)