Skip to content

Commit 9233c5a

Browse files
committed
cpu: matmul: update reference impl for coo sparse matmul
1 parent 40dbd6d commit 9233c5a

File tree

3 files changed

+235
-52
lines changed

3 files changed

+235
-52
lines changed

src/common/memory_tracking.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ enum {
257257
key_matmul_wei_trans,
258258
key_matmul_dst_trans,
259259
key_matmul_dst_cast_acc,
260+
key_matmul_sparse_tmp_ptr,
260261
key_pool_dst_bf16cvt,
261262
key_pool_dst_plain2blocked_cvt,
262263
key_pool_ind_plain2blocked_cvt,

src/cpu/matmul/ref_sparse_matmul.cpp

+147-32
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2023 Intel Corporation
2+
* Copyright 2023-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@
1818
#include "common/math_utils.hpp"
1919
#include "common/type_helpers.hpp"
2020

21+
#include "cpu/ref_io_helper.hpp"
22+
2123
#include "cpu/matmul/ref_sparse_matmul.hpp"
2224

2325
namespace dnnl {
@@ -27,7 +29,7 @@ namespace matmul {
2729

2830
status_t ref_sparse_matmul_t::execute(const exec_ctx_t &ctx) const {
2931
status_t status = status::success;
30-
auto dst = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DST, status);
32+
auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
3133
CHECK(status);
3234

3335
const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
@@ -38,48 +40,161 @@ status_t ref_sparse_matmul_t::execute(const exec_ctx_t &ctx) const {
3840
const dim_t N = dst_d.dims()[1];
3941
const dim_t K = src_d.dims()[1];
4042

41-
parallel_nd(M, N, [&](dim_t i, dim_t j) { dst[i * N + j] = 0.0f; });
43+
const data_type_t mm_dt = src_d.data_type();
44+
auto scratchpad = ctx.get_scratchpad_grantor();
45+
46+
parallel_nd(M, N, [&](dim_t i, dim_t j) {
47+
const dim_t dst_idx = i * N + j;
48+
io::store_float_value(dst_d.data_type(), 0.0f, dst, dst_idx);
49+
});
4250

4351
if (weights_d.is_sparse_desc()) {
44-
const auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
45-
const auto wei_values = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS, 0);
46-
const auto wei_indices
47-
= CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 1);
48-
const auto wei_pointers
49-
= CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 2);
5052

53+
const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
54+
const auto wei_values = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS, 0);
55+
auto wei_buffer_1 = CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 1);
56+
auto wei_buffer_2 = CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 2);
57+
58+
// Both COO and CSR encoded data is operated on using CSR kernel for
59+
// matrix multiplication.
60+
// For COO encoding, data preparation includes using a temporary
61+
// buffer to convert the data to the CSR format.
62+
// Matrix multiplication is then carried out using the CSR encoded data.
63+
const int32_t *wei_indices;
64+
const int32_t *wei_pointers;
65+
66+
if (weights_d.encoding() == sparse_encoding::csr) {
67+
// For CSR encodings, pointer and indices assignment is
68+
// staightforward as,
69+
// index 1 - index buffer, index 2 - pointer buffer.
70+
wei_indices = wei_buffer_1;
71+
wei_pointers = wei_buffer_2;
72+
} else if (weights_d.encoding() == sparse_encoding::coo) {
73+
// For COO encodings, the two index buffers hold the row and column
74+
// indices respectively. For CSR conversion, the row indices are
75+
// compressed to generate the CSR pointers.
76+
wei_indices = wei_buffer_2;
77+
78+
int32_t *wei_row_pointers = scratchpad.template get<int32_t>(
79+
memory_tracking::names::key_matmul_sparse_tmp_ptr);
80+
81+
parallel_nd(K + 1, [&](dim_t k) {
82+
io::store_float_value(
83+
weights_d.metadata_type(0), 0, wei_row_pointers, k);
84+
});
85+
86+
cvt_coo_indices_to_csr_pointers(
87+
wei_buffer_1, wei_row_pointers, weights_d.nnz(), K);
88+
89+
wei_pointers = wei_row_pointers;
90+
}
91+
92+
run_csr_kernel(src, wei_values, wei_indices, wei_pointers, dst, M, N, K,
93+
mm_dt, src_d.is_sparse_desc());
94+
95+
} else if (src_d.is_sparse_desc()) {
96+
const auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
97+
const auto src_values = CTX_IN_MEM(const void *, DNNL_ARG_SRC, 0);
98+
auto src_buffer_1 = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 1);
99+
auto src_buffer_2 = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 2);
100+
101+
// Both COO and CSR encoded data is operated on using CSR kernel for
102+
// matrix multiplication.
103+
// For COO encoding, data preparation includes using a temporary
104+
// buffer to convert the data to the CSR format.
105+
// Matrix multiplication is then carried out using the CSR encoded data.
106+
const int32_t *src_indices;
107+
const int32_t *src_pointers;
108+
109+
if (src_d.encoding() == sparse_encoding::csr) {
110+
// For CSR encodings, pointer and indices assignment is
111+
// staightforward as
112+
// index 1 - index buffer, index 2 - pointer buffer.
113+
src_indices = src_buffer_1;
114+
src_pointers = src_buffer_2;
115+
} else if (src_d.encoding() == sparse_encoding::coo) {
116+
// For COO encodings, the two index buffers hold the row and column
117+
// indices respectively. For CSR conversion, the row indices are
118+
// compressed to generate the CSR pointers.
119+
src_indices = src_buffer_2;
120+
121+
int32_t *src_row_pointers = scratchpad.template get<int32_t>(
122+
memory_tracking::names::key_matmul_sparse_tmp_ptr);
123+
124+
parallel_nd(M + 1, [&](dim_t m) {
125+
io::store_float_value(
126+
src_d.metadata_type(0), 0, src_row_pointers, m);
127+
});
128+
129+
cvt_coo_indices_to_csr_pointers(
130+
src_buffer_1, src_row_pointers, src_d.nnz(), M);
131+
src_pointers = src_row_pointers;
132+
}
133+
134+
run_csr_kernel(weights, src_values, src_indices, src_pointers, dst, M,
135+
N, K, mm_dt, src_d.is_sparse_desc());
136+
}
137+
return status::success;
138+
}
139+
140+
void ref_sparse_matmul_t::cvt_coo_indices_to_csr_pointers(
141+
const int32_t *indices, int32_t *pointers, const int nnz,
142+
const int nrows) const {
143+
parallel_nd(
144+
nnz, [&](dim_t i) { fetch_and_add(&pointers[indices[i] + 1], 1); });
145+
for (int i = 0; i < nrows; ++i) {
146+
pointers[i + 1] += pointers[i];
147+
}
148+
}
149+
150+
void ref_sparse_matmul_t::run_csr_kernel(const void *dmat, const void *values,
151+
const int32_t *indices, const int32_t *pointers, void *res,
152+
const dim_t M, const dim_t N, const dim_t K, const data_type_t mm_dt,
153+
bool is_src_sparse) const {
154+
155+
if (is_src_sparse) {
156+
// With a sparse source tensor, the matrix multiplication is carried out
157+
// for a sparse multiplier with parallelization over the sparse rows
158+
// of the multiplier matrix.
51159
parallel_nd(M, [&](dim_t m) {
52-
for (dim_t k = 0; k < K; k++) {
53-
const dim_t row_start = wei_pointers[k];
54-
const dim_t row_end = wei_pointers[k + 1];
55-
for (dim_t n = row_start; n < row_end; n++) {
56-
const dim_t src_idx = m * K + k;
57-
const dim_t dst_idx = m * N + wei_indices[n];
58-
dst[dst_idx] = dst[dst_idx] + src[src_idx] * wei_values[n];
160+
const dim_t row_start = pointers[m];
161+
const dim_t row_end = pointers[m + 1];
162+
163+
for (dim_t n = 0; n < N; n++) {
164+
const dim_t c_idx = m * N + n;
165+
float c_val = io::load_float_value(mm_dt, res, c_idx);
166+
167+
for (dim_t k = row_start; k < row_end; k++) {
168+
const dim_t b_idx = indices[k] * N + n;
169+
const float a_val = io::load_float_value(mm_dt, values, k);
170+
const float b_val
171+
= io::load_float_value(mm_dt, dmat, b_idx);
172+
c_val += a_val * b_val;
59173
}
174+
io::store_float_value(mm_dt, c_val, res, c_idx);
60175
}
61176
});
62-
} else if (src_d.is_sparse_desc()) {
63-
const auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS);
64-
const auto src_values = CTX_IN_MEM(const float *, DNNL_ARG_SRC, 0);
65-
const auto src_indices = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 1);
66-
const auto src_pointers = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 2);
67-
177+
} else {
178+
// With a sparse weights tensor, the matrix multiplication is carried
179+
// out for a sparse multiplicand with parallelization over the dense
180+
// rows of the multiplier matrix.
68181
parallel_nd(M, [&](dim_t m) {
69-
const dim_t row_start = src_pointers[m];
70-
const dim_t row_end = src_pointers[m + 1];
71-
for (dim_t k = row_start; k < row_end; k++) {
72-
for (dim_t n = 0; n < N; n++) {
73-
const dim_t dst_idx = m * N + n;
74-
const dim_t wei_idx = src_indices[k] * N + n;
75-
dst[dst_idx]
76-
= dst[dst_idx] + src_values[k] * weights[wei_idx];
182+
for (dim_t k = 0; k < K; k++) {
183+
const dim_t row_start = pointers[k];
184+
const dim_t row_end = pointers[k + 1];
185+
for (dim_t n = row_start; n < row_end; n++) {
186+
const dim_t a_idx = m * K + k;
187+
const dim_t c_idx = m * N + indices[n];
188+
const float a_val
189+
= io::load_float_value(mm_dt, dmat, a_idx);
190+
const float b_val = io::load_float_value(mm_dt, values, n);
191+
float c_val = io::load_float_value(mm_dt, res, c_idx);
192+
c_val += a_val * b_val;
193+
io::store_float_value(mm_dt, c_val, res, c_idx);
77194
}
78195
}
79196
});
80197
}
81-
82-
return status::success;
83198
}
84199

85200
} // namespace matmul

