Skip to content

Commit a14aff9

Browse files
committed
benchdnn: matmul: ref: decrease the number of calls to get_idx
1 parent 3643f2d commit a14aff9

File tree

1 file changed

+55
-28
lines changed

1 file changed

+55
-28
lines changed

tests/benchdnn/matmul/ref_matmul.cpp

+55-28
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17+
#include <algorithm>
18+
1719
#include "utils/parallel.hpp"
1820

1921
#include "matmul/matmul.hpp"
@@ -39,16 +41,26 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
3941
= args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
4042
const dnn_mem_t &dropout = args.find(DNNL_ARG_ATTR_DROPOUT_MASK);
4143

44+
const int64_t M = prb->m;
45+
const int64_t N = prb->n;
46+
const int64_t K = prb->k;
47+
const int64_t MB = prb->mb;
48+
const int batch_ndims = dst_m.ndims() - 2;
49+
4250
const bool has_src_scale = !prb->attr.scales.get(DNNL_ARG_SRC).is_def();
4351
const bool has_wei_scale = !prb->attr.scales.get(DNNL_ARG_WEIGHTS).is_def();
4452
const bool has_dst_scale = !prb->attr.scales.get(DNNL_ARG_DST).is_def();
53+
4554
const int src_scale_mask = prb->attr.scales.get_mask(
4655
DNNL_ARG_SRC, dnnl_matmul, src_m.ndims());
4756
const int wei_scale_mask = prb->attr.scales.get_mask(
4857
DNNL_ARG_WEIGHTS, dnnl_matmul, wei_m.ndims());
4958
const int dst_scale_mask = prb->attr.scales.get_mask(
5059
DNNL_ARG_DST, dnnl_matmul, dst_m.ndims());
5160

61+
const bool has_src_single_scale = has_src_scale && src_scale_mask == 0;
62+
const bool has_wei_single_scale = has_wei_scale && wei_scale_mask == 0;
63+
5264
const bool has_src_zp = !prb->attr.zero_points.get(DNNL_ARG_SRC).is_def();
5365
const bool has_wei_zp
5466
= !prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).is_def();
@@ -61,18 +73,21 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
6173
const int dst_zp_mask = attr_t::get_default_mask(
6274
prb->attr.zero_points.get(DNNL_ARG_DST).policy);
6375

76+
const bool has_src_single_zp = has_src_zp && src_zp_mask == 0;
77+
const bool has_wei_single_zp = has_wei_zp && wei_zp_mask == 0;
78+
6479
const auto &src_scale_groups = prb->attr.scales.get(DNNL_ARG_SRC).groups;
6580
const auto &wei_scale_groups
6681
= prb->attr.scales.get(DNNL_ARG_WEIGHTS).groups;
6782
const auto &src_zp_groups = prb->attr.zero_points.get(DNNL_ARG_SRC).groups;
6883
const auto &wei_zp_groups
6984
= prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).groups;
70-
71-
const int64_t M = prb->m;
72-
const int64_t N = prb->n;
73-
const int64_t K = prb->k;
74-
const int64_t MB = prb->mb;
75-
const int batch_ndims = dst_m.ndims() - 2;
85+
const auto smallest_k_group
86+
= std::min({src_scale_groups.empty() ? K : src_scale_groups[1],
87+
wei_scale_groups.empty() ? K : wei_scale_groups[0],
88+
src_zp_groups.empty() ? K : src_zp_groups[1],
89+
wei_zp_groups.empty() ? K : wei_zp_groups[0]});
90+
const auto n_k_groups = K / smallest_k_group;
7691

7792
// Fast return if any dim is zero. Common logic doesn't apply because of
7893
// broadcast semantics.
@@ -87,45 +102,57 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
87102

