Skip to content

Commit 685bce3

Browse files
dmitry-gorokhovazhai219
authored andcommittedNov 27, 2024
[FORK][FEATURE] Asymmetric quntization for activations
1 parent 44bec63 commit 685bce3

30 files changed

+542
-211
lines changed
 

‎include/oneapi/dnnl/dnnl.h

+18
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,24 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding(
515515
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rounding(
516516
dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t *mode);
517517

518+
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_compensations(
519+
const_dnnl_primitive_attr_t attr, int *count, int *mask, const int32_t **compensations);
520+
521+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_compensations(
522+
dnnl_primitive_attr_t attr, int count, int mask, const int32_t *compensations);
523+
524+
dnnl_status_t DNNL_API dnnl_primitive_attr_get_input_zero_points(
525+
const_dnnl_primitive_attr_t attr, int *count, int *mask, const uint8_t **zero_points);
526+
527+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_input_zero_points(
528+
dnnl_primitive_attr_t attr, int count, int mask, const uint8_t *zero_points);
529+
530+
dnnl_status_t DNNL_API dnnl_primitive_attr_get_weights_zero_points(
531+
const_dnnl_primitive_attr_t attr, int *count, int *mask, const float **zero_points);
532+
533+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_weights_zero_points(
534+
dnnl_primitive_attr_t attr, int count, int mask, const float *zero_points);
535+
518536
/// Returns primitive attributes post-ops.
519537
///
520538
/// @warning

‎include/oneapi/dnnl/dnnl.hpp

+63
Original file line numberDiff line numberDiff line change
@@ -4204,6 +4204,69 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
42044204
"could not set zero points primitive attribute");
42054205
}
42064206

4207+
void get_output_compensations(int &mask, std::vector<int32_t> &compensations) const
4208+
{
4209+
int count, c_mask;
4210+
const int32_t *c_compensations;
4211+
error::wrap_c_api(dnnl_primitive_attr_get_output_compensations(get(),
4212+
&count, &c_mask, &c_compensations),
4213+
"could not get int output compensations");
4214+
compensations.resize(count);
4215+
4216+
mask = c_mask;
4217+
for (int c = 0; c < count; ++c)
4218+
compensations[c] = c_compensations[c];
4219+
}
4220+
4221+
void set_output_compensations(int mask, const std::vector<int32_t> &compensations)
4222+
{
4223+
error::wrap_c_api(dnnl_primitive_attr_set_output_compensations(get(),
4224+
(int)compensations.size(), mask, &compensations[0]),
4225+
"could not set int output compensations");
4226+
}
4227+
4228+
void get_input_zero_points(int &mask, std::vector<uint8_t> &zero_points) const
4229+
{
4230+
int count, c_mask;
4231+
const uint8_t *c_zero_points;
4232+
error::wrap_c_api(dnnl_primitive_attr_get_input_zero_points(get(),
4233+
&count, &c_mask, &c_zero_points),
4234+
"could not get int input zero_points");
4235+
zero_points.resize(count);
4236+
4237+
mask = c_mask;
4238+
for (int c = 0; c < count; ++c)
4239+
zero_points[c] = c_zero_points[c];
4240+
}
4241+
4242+
void set_input_zero_points(int mask, const std::vector<uint8_t> &zero_points)
4243+
{
4244+
error::wrap_c_api(dnnl_primitive_attr_set_input_zero_points(get(),
4245+
(int)zero_points.size(), mask, &zero_points[0]),
4246+
"could not set int input zero_points");
4247+
}
4248+
4249+
void get_weights_zero_points(int &mask, std::vector<int8_t> &zero_points) const
4250+
{
4251+
int count, c_mask;
4252+
const float *c_zero_points;
4253+
error::wrap_c_api(dnnl_primitive_attr_get_weights_zero_points(get(),
4254+
&count, &c_mask, &c_zero_points),
4255+
"could not get int weights zero_points");
4256+
zero_points.resize(count);
4257+
4258+
mask = c_mask;
4259+
for (int c = 0; c < count; ++c)
4260+
zero_points[c] = c_zero_points[c];
4261+
}
4262+
4263+
void set_weights_zero_points(int mask, const std::vector<float> &zero_points)
4264+
{
4265+
error::wrap_c_api(dnnl_primitive_attr_set_weights_zero_points(get(),
4266+
(int)zero_points.size(), mask, &zero_points[0]),
4267+
"could not set int weights zero_points");
4268+
}
4269+
42074270
/// Returns post-ops previously set via set_post_ops().
42084271
///
42094272
/// @returns Post-ops.

