@@ -73,6 +73,29 @@ status_t scales_t::set(dim_t count, int mask, const float *scales) {
73
73
return status::success;
74
74
}
75
75
76
+
77
+ template <typename T>
78
+ status_t shifts_t <T>::set(int count, int mask, const T *shifts) {
79
+ cleanup ();
80
+
81
+ count_ = count;
82
+ mask_ = mask;
83
+
84
+ if (count_ == 1 ) {
85
+ shifts_ = shifts_buf_;
86
+ utils::array_set (shifts_, shifts[0 ], shifts_buf_size);
87
+ } else {
88
+ shifts_ = (T *)impl::malloc (count_ * sizeof (*shifts_), 64 );
89
+ if (shifts_ == nullptr )
90
+ return status::out_of_memory;
91
+
92
+ for (int c = 0 ; c < count_; ++c)
93
+ shifts_[c] = shifts[c];
94
+ }
95
+
96
+ return status::success;
97
+ }
98
+
76
99
status_t zero_points_t::get (int arg, int *mask, data_type_t *dt) const {
77
100
if (mask) *mask = get_mask (arg);
78
101
if (dt) *dt = get_data_type (arg);
@@ -182,6 +205,14 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
182
205
#undef CHECK_ARG
183
206
}
184
207
208
+ bool primitive_attr_t::has_asymmetric_quantization () const {
209
+ return true
210
+ && output_scales_.has_default_values ()
211
+ && rnn_data_qparams_.has_default_values ()
212
+ && rnn_weights_qparams_.has_default_values ()
213
+ && (!input_zero_points_.has_default_values () || !weights_zero_points_.has_default_values ());
214
+ }
215
+
185
216
bool primitive_attr_t::defined (dnnl_primitive_attr::skip_mask_t mask) const {
186
217
using smask_t = skip_mask_t ;
187
218
bool ok = true ;
@@ -313,6 +344,47 @@ status_t post_ops_t::append_prelu(int mask) {
313
344
return success;
314
345
}
315
346
347
+ status_t post_ops_t::append_depthwise (alg_kind_t alg, const float * weights_data, const float * biases_data) {
348
+ using namespace dnnl ::impl::alg_kind;
349
+ if (len () == post_ops_limit) return out_of_memory;
350
+ bool known_alg = one_of (alg, depthwise_scale_shift, depthwise_prelu);
351
+ if (!known_alg)
352
+ return invalid_arguments;
353
+
354
+ entry_.emplace_back ();
355
+ auto &e = entry_.back ();
356
+ e.kind = primitive_kind::depthwise;
357
+ e.depthwise .alg = alg;
358
+ e.depthwise .weights_data = weights_data;
359
+ e.depthwise .biases_data = biases_data;
360
+
361
+ return success;
362
+ }
363
+
364
+ status_t post_ops_t::append_quantization (alg_kind_t alg,
365
+ const void * crop_low, const void * crop_high,
366
+ const void * input_scale, const void * input_shift,
367
+ const void * output_scale, const void * output_shift) {
368
+ using namespace dnnl ::impl::alg_kind;
369
+ if (len () == post_ops_limit) return out_of_memory;
370
+ bool known_alg = one_of (alg, quantization_quantize_dequantize, quantization_quantize);
371
+ if (!known_alg)
372
+ return invalid_arguments;
373
+
374
+ entry_.emplace_back ();
375
+ auto &e = entry_.back ();
376
+ e.kind = primitive_kind::quantization;
377
+ e.quantization .alg = alg;
378
+ e.quantization .crop_low_data = reinterpret_cast <const shifts_t <float >*>(crop_low);
379
+ e.quantization .crop_high_data = reinterpret_cast <const shifts_t <float >*>(crop_high);
380
+ e.quantization .input_scale_data = reinterpret_cast <const scales_t *>(input_scale);
381
+ e.quantization .input_shift_data = reinterpret_cast <const shifts_t <float >*>(input_shift);
382
+ e.quantization .output_scale_data = reinterpret_cast <const scales_t *>(output_scale);
383
+ e.quantization .output_shift_data = reinterpret_cast <const shifts_t <float >*>(output_shift);
384
+
385
+ return success;
386
+ }
387
+
316
388
bool post_ops_t::defined () const {
317
389
for (int idx = 0 ; idx < len (); ++idx) {
318
390
auto kind = entry_[idx].kind ;
@@ -327,6 +399,10 @@ bool post_ops_t::defined() const {
327
399
primitive_kind::prelu,
328
400
primitive_kind::convolution)) {
329
401
// binary is always defined
402
+ } else if (kind == primitive_kind::depthwise) {
403
+ // depthwise is always defined
404
+ } else if (kind == primitive_kind::quantization) {
405
+ // quantization is always defined
330
406
} else {
331
407
assert (!" unreachable" );
332
408
}
@@ -787,6 +863,23 @@ status_t dnnl_post_ops_get_params_prelu(
787
863
return success;
788
864
}
789
865
866
+ status_t dnnl_post_ops_append_depthwise (dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
867
+ const float * weights_data, const float * biases_data) {
868
+ if (post_ops == nullptr ) return invalid_arguments;
869
+
870
+ return post_ops->append_depthwise (alg, weights_data, biases_data);
871
+ }
872
+
873
+ status_t dnnl_post_ops_append_quantization (post_ops_t *post_ops, alg_kind_t kind,
874
+ const void * crop_low, const void * crop_high,
875
+ const void * input_scale, const void * input_shift,
876
+ const void * output_scale, const void * output_shift) {
877
+ if (post_ops == nullptr )
878
+ return invalid_arguments;
879
+
880
+ return post_ops->append_quantization (kind, crop_low, crop_high, input_scale, input_shift, output_scale, output_shift);
881
+ }
882
+
790
883
status_t dnnl_primitive_attr_set_rnn_data_qparams (
791
884
primitive_attr_t *attr, const float scale, const float shift) {
792
885
if (attr == nullptr ) return invalid_arguments;
@@ -854,3 +947,7 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams(
854
947
855
948
return attr->rnn_tparams_ .set (mode, ngates, scales, cscale);
856
949
}
950
+
951
+ template struct dnnl ::impl::shifts_t <uint8_t >;
952
+ template struct dnnl ::impl::shifts_t <int32_t >;
953
+ template struct dnnl ::impl::shifts_t <float >;
0 commit comments