src/cpu/matmul/ref_sparse_matmul.hpp

+87-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2023 Intel Corporation
2+
* Copyright 2023-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -44,25 +44,62 @@ struct ref_sparse_matmul_t : public primitive_t {
4444
memory_desc_wrapper src_d(src_md());
4545
memory_desc_wrapper wei_d(weights_md(0));
4646

47-
const bool ok
48-
= utils::everyone_is(f32, src_type, wei_type, dst_type)
49-
&& utils::one_of(true, wei_d.is_sparse_desc(),
50-
src_d.is_sparse_desc())
51-
&& IMPLICATION(wei_d.is_sparse_desc(),
52-
wei_d.encoding() == sparse_encoding::csr)
53-
&& IMPLICATION(src_d.is_sparse_desc(),
54-
src_d.encoding() == sparse_encoding::csr)
55-
&& IMPLICATION(
56-
wei_d.is_sparse_desc(), !src_d.is_sparse_desc())
57-
&& IMPLICATION(src_d.is_sparse_desc(),
58-
utils::everyone_is(s32, src_d.metadata_type(0),
59-
src_d.metadata_type(1)))
60-
&& IMPLICATION(wei_d.is_sparse_desc(),
61-
utils::everyone_is(s32, wei_d.metadata_type(0),
62-
wei_d.metadata_type(1)))
63-
&& !with_bias() && attr()->has_default_values()
64-
&& set_default_formats() && formats_ok(src_d, wei_d);
65-
return ok ? status::success : status::unimplemented;
47+
VDISPATCH_MATMUL(wei_d.is_sparse_desc() || src_d.is_sparse_desc(),
48+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
49+
VDISPATCH_MATMUL(wei_d.is_sparse_desc() ^ src_d.is_sparse_desc(),
50+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
51+
52+
VDISPATCH_MATMUL(IMPLICATION(src_d.is_sparse_desc(),
53+
utils::one_of(src_d.encoding(),
54+
sparse_encoding::csr,
55+
sparse_encoding::coo)),
56+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
57+
VDISPATCH_MATMUL(IMPLICATION(wei_d.is_sparse_desc(),
58+
utils::one_of(wei_d.encoding(),
59+
sparse_encoding::csr,
60+
sparse_encoding::coo)),
61+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
62+
63+
VDISPATCH_MATMUL(
64+
utils::everyone_is(f16, src_type, wei_type, dst_type)
65+
|| utils::everyone_is(
66+
f32, src_type, wei_type, dst_type),
67+
VERBOSE_UNSUPPORTED_DT_CFG);
68+
69+
if (src_d.is_sparse_desc()) {
70+
sparse_mem_encoding = src_d.encoding();
71+
VDISPATCH_MATMUL(
72+
IMPLICATION(sparse_mem_encoding == sparse_encoding::coo,
73+
s32 == src_d.metadata_type(0)),
74+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
75+
VDISPATCH_MATMUL(
76+
IMPLICATION(sparse_mem_encoding == sparse_encoding::csr,
77+
utils::everyone_is(s32, src_d.metadata_type(0),
78+
src_d.metadata_type(1))),
79+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
80+
}
81+
if (wei_d.is_sparse_desc()) {
82+
sparse_mem_encoding = wei_d.encoding();
83+
VDISPATCH_MATMUL(
84+
IMPLICATION(sparse_mem_encoding == sparse_encoding::coo,
85+
s32 == wei_d.metadata_type(0)),
86+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
87+
88+
VDISPATCH_MATMUL(
89+
IMPLICATION(sparse_mem_encoding == sparse_encoding::csr,
90+
utils::everyone_is(s32, wei_d.metadata_type(0),
91+
wei_d.metadata_type(1))),
92+
VERBOSE_UNSUPPORTED_SPARSE_CFG);
93+
}
94+
95+
VDISPATCH_MATMUL(!with_bias(), VERBOSE_UNSUPPORTED_BIAS_CFG);
96+
VDISPATCH_MATMUL(
97+
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
98+
VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_ATTR);
99+
VDISPATCH_MATMUL(formats_ok(src_d, wei_d), VERBOSE_UNSUPPORTED_TAG);
100+
101+
init_scratchpad();
102+
return status::success;
66103
}
67104

68105
bool formats_ok(const memory_desc_wrapper &src_d,
@@ -76,10 +113,40 @@ struct ref_sparse_matmul_t : public primitive_t {
76113
return src_d.matches_one_of_tag(format_tag::ab);
77114
return false;
78115
}
116+
117+
private:
118+
void init_scratchpad() {
119+
using namespace memory_tracking::names;
120+
const memory_desc_wrapper src_d(src_md());
121+
const memory_desc_wrapper wei_d(weights_md());
122+
123+
if (sparse_mem_encoding == sparse_encoding::coo) {
124+
auto scratchpad = scratchpad_registry().registrar();
125+
const auto ptr_size
126+
= src_d.dims()[(int)wei_d.is_sparse_desc()] + 1;
127+
scratchpad.template book<int32_t>(
128+
key_matmul_sparse_tmp_ptr, ptr_size);
129+
}
130+
}
131+
132+
sparse_encoding_t sparse_mem_encoding = sparse_encoding::undef;
79133
};
80134

81135
ref_sparse_matmul_t(const pd_t *apd) : primitive_t(apd) {}
82136

137+
// COO sparse encodings are converted to CSR format by
138+
// compressing the respective row indices into CSR pointers.
139+
void cvt_coo_indices_to_csr_pointers(const int32_t *indices,
140+
int32_t *pointers, const int nnz, const int nrows) const;
141+
142+
// Executes the matrix mutiplication, C = A x B where one of the input
143+
// matrices is dense. Operation indices are determined depending on
144+
// whether the mulitplier or multiplicand is dense
145+
void run_csr_kernel(const void *dmat, const void *values,
146+
const int32_t *indices, const int32_t *pointers, void *res,
147+
const dim_t M, const dim_t N, const dim_t K,
148+
const data_type_t mm_dt, bool is_src_sparse) const;
149+
83150
status_t execute(const exec_ctx_t &ctx) const override;
84151

85152
private:

0 commit comments

Comments
 (0)