‎src/common/primitive_attr.cpp

+66-8
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
177177
CHECK_ARG(
178178
IMPLICATION((bool)(~mask & smask_t::zero_points_runtime_data_type),
179179
zero_points_.has_default_data_type()));
180+
CHECK_MASK(smask_t::input_zero_points, input_zero_points_);
181+
CHECK_MASK(smask_t::weights_zero_points, weights_zero_points_);
182+
CHECK_MASK(smask_t::output_compensations, output_compensations_);
180183
CHECK_MASK(smask_t::post_ops, post_ops_);
181184
CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_);
182185
CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_);
@@ -205,14 +208,6 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
205208
#undef CHECK_ARG
206209
}
207210

208-
bool primitive_attr_t::has_asymmetric_quantization() const {
209-
return true
210-
&& output_scales_.has_default_values()
211-
&& rnn_data_qparams_.has_default_values()
212-
&& rnn_weights_qparams_.has_default_values()
213-
&& (!input_zero_points_.has_default_values() || !weights_zero_points_.has_default_values());
214-
}
215-
216211
bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
217212
using smask_t = skip_mask_t;
218213
bool ok = true;
@@ -700,6 +695,69 @@ status_t dnnl_primitive_attr_set_rounding(
700695
return attr->rounding_mode_.set(arg, mode);
701696
}
702697

