@@ -371,26 +371,33 @@ struct zero_points_t : public c_compatible {
371
371
return mask_src == rhs.mask_src && mask_wei == rhs.mask_wei
372
372
&& mask_dst == rhs.mask_dst && is_set_src == rhs.is_set_src
373
373
&& is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst
374
- && IMPLICATION (ndims_wei > 0 , ndims_wei == rhs.ndims_wei && utils::array_cmp (dims_wei, rhs.dims_wei , ndims_wei));
374
+ && IMPLICATION (ndims_wei > 0 , ndims_wei == rhs.ndims_wei && utils::array_cmp (dims_wei, rhs.dims_wei , ndims_wei))
375
+ && data_type_wei == rhs.data_type_wei ;
375
376
}
376
377
377
378
// arg-specific checks
378
379
bool common (int arg) const { return get_mask (arg) == 0 ; }
379
380
bool defined (int arg) const { return has_default_values (arg); }
380
- bool has_default_values (int arg) const { return is_set (arg) == false ; }
381
+ bool has_default_values (int arg) const { return is_set (arg) == false && has_default_data_type (arg); }
382
+ bool has_default_data_type (int arg) const {
383
+ return get_data_type (arg) == data_type::s32;
384
+ }
381
385
382
386
// same checks but for all supported arguments at once
383
387
bool common () const { return check_all (&zero_points_t ::common); }
384
388
bool defined () const { return has_default_values (); }
385
389
bool has_default_values () const {
386
390
return check_all (&zero_points_t ::has_default_values);
387
391
}
392
+ bool has_default_data_type () const {
393
+ return check_all (&zero_points_t ::has_default_data_type);
394
+ }
388
395
389
396
status_t get (int arg, int *mask) const ;
390
397
int get (int arg) const ; // Returns 0 if dimension is unset
391
398
392
399
status_t set (int arg, int mask);
393
- status_t set (int arg, const dims_t dims, int ndims);
400
+ status_t set (int arg, const dims_t dims, int ndims, data_type_t data_type );
394
401
status_t set (int arg) { return set (arg, 0 ); }
395
402
396
403
const dims_t & get_dims (int /* arg*/ ) const {
@@ -403,9 +410,15 @@ struct zero_points_t : public c_compatible {
403
410
}
404
411
}
405
412
413
+ data_type_t get_data_type (int arg) const {
414
+ if (arg == DNNL_ARG_WEIGHTS) return data_type_wei;
415
+ return data_type::s32;
416
+ }
417
+
406
418
private:
407
419
bool is_set_src = false , is_set_wei = false , is_set_dst = false ;
408
420
int mask_src = 0 , mask_wei = 0 , mask_dst = 0 ;
421
+ data_type_t data_type_wei = data_type::s32;
409
422
410
423
int ndims_wei = 0 ;
411
424
dnnl::impl::dims_t dims_wei;
@@ -470,6 +483,28 @@ struct legacy_zero_points_t : public c_compatible {
470
483
int mask_ = 0 ;
471
484
};
472
485
486
+ struct src_dyn_quant_params_t : public c_compatible {
487
+ src_dyn_quant_params_t () : group_size_(0 ) {}
488
+ bool has_default_values () const {
489
+ return (group_size_ == 0 );
490
+ }
491
+ bool defined () const {
492
+ return true ;
493
+ }
494
+
495
+ status_t set (uint64_t group_size) {
496
+ group_size_ = group_size;
497
+ return status::success;
498
+ }
499
+
500
+ bool operator ==(const src_dyn_quant_params_t &rhs) const {
501
+ using namespace utils ;
502
+ return group_size_ == rhs.group_size_ ;
503
+ }
504
+
505
+ uint64_t group_size_;
506
+ };
507
+
473
508
} // namespace impl
474
509
} // namespace dnnl
475
510
@@ -882,6 +917,7 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
882
917
input_zero_points_ = (other.input_zero_points_ );
883
918
weights_zero_points_ = (other.weights_zero_points_ );
884
919
output_compensations_ = (other.output_compensations_ );
920
+ src_dyn_quant_params_ = other.src_dyn_quant_params_ ;
885
921
886
922
return status::success;
887
923
}
@@ -906,6 +942,7 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
906
942
input_zero_points = 1 << 13 ,
907
943
weights_zero_points = 1 << 14 ,
908
944
output_compensations = 1 << 15 ,
945
+ src_dyn_quant_params = 1u << 16 ,
909
946
};
910
947
911
948
/* * Returns true if the attributes have default values.
@@ -933,7 +970,8 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
933
970
|| (!gpu_attr_ && !rhs.gpu_attr_ ))
934
971
&& input_zero_points_ == rhs.input_zero_points_
935
972
&& weights_zero_points_ == rhs.weights_zero_points_
936
- && output_compensations_ == rhs.output_compensations_ ;
973
+ && output_compensations_ == rhs.output_compensations_
974
+ && src_dyn_quant_params_ == rhs.src_dyn_quant_params_ ;
937
975
return ret;
938
976
}
939
977
@@ -985,6 +1023,8 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible {
985
1023
dnnl::impl::legacy_zero_points_t weights_zero_points_;
986
1024
dnnl::impl::legacy_zero_points_t output_compensations_;
987
1025
1026
+ dnnl::impl::src_dyn_quant_params_t src_dyn_quant_params_;
1027
+
988
1028
dnnl_primitive_attr &operator =(const dnnl_primitive_attr &other) = delete ;
989
1029
};
990
1030
0 commit comments