Skip to content

Commit f82148b

Browse files
author
dmitrygo
committed
[FORK][FEATURE] InnerProduct primitive: src dynamic quantization
1 parent 538fbf9 commit f82148b

27 files changed

+1080
-128
lines changed

include/oneapi/dnnl/dnnl.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims(
357357
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
358358
dnnl_primitive_attr_t attr, int arg, int mask);
359359
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_dims(
360-
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims);
360+
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type);
361361

362362
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_compensations(
363363
dnnl_primitive_attr_t attr, int count, int mask);
@@ -2402,6 +2402,9 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
24022402
const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
24032403
const float **scales);
24042404

2405+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_src_dyn_quant_params(
2406+
dnnl_primitive_attr_t attr, uint64_t group_size);
2407+
24052408
/// @} dnnl_api_attributes
24062409

24072410
/// @addtogroup dnnl_api_rnn

include/oneapi/dnnl/dnnl.hpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -3924,9 +3924,9 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
39243924
dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
39253925
"could not set zero points primitive attribute");
39263926
}
3927-
void set_zero_points_dims(int arg, const memory::dims& dims) {
3927+
void set_zero_points_dims(int arg, const memory::dims& dims, memory::data_type dt) {
39283928
error::wrap_c_api(
3929-
dnnl_primitive_attr_set_zero_points_dims(get(), arg, dims.data(), dims.size()),
3929+
dnnl_primitive_attr_set_zero_points_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(dt)),
39303930
"could not set zero points primitive attribute");
39313931
}
39323932

@@ -4171,6 +4171,11 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
41714171
for (dnnl_dim_t c = 0; c < count; c++)
41724172
scales[c] = c_scales[c];
41734173
}
4174+
4175+
void set_src_dyn_quant_params(uint64_t group_size) {
4176+
error::wrap_c_api(dnnl_primitive_attr_set_src_dyn_quant_params(get(), group_size),
4177+
"could not set src dynamic quantization parameters primitive attribute");
4178+
}
41744179
};
41754180

41764181
/// @} dnnl_api_attributes

src/common/inner_product.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
127127
data_type::s32);
128128
if (engine->kind() == engine_kind::cpu)
129129
is_int8 |= one_of(wei_dt, data_type::u8, data_type::nf4, data_type::s4, data_type::u4);
130-
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime;
130+
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime | smask_t::src_dyn_quant_params;
131131

132132
VCHECK_IP_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt),
133133
VERBOSE_UNSUPPORTED_ATTR);

