14
14
* limitations under the License.
15
15
*******************************************************************************/
16
16
17
+ #include < algorithm>
18
+
17
19
#include " utils/parallel.hpp"
18
20
19
21
#include " matmul/matmul.hpp"
@@ -39,16 +41,26 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
39
41
= args.find (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
40
42
const dnn_mem_t &dropout = args.find (DNNL_ARG_ATTR_DROPOUT_MASK);
41
43
44
+ const int64_t M = prb->m ;
45
+ const int64_t N = prb->n ;
46
+ const int64_t K = prb->k ;
47
+ const int64_t MB = prb->mb ;
48
+ const int batch_ndims = dst_m.ndims () - 2 ;
49
+
42
50
const bool has_src_scale = !prb->attr .scales .get (DNNL_ARG_SRC).is_def ();
43
51
const bool has_wei_scale = !prb->attr .scales .get (DNNL_ARG_WEIGHTS).is_def ();
44
52
const bool has_dst_scale = !prb->attr .scales .get (DNNL_ARG_DST).is_def ();
53
+
45
54
const int src_scale_mask = prb->attr .scales .get_mask (
46
55
DNNL_ARG_SRC, dnnl_matmul, src_m.ndims ());
47
56
const int wei_scale_mask = prb->attr .scales .get_mask (
48
57
DNNL_ARG_WEIGHTS, dnnl_matmul, wei_m.ndims ());
49
58
const int dst_scale_mask = prb->attr .scales .get_mask (
50
59
DNNL_ARG_DST, dnnl_matmul, dst_m.ndims ());
51
60
61
+ const bool has_src_single_scale = has_src_scale && src_scale_mask == 0 ;
62
+ const bool has_wei_single_scale = has_wei_scale && wei_scale_mask == 0 ;
63
+
52
64
const bool has_src_zp = !prb->attr .zero_points .get (DNNL_ARG_SRC).is_def ();
53
65
const bool has_wei_zp
54
66
= !prb->attr .zero_points .get (DNNL_ARG_WEIGHTS).is_def ();
@@ -61,18 +73,21 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
61
73
const int dst_zp_mask = attr_t::get_default_mask (
62
74
prb->attr .zero_points .get (DNNL_ARG_DST).policy );
63
75
76
+ const bool has_src_single_zp = has_src_zp && src_zp_mask == 0 ;
77
+ const bool has_wei_single_zp = has_wei_zp && wei_zp_mask == 0 ;
78
+
64
79
const auto &src_scale_groups = prb->attr .scales .get (DNNL_ARG_SRC).groups ;
65
80
const auto &wei_scale_groups
66
81
= prb->attr .scales .get (DNNL_ARG_WEIGHTS).groups ;
67
82
const auto &src_zp_groups = prb->attr .zero_points .get (DNNL_ARG_SRC).groups ;
68
83
const auto &wei_zp_groups
69
84
= prb->attr .zero_points .get (DNNL_ARG_WEIGHTS).groups ;
70
-
71
- const int64_t M = prb-> m ;
72
- const int64_t N = prb-> n ;
73
- const int64_t K = prb-> k ;
74
- const int64_t MB = prb-> mb ;
75
- const int batch_ndims = dst_m. ndims () - 2 ;
85
+ const auto smallest_k_group
86
+ = std::min ({src_scale_groups. empty () ? K : src_scale_groups[ 1 ],
87
+ wei_scale_groups. empty () ? K : wei_scale_groups[ 0 ],
88
+ src_zp_groups. empty () ? K : src_zp_groups[ 1 ],
89
+ wei_zp_groups. empty () ? K : wei_zp_groups[ 0 ]}) ;
90
+ const auto n_k_groups = K / smallest_k_group ;
76
91
77
92
// Fast return if any dim is zero. Common logic doesn't apply because of
78
93
// broadcast semantics.
@@ -87,45 +102,57 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
87
102
88
103
benchdnn_parallel_nd (MB, M, N, [&](int64_t mb, int64_t m, int64_t n) {
89
104
float dst = 0 ;
90
- const int64_t src_mb
91
- = dst_m.get_idx (mb, src_broadcast_mask, batch_ndims);
92
- const int64_t wei_mb
93
- = dst_m.get_idx (mb, wei_broadcast_mask, batch_ndims);
105
+ int64_t src_mb = 0 ;
106
+ int64_t wei_mb = 0 ;
107
+ if (MB > 1 ) {
108
+ src_mb = dst_m.get_idx (mb, src_broadcast_mask, batch_ndims);
109
+ wei_mb = dst_m.get_idx (mb, wei_broadcast_mask, batch_ndims);
110
+ }
111
+
112
+ int src_zp = has_src_single_zp ? src_zps.get_elem (0 ) : 0 ;
113
+ int wei_zp = has_wei_single_zp ? wei_zps.get_elem (0 ) : 0 ;
114
+ float src_scale = has_src_single_scale ? src_scales.get_elem (0 ) : 1 .f ;
115
+ float wei_scale = has_wei_single_scale ? wei_scales.get_elem (0 ) : 1 .f ;
94
116
95
- for (int64_t k = 0 ; k < K; ++k) {
96
- const auto src_off = src_off_f (prb, src_mb, m, k);
97
- const auto wei_off = wei_off_f (prb, wei_mb, k, n);
117
+ for (int64_t gK = 0 ; gK < n_k_groups; gK ++) {
118
+ const auto src_gK_off
119
+ = src_off_f (prb, src_mb, m, gK * smallest_k_group);
120
+ const auto wei_gK_off
121
+ = wei_off_f (prb, wei_mb, gK * smallest_k_group, n);
98
122
99
- int src_zp = 0 ;
100
- if (has_src_zp) {
123
+ if (has_src_zp && !has_src_single_zp) {
101
124
const auto src_zp_idx = src_m.get_idx (
102
- src_off , src_zp_mask, src_m.ndims (), src_zp_groups);
125
+ src_gK_off , src_zp_mask, src_m.ndims (), src_zp_groups);
103
126
src_zp = src_zps.get_elem (src_zp_idx);
104
127
}
105
- int wei_zp = 0 ;
106
- if (has_wei_zp) {
128
+ if (has_wei_zp && !has_wei_single_zp) {
107
129
const auto wei_zp_idx = wei_m.get_idx (
108
- wei_off , wei_zp_mask, wei_m.ndims (), wei_zp_groups);
130
+ wei_gK_off , wei_zp_mask, wei_m.ndims (), wei_zp_groups);
109
131
wei_zp = wei_zps.get_elem (wei_zp_idx);
110
132
}
111
133
112
- float src_scale = 1 .f ;
113
- if (has_src_scale) {
114
- const auto src_scale_idx = src_m.get_idx (src_off,
134
+ if (has_src_scale && !has_src_single_scale) {
135
+ const auto src_scale_idx = src_m.get_idx (src_gK_off,
115
136
src_scale_mask, src_m.ndims (), src_scale_groups);
116
137
src_scale = src_scales.get_elem (src_scale_idx);
117
138
}
118
- float wei_scale = 1 .f ;
119
- if (has_wei_scale) {
120
- const auto wei_scale_idx = wei_m.get_idx (wei_off,
139
+ if (has_wei_scale && !has_wei_single_scale) {
140
+ const auto wei_scale_idx = wei_m.get_idx (wei_gK_off,
121
141
wei_scale_mask, wei_m.ndims (), wei_scale_groups);
122
142
wei_scale = wei_scales.get_elem (wei_scale_idx);
123
143
}
124
144
125
- auto s = src_scale * (src_m.get_elem (src_off) - src_zp);
126
- auto w = wei_scale * (wei_m.get_elem (wei_off) - wei_zp);
145
+ for (int64_t k = 0 ; k < smallest_k_group; ++k) {
146
+ const auto src_off
147
+ = src_off_f (prb, src_mb, m, gK * smallest_k_group + k);
148
+ const auto wei_off
149
+ = wei_off_f (prb, wei_mb, gK * smallest_k_group + k, n);
127
150
128
- dst += s * w;
151
+ auto s = src_scale * (src_m.get_elem (src_off) - src_zp);
152
+ auto w = wei_scale * (wei_m.get_elem (wei_off) - wei_zp);
153
+
154
+ dst += s * w;
155
+ }
129
156
}
130
157
131
158
const auto dst_off = dst_off_f (prb, mb, m, n);
0 commit comments