@@ -398,25 +398,22 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
398
398
const auto nelems = mem_dt.nelems ();
399
399
if (nelems == 0 ) return OK;
400
400
401
- bool is_src_sparse_csr_coo = false ;
402
- bool is_wei_sparse_csr_coo = false ;
403
- bool is_wei_sparse_packed = false ;
401
+ bool is_sparse_packed = false ;
402
+ bool is_any_sparse = false ;
403
+ std::vector< bool > nnz_mask ;
404
404
#ifdef DNNL_EXPERIMENTAL_SPARSE
405
- auto src_encoding = prb->sparse_options .get_encoding (DNNL_ARG_SRC);
406
- auto wei_encoding = prb->sparse_options .get_encoding (DNNL_ARG_WEIGHTS);
407
- is_src_sparse_csr_coo = kind == SRC
408
- && (src_encoding == dnnl_csr || src_encoding == dnnl_coo);
409
- is_wei_sparse_csr_coo = kind == WEI
410
- && (wei_encoding == dnnl_csr || wei_encoding == dnnl_coo);
411
-
412
- if (is_src_sparse_csr_coo || is_wei_sparse_csr_coo) {
413
- return fill_sparse_data (kind, prb, mem_dt, mem_fp, res,
414
- kind == SRC ? src_encoding : wei_encoding);
405
+ const auto sparse_encoding = prb->sparse_options .get_encoding (kind);
406
+ const bool is_sparse_csr_coo
407
+ = sparse_encoding == dnnl_csr || sparse_encoding == dnnl_coo;
408
+ is_sparse_packed = sparse_encoding == dnnl_packed;
409
+ is_any_sparse = sparse_encoding != sparse_options_t ::def_encoding;
410
+
411
+ if (is_sparse_csr_coo) {
412
+ return fill_sparse_data (
413
+ kind, prb, mem_dt, mem_fp, res, sparse_encoding);
415
414
}
416
415
417
- is_wei_sparse_packed = kind == WEI && wei_encoding == dnnl_packed;
418
- std::vector<bool > nnz_mask;
419
- if (is_wei_sparse_packed) {
416
+ if (is_sparse_packed) {
420
417
nnz_mask.resize (nelems, false );
421
418
const dnnl_dim_t nnz = query_md_nnz (mem_dt.md_ );
422
419
assert (nnz > 0 );
@@ -427,11 +424,9 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
427
424
}
428
425
#endif
429
426
430
- const bool is_any_sparse = is_src_sparse_csr_coo || is_wei_sparse_csr_coo
431
- || is_wei_sparse_packed;
432
427
// Refer to modes documentation for filling principles.
433
428
// Note: sparse filling is more complex than a general one in a sense that
434
- // it requires indices in addition to data. To have reasonable bitwise
429
+ // it requires metadata in addition to data. To have reasonable bitwise
435
430
// validation for sparse, only data must be random and indices should remain
436
431
// identical between runs. So far, simply don't support bitwise mode for
437
432
// sparse problems. `CSR`/`COO` will utilize their `fill_sparse_data`
@@ -467,7 +462,7 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
467
462
std::bernoulli_distribution b_dist (density);
468
463
469
464
// make sure the first element is positive
470
- if (idx_start == 0 && !is_wei_sparse_packed ) {
465
+ if (idx_start == 0 && !is_sparse_packed ) {
471
466
float val = 0 ;
472
467
while (val <= 0 )
473
468
val = gen (int_seed);
@@ -478,19 +473,15 @@ int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
478
473
479
474
for (int64_t idx = idx_start; idx < idx_end; ++idx) {
480
475
bool is_one = density == 1 .f ? true : b_dist (b_seed);
481
- #ifdef DNNL_EXPERIMENTAL_SPARSE
482
476
float val = 0 .0f ;
483
- if (is_wei_sparse_packed && kind == WEI ) {
477
+ if (is_sparse_packed ) {
484
478
is_one = nnz_mask[idx];
485
479
while (val == 0 .0f )
486
480
val = gen (int_seed);
487
481
val *= is_one;
488
482
} else {
489
483
val = is_one * gen (int_seed);
490
484
}
491
- #else
492
- float val = is_one * gen (int_seed);
493
- #endif
494
485
mem_fp.set_elem (
495
486
idx, round_to_nearest_representable (cfg.get_dt (kind), val));
496
487
}
0 commit comments