src/common/memory_tracking.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ enum {
289289
key_wino_M,
290290
key_decompression_scales,
291291
key_decompression_zero_points,
292+
key_src_quantized,
293+
key_src_dequantized_scales,
292294
// These two keys should always be the last ones,
293295
// even though they are not in alphabetical order
294296
key_nested,

src/common/primitive_attr.cpp

+15-3
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ status_t zero_points_t::set(int arg, int mask) {
120120
return status::success;
121121
}
122122

123-
status_t zero_points_t::set(int arg, const dims_t dims, int ndims) {
123+
status_t zero_points_t::set(int arg, const dims_t dims, int ndims, data_type_t data_type) {
124124
const bool supported_arg
125125
= utils::one_of(arg, DNNL_ARG_WEIGHTS);
126126
if (!supported_arg) return status::unimplemented;
@@ -131,6 +131,7 @@ status_t zero_points_t::set(int arg, const dims_t dims, int ndims) {
131131
ndims_wei = ndims;
132132
mask_wei = 1;
133133
utils::array_copy(dims_wei, dims, ndims);
134+
data_type_wei = data_type;
134135
break;
135136
}
136137
return status::success;
@@ -159,6 +160,8 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
159160
CHECK_MASK(smask_t::oscale_runtime, output_scales_);
160161
CHECK_MASK(smask_t::scales, scales_);
161162
CHECK_MASK(smask_t::zero_points, zero_points_);
163+
CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::zero_points),
164+
zero_points_.has_default_data_type()));
162165
CHECK_MASK(smask_t::input_zero_points, input_zero_points_);
163166
CHECK_MASK(smask_t::weights_zero_points, weights_zero_points_);
164167
CHECK_MASK(smask_t::output_compensations, output_compensations_);
@@ -167,6 +170,7 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
167170
CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_);
168171
CHECK_MASK(smask_t::rnn_weights_projection_qparams,
169172
rnn_weights_projection_qparams_);
173+
CHECK_MASK(smask_t::src_dyn_quant_params, src_dyn_quant_params_);
170174
CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::sum_dt),
171175
post_ops_.sum_with_default_dt(dst_dt)));
172176
bool gpu_attr_ok = IMPLICATION((bool)(~mask & smask_t::gpu_attr),
@@ -192,6 +196,7 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
192196
CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_);
193197
CHECK_MASK(smask_t::rnn_weights_projection_qparams,
194198
rnn_weights_projection_qparams_);
199+
CHECK_MASK(smask_t::src_dyn_quant_params, src_dyn_quant_params_);
195200
return ok;
196201
#undef CHECK_MASK
197202
#undef CHECK_ARG
@@ -581,11 +586,11 @@ status_t dnnl_primitive_attr_set_zero_points_mask(
581586
return attr->zero_points_.set(arg, mask);
582587
}
583588
status_t dnnl_primitive_attr_set_zero_points_dims(
584-
primitive_attr_t *attr, int arg, const dims_t dims, int ndims) {
589+
primitive_attr_t *attr, int arg, const dims_t dims, int ndims, dnnl_data_type_t data_type) {
585590
bool ok = attr && ndims > 0;
586591
if (!ok) return invalid_arguments;
587592

588-
return attr->zero_points_.set(arg, dims, ndims);
593+
return attr->zero_points_.set(arg, dims, ndims, data_type);
589594
}
590595