698+
status_t dnnl_primitive_attr_get_output_compensations(const primitive_attr_t *attr,
699+
int *count, int *mask, const int32_t **compensations) {
700+
if (any_null(attr, count, mask, compensations))
701+
return invalid_arguments;
702+
703+
*count = attr->output_compensations_.count_;
704+
*mask = attr->output_compensations_.mask_;
705+
*compensations = attr->output_compensations_.shifts_;
706+
707+
return success;
708+
}
709+
710+
status_t dnnl_primitive_attr_set_output_compensations(primitive_attr_t *attr,
711+
int count, int mask, const int32_t *compensations) {
712+
bool ok = !any_null(attr, compensations) && count > 0 && mask >= 0;
713+
if (!ok)
714+
return invalid_arguments;
715+
716+
return attr->output_compensations_.set(count, mask, compensations);
717+
}
718+
719+
status_t dnnl_primitive_attr_get_input_zero_points(const primitive_attr_t *attr,
720+
int *count, int *mask, const uint8_t **zero_points) {
721+
if (any_null(attr, count, mask, zero_points))
722+
return invalid_arguments;
723+
724+
*count = attr->input_zero_points_.count_;
725+
*mask = attr->input_zero_points_.mask_;
726+
*zero_points = attr->input_zero_points_.shifts_;
727+
728+
return success;
729+
}
730+
731+
status_t dnnl_primitive_attr_set_input_zero_points(primitive_attr_t *attr,
732+
int count, int mask, const uint8_t *zero_points) {
733+
bool ok = !any_null(attr, zero_points) && count > 0 && mask >= 0;
734+
if (!ok)
735+
return invalid_arguments;
736+
737+
return attr->input_zero_points_.set(count, mask, zero_points);
738+
}
739+
740+
status_t dnnl_primitive_attr_get_weights_zero_points(const primitive_attr_t *attr,
741+
int *count, int *mask, const float **zero_points) {
742+
if (any_null(attr, count, mask, zero_points))
743+
return invalid_arguments;
744+
745+
*count = attr->weights_zero_points_.count_;
746+
*mask = attr->weights_zero_points_.mask_;
747+
*zero_points = attr->weights_zero_points_.shifts_;
748+
749+
return success;
750+
}
751+
752+
status_t dnnl_primitive_attr_set_weights_zero_points(primitive_attr_t *attr,
753+
int count, int mask, const float *zero_points) {
754+
bool ok = !any_null(attr, zero_points) && count > 0 && mask >= 0;
755+
if (!ok)
756+
return invalid_arguments;
757+
758+
return attr->weights_zero_points_.set(count, mask, zero_points);
759+
}
760+
703761
status_t dnnl_primitive_attr_get_post_ops(
704762
const primitive_attr_t *attr, const post_ops_t **post_ops) {
705763
if (any_null(attr, post_ops)) return invalid_arguments;

‎src/common/primitive_attr.hpp

+26-14
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,15 @@ struct shifts_t: public c_compatible {
190190
shifts_t(): count_(1), mask_(0), shifts_(shifts_buf_)
191191
{ set(0); }
192192

193-
shifts_t(const shifts_t &rhs): shifts_t()
194-
{ set(rhs.count_, rhs.mask_, rhs.shifts_); }
195-
196193
~shifts_t() { cleanup(); }
197194

198-
shifts_t &operator=(const shifts_t &rhs) {
199-
if (&rhs == this)
200-
return *this;
201-
status_t status = set(rhs.count_, rhs.mask_, rhs.shifts_);
202-
assert(status == status::success);
203-
(void)status;
204-
return *this;
195+
bool operator==(const shifts_t<T> &rhs) const {
196+
bool ret = count_ == rhs.count_ && mask_ == rhs.mask_
197+
&& !utils::any_null(shifts_, rhs.shifts_)
198+
&& defined() == rhs.defined()
199+
&& IMPLICATION(defined(),
200+
utils::array_cmp(shifts_, rhs.shifts_, count_));
201+
return ret;
205202
}
206203

207204
bool has_default_values() const {
@@ -211,10 +208,16 @@ struct shifts_t: public c_compatible {
211208
return true;
212209
}
213210

211+
bool defined() const { return !is_runtime_value(shifts_[0]); }
212+
214213
status_t set(int count, int mask, const T *zero_points);
215214
status_t set(T single_zero_point) { return this->set(1, 0, &single_zero_point); }
216215

217-
int count_;
216+
status_t copy_from(const shifts_t &other) {
217+
return set(other.count_, other.mask_, other.shifts_);
218+
}
219+
220+
dim_t count_;
218221
int mask_;
219222
T *shifts_;
220223

@@ -230,6 +233,8 @@ struct shifts_t: public c_compatible {
230233
mask_ = 0;
231234
shifts_ = shifts_buf_;
232235
}
236+
237+
DNNL_DISALLOW_COPY_AND_ASSIGN(shifts_t);
233238
};
234239

235240
struct runtime_scales_t : public c_compatible {
@@ -948,6 +953,9 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
948953
CHECK(rnn_tparams_.copy_from(other.rnn_tparams_));
949954
if (other.gpu_attr_) gpu_attr_ = other.gpu_attr_->clone();
950955
dropout_ = other.dropout_;
956+
input_zero_points_.copy_from(other.input_zero_points_);
957+
weights_zero_points_.copy_from(other.weights_zero_points_);
958+
output_compensations_.copy_from(other.output_compensations_);
951959

952960
return status::success;
953961
}
@@ -978,6 +986,9 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
978986
= (unsigned)zero_points_runtime | (1u << 18),
979987
dropout = 1u << 19,
980988
rounding_mode = 1u << 20,
989+
input_zero_points = 1 << 21,
990+
weights_zero_points = 1 << 22,
991+
output_compensations = 1 << 23,
981992
};
982993

983994
/** Returns true if the attributes have default values.
@@ -986,8 +997,6 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
986997
bool has_default_values(skip_mask_t mask = skip_mask_t::none,
987998
dnnl::impl::data_type_t dst_dt = dnnl_data_type_undef) const;
988999

989-
bool has_asymmetric_quantization() const;
990-
9911000
/** Returns true if the attributes are fully defined. */
9921001
bool defined(skip_mask_t mask = skip_mask_t::none) const;
9931002

@@ -1007,7 +1016,10 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
10071016
&& gpu_attr_->is_equal(*rhs.gpu_attr_))
10081017
|| (!gpu_attr_ && !rhs.gpu_attr_))
10091018
&& dropout_ == rhs.dropout_
1010-
&& rounding_mode_ == rhs.rounding_mode_;
1019+
&& rounding_mode_ == rhs.rounding_mode_
1020+
&& input_zero_points_ == rhs.input_zero_points_
1021+
&& weights_zero_points_ == rhs.weights_zero_points_
1022+
&& output_compensations_ == rhs.output_compensations_;
10111023
return ret;
10121024
}
10131025

0 commit comments

Comments
 (0)