88103
benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) {
89104
float dst = 0;
90-
const int64_t src_mb
91-
= dst_m.get_idx(mb, src_broadcast_mask, batch_ndims);
92-
const int64_t wei_mb
93-
= dst_m.get_idx(mb, wei_broadcast_mask, batch_ndims);
105+
int64_t src_mb = 0;
106+
int64_t wei_mb = 0;
107+
if (MB > 1) {
108+
src_mb = dst_m.get_idx(mb, src_broadcast_mask, batch_ndims);
109+
wei_mb = dst_m.get_idx(mb, wei_broadcast_mask, batch_ndims);
110+
}
111+
112+
int src_zp = has_src_single_zp ? src_zps.get_elem(0) : 0;
113+
int wei_zp = has_wei_single_zp ? wei_zps.get_elem(0) : 0;
114+
float src_scale = has_src_single_scale ? src_scales.get_elem(0) : 1.f;
115+
float wei_scale = has_wei_single_scale ? wei_scales.get_elem(0) : 1.f;
94116

95-
for (int64_t k = 0; k < K; ++k) {
96-
const auto src_off = src_off_f(prb, src_mb, m, k);
97-
const auto wei_off = wei_off_f(prb, wei_mb, k, n);
117+
for (int64_t gK = 0; gK < n_k_groups; gK++) {
118+
const auto src_gK_off
119+
= src_off_f(prb, src_mb, m, gK * smallest_k_group);
120+
const auto wei_gK_off
121+
= wei_off_f(prb, wei_mb, gK * smallest_k_group, n);
98122

99-
int src_zp = 0;
100-
if (has_src_zp) {
123+
if (has_src_zp && !has_src_single_zp) {
101124
const auto src_zp_idx = src_m.get_idx(
102-
src_off, src_zp_mask, src_m.ndims(), src_zp_groups);
125+
src_gK_off, src_zp_mask, src_m.ndims(), src_zp_groups);
103126
src_zp = src_zps.get_elem(src_zp_idx);
104127
}
105-
int wei_zp = 0;
106-
if (has_wei_zp) {
128+
if (has_wei_zp && !has_wei_single_zp) {
107129
const auto wei_zp_idx = wei_m.get_idx(
108-
wei_off, wei_zp_mask, wei_m.ndims(), wei_zp_groups);
130+
wei_gK_off, wei_zp_mask, wei_m.ndims(), wei_zp_groups);
109131
wei_zp = wei_zps.get_elem(wei_zp_idx);
110132
}
111133

112-
float src_scale = 1.f;
113-
if (has_src_scale) {
114-
const auto src_scale_idx = src_m.get_idx(src_off,
134+
if (has_src_scale && !has_src_single_scale) {
135+
const auto src_scale_idx = src_m.get_idx(src_gK_off,
115136
src_scale_mask, src_m.ndims(), src_scale_groups);
116137
src_scale = src_scales.get_elem(src_scale_idx);
117138
}
118-
float wei_scale = 1.f;
119-
if (has_wei_scale) {
120-
const auto wei_scale_idx = wei_m.get_idx(wei_off,
139+
if (has_wei_scale && !has_wei_single_scale) {
140+
const auto wei_scale_idx = wei_m.get_idx(wei_gK_off,
121141
wei_scale_mask, wei_m.ndims(), wei_scale_groups);
122142
wei_scale = wei_scales.get_elem(wei_scale_idx);
123143
}
124144

125-
auto s = src_scale * (src_m.get_elem(src_off) - src_zp);
126-
auto w = wei_scale * (wei_m.get_elem(wei_off) - wei_zp);
145+
for (int64_t k = 0; k < smallest_k_group; ++k) {
146+
const auto src_off
147+
= src_off_f(prb, src_mb, m, gK * smallest_k_group + k);
148+
const auto wei_off
149+
= wei_off_f(prb, wei_mb, gK * smallest_k_group + k, n);
127150

128-
dst += s * w;
151+
auto s = src_scale * (src_m.get_elem(src_off) - src_zp);
152+
auto w = wei_scale * (wei_m.get_elem(wei_off) - wei_zp);
153+
154+
dst += s * w;
155+
}
129156
}
130157

131158
const auto dst_off = dst_off_f(prb, mb, m, n);

0 commit comments

Comments
 (0)