591596
status_t dnnl_primitive_attr_set_output_compensations(primitive_attr_t *attr,
@@ -887,6 +892,13 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams(
887892
return attr->rnn_tparams_.set(mode, ngates, scales, cscale);
888893
}
889894

895+
status_t dnnl_primitive_attr_set_src_dyn_quant_params(
896+
primitive_attr_t *attr, const uint64_t group_size) {
897+
if (attr == nullptr) return invalid_arguments;
898+
899+
return attr->src_dyn_quant_params_.set(group_size);
900+
}
901+
890902
template struct dnnl::impl::shifts_t<uint8_t>;
891903
template struct dnnl::impl::shifts_t<int32_t>;
892904
template struct dnnl::impl::shifts_t<float>;

src/common/primitive_attr.hpp

+44-4
Original file line numberDiff line numberDiff line change
@@ -371,26 +371,33 @@ struct zero_points_t : public c_compatible {
371371
return mask_src == rhs.mask_src && mask_wei == rhs.mask_wei
372372
&& mask_dst == rhs.mask_dst && is_set_src == rhs.is_set_src
373373
&& is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst
374-
&& IMPLICATION(ndims_wei > 0, ndims_wei == rhs.ndims_wei && utils::array_cmp(dims_wei, rhs.dims_wei, ndims_wei));
374+
&& IMPLICATION(ndims_wei > 0, ndims_wei == rhs.ndims_wei && utils::array_cmp(dims_wei, rhs.dims_wei, ndims_wei))
375+
&& data_type_wei == rhs.data_type_wei;
375376
}
376377

377378
// arg-specific checks
378379
bool common(int arg) const { return get_mask(arg) == 0; }
379380
bool defined(int arg) const { return has_default_values(arg); }
380-
bool has_default_values(int arg) const { return is_set(arg) == false; }
381+
bool has_default_values(int arg) const { return is_set(arg) == false && has_default_data_type(arg); }
382+
bool has_default_data_type(int arg) const {
383+
return get_data_type(arg) == data_type::s32;
384+
}
381385

382386
// same checks but for all supported arguments at once
383387
bool common() const { return check_all(&zero_points_t::common); }
384388
bool defined() const { return has_default_values(); }
385389
bool has_default_values() const {
386390
return check_all(&zero_points_t::has_default_values);
387391
}
392+
bool has_default_data_type() const {
393+
return check_all(&zero_points_t::has_default_data_type);
394+
}
388395

389396
status_t get(int arg, int *mask) const;
390397
int get(int arg) const; // Returns 0 if dimension is unset
391398

392399
status_t set(int arg, int mask);
393-
status_t set(int arg, const dims_t dims, int ndims);
400+
status_t set(int arg, const dims_t dims, int ndims, data_type_t data_type);
394401
status_t set(int arg) { return set(arg, 0); }
395402

396403
const dims_t & get_dims(int /*arg*/) const {
@@ -403,9 +410,15 @@ struct zero_points_t : public c_compatible {
403410
}
404411
}
405412

413+
data_type_t get_data_type(int arg) const {
414+
if (arg == DNNL_ARG_WEIGHTS) return data_type_wei;
415+
return data_type::s32;
416+
}
417+
406418
private:
407419
bool is_set_src = false, is_set_wei = false, is_set_dst = false;
408420
int mask_src = 0, mask_wei = 0, mask_dst = 0;
421+
data_type_t data_type_wei = data_type::s32;
409422

410423
int ndims_wei = 0;
411424
dnnl::impl::dims_t dims_wei;
@@ -470,6 +483,28 @@ struct legacy_zero_points_t : public c_compatible {
470483
int mask_ = 0;
471484
};
472485

486+
struct src_dyn_quant_params_t : public c_compatible {
487+
src_dyn_quant_params_t() : group_size_(0) {}
488+
bool has_default_values() const {
489+
return (group_size_ == 0);
490+
}
491+
bool defined() const {
492+
return true;
493+
}
494+
495+
status_t set(uint64_t group_size) {
496+
group_size_ = group_size;
497+
return status::success;
498+
}
499+
500+
bool operator==(const src_dyn_quant_params_t &rhs) const {
501+
using namespace utils;
502+
return group_size_ == rhs.group_size_;
503+
}
504+
505+
uint64_t group_size_;
506+
};
507+
473508
} // namespace impl
474509
} // namespace dnnl
475510

@@ -882,6 +917,7 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
882917
input_zero_points_ = (other.input_zero_points_);
883918
weights_zero_points_ = (other.weights_zero_points_);
884919
output_compensations_ = (other.output_compensations_);
920+
src_dyn_quant_params_ = other.src_dyn_quant_params_;
885921

886922
return status::success;
887923
}
@@ -906,6 +942,7 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
906942
input_zero_points = 1 << 13,
907943
weights_zero_points = 1 << 14,
908944
output_compensations = 1 << 15,
945+
src_dyn_quant_params = 1u << 16,
909946
};
910947

