@@ -74,8 +74,8 @@ struct gen_gemm_t : public gpu_gemm_t {
74
74
wei_decomp_ = (utils::one_of (d->c_type (), f32, f16, bf16, f8_e5m2,
75
75
f8_e4m3)
76
76
&& utils::one_of (d->a_type (), u8, s8, s4, u4)
77
- && utils::one_of (d->b_type (), f16, f32, bf16 ,
78
- f8_e5m2, f8_e4m3))
77
+ && utils::one_of (d->b_type (), u8, s8, s4, u4 ,
78
+ f16, f32, bf16, f8_e5m2, f8_e4m3))
79
79
&& attr ()->mayiconvert (d->a_type (), f32);
80
80
dy_quant_enabled_
81
81
= (utils::one_of (d->c_type (), f32, f16, bf16)
@@ -224,6 +224,9 @@ struct gen_gemm_t : public gpu_gemm_t {
224
224
225
225
if (!attr ()->zero_points_ .has_default_values ()) {
226
226
if (!attr_zps.has_default_values (DNNL_ARG_A)) {
227
+ // Only apply to integers inputs.
228
+ VDISPATCH_GEMM (utils::one_of (d->a_type (), s4, u4, s8, u8),
229
+ VERBOSE_UNSUPPORTED_ZP_CFG);
227
230
const int cmask_a = attr_zps.get_mask (DNNL_ARG_A);
228
231
ao_dims_ = cmask_a > 0 ;
229
232
@@ -253,10 +256,17 @@ struct gen_gemm_t : public gpu_gemm_t {
253
256
VDISPATCH_GEMM (utils::one_of (cmask_a, 0 , mask_per_oc,
254
257
mask_per_ic),
255
258
VERBOSE_UNSUPPORTED_ZP_CFG);
259
+ // Weights zp can only be performantly enabled during upconversion.
260
+ VDISPATCH_GEMM (wei_decomp_
261
+ || utils::one_of (d->b_type (), s4, u4),
262
+ VERBOSE_UNSUPPORTED_ZP_CFG);
256
263
}
257
264
}
258
265
259
266
if (!attr_zps.has_default_values (DNNL_ARG_B)) {
267
+ // Only apply to integers inputs.
268
+ VDISPATCH_GEMM (utils::one_of (d->b_type (), s4, u4, s8, u8),
269
+ VERBOSE_UNSUPPORTED_ZP_CFG);
260
270
const int cmask_b = attr_zps.get_mask (DNNL_ARG_B);
261
271
bo_dims_ = cmask_b > 0 ;
262
272
@@ -390,6 +400,7 @@ struct gen_gemm_t : public gpu_gemm_t {
390
400
: data_type::s32;
391
401
if (swap_ab_) std::swap (ao_type, bo_type);
392
402
bool int_acc = utils::one_of (eff_a_type (), s8, u8);
403
+ int_acc &= !wei_scales_2d_;
393
404
auto co_type = with_bias () ? d->bias_type ()
394
405
: with_sum_ab () ? d->sum_ab_type
395
406
: int_acc ? s32
0 commit comments