Skip to content

Commit 74426dd

Browse files
committed
benchdnn: matmul: get encoding per kind
1 parent 7d93411 commit 74426dd

File tree

2 files changed

+27
-25
lines changed

2 files changed

+27
-25
lines changed

tests/benchdnn/dnn_types.hpp

+11
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,17 @@ struct sparse_options_t {
503503
if (options_.count(arg) == 0) return dnnl_sparse_encoding_undef;
504504
return options_.at(arg).first;
505505
}
506+
dnnl_sparse_encoding_t get_encoding(data_kind_t kind) const {
507+
// Note: the commented code doesn't work as `arg` returned is a
508+
// backward exec_arg. See the function comment.
509+
// const auto arg = data_kind2exec_arg(kind);
510+
// return get_encoding(arg);
511+
switch (kind) {
512+
case SRC: return get_encoding(DNNL_ARG_SRC);
513+
case WEI: return get_encoding(DNNL_ARG_WEIGHTS);
514+
default: return def_encoding;
515+
}
516+
}
506517

507518
float get_sparsity(int arg) const {
508519
if (options_.count(arg) == 0) return 0.0f;

tests/benchdnn/matmul/matmul.cpp

+16-25
Original file line numberDiff line numberDiff line change
@@ -398,25 +398,22 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
398398
const auto nelems = mem_dt.nelems();
399399
if (nelems == 0) return OK;
400400

401-
bool is_src_sparse_csr_coo = false;
402-
bool is_wei_sparse_csr_coo = false;
403-
bool is_wei_sparse_packed = false;
401+
bool is_sparse_packed = false;
402+
bool is_any_sparse = false;
403+
std::vector<bool> nnz_mask;
404404
#ifdef DNNL_EXPERIMENTAL_SPARSE
405-
auto src_encoding = prb->sparse_options.get_encoding(DNNL_ARG_SRC);
406-
auto wei_encoding = prb->sparse_options.get_encoding(DNNL_ARG_WEIGHTS);
407-
is_src_sparse_csr_coo = kind == SRC
408-
&& (src_encoding == dnnl_csr || src_encoding == dnnl_coo);
409-
is_wei_sparse_csr_coo = kind == WEI
410-
&& (wei_encoding == dnnl_csr || wei_encoding == dnnl_coo);
411-
412-
if (is_src_sparse_csr_coo || is_wei_sparse_csr_coo) {
413-
return fill_sparse_data(kind, prb, mem_dt, mem_fp, res,
414-
kind == SRC ? src_encoding : wei_encoding);
405+
const auto sparse_encoding = prb->sparse_options.get_encoding(kind);
406+
const bool is_sparse_csr_coo
407+
= sparse_encoding == dnnl_csr || sparse_encoding == dnnl_coo;
408+
is_sparse_packed = sparse_encoding == dnnl_packed;
409+
is_any_sparse = sparse_encoding != sparse_options_t::def_encoding;
410+
411+
if (is_sparse_csr_coo) {
412+
return fill_sparse_data(
413+
kind, prb, mem_dt, mem_fp, res, sparse_encoding);
415414
}
416415

417-
is_wei_sparse_packed = kind == WEI && wei_encoding == dnnl_packed;
418-
std::vector<bool> nnz_mask;
419-
if (is_wei_sparse_packed) {
416+
if (is_sparse_packed) {
420417
nnz_mask.resize(nelems, false);
421418
const dnnl_dim_t nnz = query_md_nnz(mem_dt.md_);
422419
assert(nnz > 0);
@@ -427,11 +424,9 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
427424
}
428425
#endif
429426

430-
const bool is_any_sparse = is_src_sparse_csr_coo || is_wei_sparse_csr_coo
431-
|| is_wei_sparse_packed;
432427
// Refer to modes documentation for filling principles.
433428
// Note: sparse filling is more complex than a general one in a sense that
434-
// it requires indices in addition to data. To have reasonable bitwise
429+
// it requires metadata in addition to data. To have reasonable bitwise
435430
// validation for sparse, only data must be random and indices should remain
436431
// identical between runs. So far, simply don't support bitwise mode for
437432
// sparse problems. `CSR`/`COO` will utilize their `fill_sparse_data`
@@ -467,7 +462,7 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
467462
std::bernoulli_distribution b_dist(density);
468463

469464
// make sure the first element is positive
470-
if (idx_start == 0 && !is_wei_sparse_packed) {
465+
if (idx_start == 0 && !is_sparse_packed) {
471466
float val = 0;
472467
while (val <= 0)
473468
val = gen(int_seed);
@@ -478,19 +473,15 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
478473

479474
for (int64_t idx = idx_start; idx < idx_end; ++idx) {
480475
bool is_one = density == 1.f ? true : b_dist(b_seed);
481-
#ifdef DNNL_EXPERIMENTAL_SPARSE
482476
float val = 0.0f;
483-
if (is_wei_sparse_packed && kind == WEI) {
477+
if (is_sparse_packed) {
484478
is_one = nnz_mask[idx];
485479
while (val == 0.0f)
486480
val = gen(int_seed);
487481
val *= is_one;
488482
} else {
489483
val = is_one * gen(int_seed);
490484
}
491-
#else
492-
float val = is_one * gen(int_seed);
493-
#endif
494485
mem_fp.set_elem(
495486
idx, round_to_nearest_representable(cfg.get_dt(kind), val));
496487
}

0 commit comments

Comments
 (0)