Skip to content

Commit 2382350

Browse files
committed
benchdnn: matmul: ref: switch weights to ba
1 parent 4e52972 commit 2382350

File tree

5 files changed

+45
-11
lines changed

5 files changed

+45
-11
lines changed

tests/benchdnn/dnnl_memory.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,18 @@ dnn_mem_t::dnn_mem_t(const_dnnl_memory_desc_t md, dnnl_data_type_t dt,
6868
}
6969
}
7070

71+
dnn_mem_t::dnn_mem_t(const_dnnl_memory_desc_t md, dnnl_data_type_t dt,
72+
const dnnl_dims_t strides, dnnl_engine_t engine) {
73+
const int ndims = query_md_ndims(md);
74+
if (ndims > 0) {
75+
auto status = dnnl_memory_desc_create_with_strides(
76+
&md_, ndims, query_md_dims(md), dt, strides);
77+
(void)status;
78+
assert(status == dnnl_success);
79+
active_ = (initialize(engine) == OK);
80+
}
81+
}
82+
7183
dnn_mem_t::dnn_mem_t(int ndims, const dnnl_dims_t dims, dnnl_data_type_t dt,
7284
const std::string &tag, dnnl_engine_t engine) {
7385
if (ndims > 0) {

tests/benchdnn/dnnl_memory.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ struct dnn_mem_t {
4747
dnn_mem_t() { map(); }
4848
dnn_mem_t(const_dnnl_memory_desc_t md, dnnl_engine_t engine,
4949
const handle_info_t &handle_info = handle_info_t::allocate());
50+
5051
dnn_mem_t(const_dnnl_memory_desc_t md, dnnl_data_type_t dt,
5152
const std::string &tag, dnnl_engine_t engine);
53+
dnn_mem_t(const_dnnl_memory_desc_t md, dnnl_data_type_t dt,
54+
const dnnl_dims_t strides, dnnl_engine_t engine);
5255

5356
dnn_mem_t(int ndims, const dnnl_dims_t dims, dnnl_data_type_t dt,
5457
const std::string &tag, dnnl_engine_t engine);

tests/benchdnn/matmul/matmul.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,24 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,
841841
} else
842842
#endif
843843
{
844-
// Scratchpad memory relates to a primitive. If reference needs it,
845-
// use switch below to define a memory desc for it.
846-
if (exec_arg != DNNL_ARG_SCRATCHPAD) {
844+
if (exec_arg == DNNL_ARG_WEIGHTS) {
845+
// Switch the format tag from "ab" to "ba" but to handle batched
846+
// cases, use strides instead.
847+
const auto ndims = mem.ndims();
848+
const auto &dims = mem.dims();
849+
dnnl_dims_t strides {};
850+
dnnl_dim_t stride = 1;
851+
for (int d = ndims - 2; d >= 0; d--) {
852+
strides[d] = stride * dims[d + 1];
853+
stride = strides[d];
854+
}
855+
strides[ndims - 2] = 1;
856+
strides[ndims - 1] = dims[ndims - 2];
857+
ref_mem_map.emplace(exec_arg,
858+
dnn_mem_t(mem.md_, dnnl_f32, strides, ref_engine));
859+
} else if (exec_arg != DNNL_ARG_SCRATCHPAD) {
860+
// Scratchpad memory relates to a primitive. If reference needs
861+
// it, use switch below to define a memory desc for it.
847862
ref_mem_map.emplace(exec_arg,
848863
dnn_mem_t(mem.md_, dnnl_f32, tag::abx, ref_engine));
849864
}

tests/benchdnn/matmul/matmul.hpp

-4
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,6 @@ inline int64_t src_off_f(const prb_t *prb, int64_t mb, int64_t m, int64_t k) {
280280
return (mb * prb->m + m) * prb->k + k;
281281
}
282282

283-
inline int64_t wei_off_f(const prb_t *prb, int64_t mb, int64_t k, int64_t n) {
284-
return (mb * prb->k + k) * prb->n + n;
285-
}
286-
287283
inline int64_t dst_off_f(const prb_t *prb, int64_t mb, int64_t m, int64_t n) {
288284
return (mb * prb->m + m) * prb->n + n;
289285
}

tests/benchdnn/matmul/ref_matmul.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222

2323
namespace matmul {
2424

25+
int64_t wei_ab_off_f(const prb_t *prb, int64_t mb, int64_t k, int64_t n) {
26+
return (mb * prb->k + k) * prb->n + n;
27+
}
28+
int64_t wei_ba_off_f(const prb_t *prb, int64_t mb, int64_t k, int64_t n) {
29+
return (mb * prb->n + n) * prb->k + k;
30+
}
31+
2532
void compute_ref_matmul(const prb_t *prb, const args_t &args) {
2633
const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
2734
const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
@@ -130,8 +137,9 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
130137
for (int64_t gK = 0; gK < n_k_groups; gK++) {
131138
const auto src_gK_off
132139
= src_off_f(prb, src_mb, m, gK * smallest_k_group);
140+
// Note: scales/zero-points are still always in `tag::abx` format.
133141
const auto wei_gK_off
134-
= wei_off_f(prb, wei_mb, gK * smallest_k_group, n);
142+
= wei_ab_off_f(prb, wei_mb, gK * smallest_k_group, n);
135143

136144
if (has_src_zp && !has_src_single_zp) {
137145
const auto src_zp_idx = src_m.get_idx(
@@ -158,8 +166,8 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
158166
for (int64_t k = 0; k < smallest_k_group; ++k) {
159167
const auto src_off
160168
= src_off_f(prb, src_mb, m, gK * smallest_k_group + k);
161-
const auto wei_off
162-
= wei_off_f(prb, wei_mb, gK * smallest_k_group + k, n);
169+
const auto wei_off = wei_ba_off_f(
170+
prb, wei_mb, gK * smallest_k_group + k, n);
163171

164172
auto s = src_scale * (src_m.get_elem(src_off) - src_zp);
165173
auto w = wei_scale * (wei_m.get_elem(wei_off) - wei_zp);
@@ -292,7 +300,7 @@ void compute_ref_sparse_matmul(const prb_t *prb, const args_t &args) {
292300

293301
for (int64_t k = row_start; k < row_end; k++) {
294302
const int64_t wei_idx
295-
= wei_off_f(prb, mb, src_indices[k], n);
303+
= wei_ba_off_f(prb, mb, src_indices[k], n);
296304
const float src_val = src_m.get_elem(k, 0);
297305
const float wei_val = wei_m.get_elem(wei_idx);
298306
dst_val += src_val * wei_val;

0 commit comments

Comments
 (0)