911948
/** Returns true if the attributes have default values.
@@ -933,7 +970,8 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
933970
|| (!gpu_attr_ && !rhs.gpu_attr_))
934971
&& input_zero_points_ == rhs.input_zero_points_
935972
&& weights_zero_points_ == rhs.weights_zero_points_
936-
&& output_compensations_ == rhs.output_compensations_;
973+
&& output_compensations_ == rhs.output_compensations_
974+
&& src_dyn_quant_params_ == rhs.src_dyn_quant_params_;
937975
return ret;
938976
}
939977

@@ -985,6 +1023,8 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
9851023
dnnl::impl::legacy_zero_points_t weights_zero_points_;
9861024
dnnl::impl::legacy_zero_points_t output_compensations_;
9871025

1026+
dnnl::impl::src_dyn_quant_params_t src_dyn_quant_params_;
1027+
9881028
dnnl_primitive_attr &operator=(const dnnl_primitive_attr &other) = delete;
9891029
};
9901030

src/common/primitive_hashing_utils.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ size_t get_attr_hash(const primitive_attr_t &attr) {
167167
if (attr.gpu_attr_) {
168168
seed = hash_combine(seed, attr.gpu_attr_->get_hash());
169169
}
170+
seed = hash_combine(seed, attr.src_dyn_quant_params_.group_size_);
170171
// Combined hash for attributes
171172
return seed;
172173
}

src/common/serialization.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ void serialize_attr(
213213
int zero = 0;
214214
sstream.write(&zero);
215215
}
216+
sstream.write(&attr.src_dyn_quant_params_.group_size_);
216217
}
217218

218219
void serialize_desc(

src/common/verbose.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,12 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) {
727727
const rnn_data_qparams_t &rnn_qp = attr->rnn_data_qparams_;
728728
if (!rnn_qp.has_default_values()) {
729729
ss << "rnn_data_qparams:" << rnn_qp.scale_ << ":" << rnn_qp.shift_
730-
<< ";";
730+
<< " ";
731+
}
732+
733+
const src_dyn_quant_params_t &dyn_qp = attr->src_dyn_quant_params_;
734+
if (!dyn_qp.has_default_values()) {
735+
ss << "src_dyn_quant_group_size:" << dyn_qp.group_size_ << ";";
731736
}
732737

733738
return ss;

src/cpu/cpu_inner_product_list.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
5353
nullptr,
5454
}},
5555
{{forward, f32, u8, f32}, {
56+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni)
5657
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
58+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2_vnni)
5759
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
5860
nullptr,
5961
}},
@@ -68,7 +70,9 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
6870
nullptr,
6971
}},
7072
{{forward, f32, u4, f32}, {
73+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni)
7174
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
75+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2_vnni)
7276
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
7377
nullptr,
7478
}},

src/cpu/reorder/cpu_reorder_regular_nf4.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ const impl_list_map_t &regular_nf4_impl_list_map() {
3535
REG_SR(nf4, any, nf4, OI16i32o2i, fmt_order_keep)
3636
REG_SR(nf4, any, nf4, OI16i48o2i, fmt_order_keep)
3737
REG_SR(nf4, any, nf4, OI16i64o2i, fmt_order_keep)
38+
REG_SR(nf4, any, f32, any, fmt_order_keep, spec::reference)
3839
nullptr,
3940
}},
4041
});

src/cpu/reorder/cpu_reorder_regular_s4.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ const impl_list_map_t &regular_s4_impl_list_map() {
3535
REG_SR(s4, any, s4, OI16i32o2i, fmt_order_keep)
3636
REG_SR(s4, any, s4, OI16i48o2i, fmt_order_keep)
3737
REG_SR(s4, any, s4, OI16i64o2i, fmt_order_keep)
38+
REG_SR(s4, any, u8, any, fmt_order_keep, spec::reference)
39+
REG_SR(s4, any, f32, any, fmt_order_keep, spec::reference)
3840
nullptr,
3941
}},
4042
});

src/cpu/reorder/cpu_reorder_regular_u4.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ const impl_list_map_t &regular_u4_impl_list_map() {
3535
REG_SR(u4, any, u4, OI16i32o2i, fmt_order_keep)
3636
REG_SR(u4, any, u4, OI16i48o2i, fmt_order_keep)
3737
REG_SR(u4, any, u4, OI16i64o2i, fmt_order_keep)
38+
REG_SR(u4, any, u4, OI16i16o4i, fmt_order_keep)
39+
REG_SR(u4, any, u4, OI16i32o4i, fmt_order_keep)
40+
REG_SR(u4, any, u4, OI16i48o4i, fmt_order_keep)
41+
REG_SR(u4, any, u4, OI16i64o4i, fmt_order_keep)
42+
REG_SR(u4, any, u8, any, fmt_order_keep, spec::reference)
43+
REG_SR(u4, any, f32, any, fmt_order_keep, spec::reference)
3844
nullptr,
3945
}},
4046
});

0 commit comments

Comments
 (0)