@@ -242,8 +242,17 @@ struct runtime_scales_t : public c_compatible {
242
242
return status::success;
243
243
}
244
244
245
+ status_t set (const dims_t dims, int ndims) {
246
+ is_set_ = true ;
247
+ ndims_ = ndims;
248
+ mask_ = 1 ;
249
+ utils::array_copy (dims_, dims, ndims_);
250
+ return status::success;
251
+ }
252
+
245
253
bool operator ==(const runtime_scales_t &rhs) const {
246
- return mask_ == rhs.mask_ && is_set_ == rhs.is_set_ ;
254
+ return mask_ == rhs.mask_ && is_set_ == rhs.is_set_ &&
255
+ ndims_ == rhs.ndims_ && utils::array_cmp (dims_, rhs.dims_ , ndims_);
247
256
}
248
257
249
258
bool has_default_values () const { return !is_set_; }
@@ -259,6 +268,9 @@ struct runtime_scales_t : public c_compatible {
259
268
// Hide `mask_` under `private:` to force interface usage.
260
269
int mask_ = 0 ;
261
270
bool is_set_ = false ;
271
+
272
+ int ndims_ = 0 ;
273
+ dnnl::impl::dims_t dims_;
262
274
};
263
275
264
276
struct arg_scales_t : public c_compatible {
@@ -295,6 +307,10 @@ struct arg_scales_t : public c_compatible {
295
307
if (!check_arg (arg)) return status::invalid_arguments;
296
308
return scales_[arg].set (mask);
297
309
}
310
+ status_t set (int arg, const dims_t dims, int ndims) {
311
+ if (!check_arg (arg)) return status::invalid_arguments;
312
+ return scales_[arg].set (dims, ndims);
313
+ }
298
314
299
315
status_t get (int arg, int *mask, bool *is_set) const {
300
316
if (!check_arg (arg)) return status::invalid_arguments;
@@ -354,7 +370,8 @@ struct zero_points_t : public c_compatible {
354
370
bool operator ==(const zero_points_t &rhs) const {
355
371
return mask_src == rhs.mask_src && mask_wei == rhs.mask_wei
356
372
&& mask_dst == rhs.mask_dst && is_set_src == rhs.is_set_src
357
- && is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst ;
373
+ && is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst
374
+ && IMPLICATION (ndims_wei > 0 , ndims_wei == rhs.ndims_wei && utils::array_cmp (dims_wei, rhs.dims_wei , ndims_wei));
358
375
}
359
376
360
377
// arg-specific checks
@@ -373,12 +390,26 @@ struct zero_points_t : public c_compatible {
373
390
int get (int arg) const ; // Returns 0 if dimension is unset
374
391
375
392
status_t set (int arg, int mask);
393
+ status_t set (int arg, const dims_t dims, int ndims);
376
394
status_t set (int arg) { return set (arg, 0 ); }
377
395
396
+ const dims_t & get_dims (int /* arg*/ ) const {
397
+ return dims_wei;
398
+ }
399
+ int get_ndims (int arg) const {
400
+ switch (arg) {
401
+ case DNNL_ARG_WEIGHTS: return ndims_wei; break ;
402
+ default : return 0 ;
403
+ }
404
+ }
405
+
378
406
private:
379
407
bool is_set_src = false , is_set_wei = false , is_set_dst = false ;
380
408
int mask_src = 0 , mask_wei = 0 , mask_dst = 0 ;
381
409
410
+ int ndims_wei = 0 ;
411
+ dnnl::impl::dims_t dims_wei;
412
+
382
413
int get_mask (int arg) const {
383
414
int mask = 0 ;
384
415
switch (arg) {
0 commit comments