1
1
/* ******************************************************************************
2
- * Copyright 2023 Intel Corporation
2
+ * Copyright 2023-2024 Intel Corporation
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
18
18
#include " common/math_utils.hpp"
19
19
#include " common/type_helpers.hpp"
20
20
21
+ #include " cpu/ref_io_helper.hpp"
22
+
21
23
#include " cpu/matmul/ref_sparse_matmul.hpp"
22
24
23
25
namespace dnnl {
@@ -27,7 +29,7 @@ namespace matmul {
27
29
28
30
status_t ref_sparse_matmul_t::execute (const exec_ctx_t &ctx) const {
29
31
status_t status = status::success;
30
- auto dst = CTX_OUT_CLEAN_MEM (float *, DNNL_ARG_DST, status);
32
+ auto dst = CTX_OUT_CLEAN_MEM (void *, DNNL_ARG_DST, status);
31
33
CHECK (status);
32
34
33
35
const auto src_d = ctx.memory_mdw (DNNL_ARG_SRC, pd ()->src_md ());
@@ -38,48 +40,161 @@ status_t ref_sparse_matmul_t::execute(const exec_ctx_t &ctx) const {
38
40
const dim_t N = dst_d.dims ()[1 ];
39
41
const dim_t K = src_d.dims ()[1 ];
40
42
41
- parallel_nd (M, N, [&](dim_t i, dim_t j) { dst[i * N + j] = 0 .0f ; });
43
+ const data_type_t mm_dt = src_d.data_type ();
44
+ auto scratchpad = ctx.get_scratchpad_grantor ();
45
+
46
+ parallel_nd (M, N, [&](dim_t i, dim_t j) {
47
+ const dim_t dst_idx = i * N + j;
48
+ io::store_float_value (dst_d.data_type (), 0 .0f , dst, dst_idx);
49
+ });
42
50
43
51
if (weights_d.is_sparse_desc ()) {
44
- const auto src = CTX_IN_MEM (const float *, DNNL_ARG_SRC);
45
- const auto wei_values = CTX_IN_MEM (const float *, DNNL_ARG_WEIGHTS, 0 );
46
- const auto wei_indices
47
- = CTX_IN_MEM (const int32_t *, DNNL_ARG_WEIGHTS, 1 );
48
- const auto wei_pointers
49
- = CTX_IN_MEM (const int32_t *, DNNL_ARG_WEIGHTS, 2 );
50
52
53
+ const auto src = CTX_IN_MEM (const void *, DNNL_ARG_SRC);
54
+ const auto wei_values = CTX_IN_MEM (const void *, DNNL_ARG_WEIGHTS, 0 );
55
+ auto wei_buffer_1 = CTX_IN_MEM (const int32_t *, DNNL_ARG_WEIGHTS, 1 );
56
+ auto wei_buffer_2 = CTX_IN_MEM (const int32_t *, DNNL_ARG_WEIGHTS, 2 );
57
+
58
+ // Both COO and CSR encoded data is operated on using CSR kernel for
59
+ // matrix multiplication.
60
+ // For COO encoding, data preparation includes using a temporary
61
+ // buffer to convert the data to the CSR format.
62
+ // Matrix multiplication is then carried out using the CSR encoded data.
63
+ const int32_t *wei_indices;
64
+ const int32_t *wei_pointers;
65
+
66
+ if (weights_d.encoding () == sparse_encoding::csr) {
67
+ // For CSR encodings, pointer and indices assignment is
68
+ // staightforward as,
69
+ // index 1 - index buffer, index 2 - pointer buffer.
70
+ wei_indices = wei_buffer_1;
71
+ wei_pointers = wei_buffer_2;
72
+ } else if (weights_d.encoding () == sparse_encoding::coo) {
73
+ // For COO encodings, the two index buffers hold the row and column
74
+ // indices respectively. For CSR conversion, the row indices are
75
+ // compressed to generate the CSR pointers.
76
+ wei_indices = wei_buffer_2;
77
+
78
+ int32_t *wei_row_pointers = scratchpad.template get <int32_t >(
79
+ memory_tracking::names::key_matmul_sparse_tmp_ptr);
80
+
81
+ parallel_nd (K + 1 , [&](dim_t k) {
82
+ io::store_float_value (
83
+ weights_d.metadata_type (0 ), 0 , wei_row_pointers, k);
84
+ });
85
+
86
+ cvt_coo_indices_to_csr_pointers (
87
+ wei_buffer_1, wei_row_pointers, weights_d.nnz (), K);
88
+
89
+ wei_pointers = wei_row_pointers;
90
+ }
91
+
92
+ run_csr_kernel (src, wei_values, wei_indices, wei_pointers, dst, M, N, K,
93
+ mm_dt, src_d.is_sparse_desc ());
94
+
95
+ } else if (src_d.is_sparse_desc ()) {
96
+ const auto weights = CTX_IN_MEM (const void *, DNNL_ARG_WEIGHTS);
97
+ const auto src_values = CTX_IN_MEM (const void *, DNNL_ARG_SRC, 0 );
98
+ auto src_buffer_1 = CTX_IN_MEM (const int32_t *, DNNL_ARG_SRC, 1 );
99
+ auto src_buffer_2 = CTX_IN_MEM (const int32_t *, DNNL_ARG_SRC, 2 );
100
+
101
+ // Both COO and CSR encoded data is operated on using CSR kernel for
102
+ // matrix multiplication.
103
+ // For COO encoding, data preparation includes using a temporary
104
+ // buffer to convert the data to the CSR format.
105
+ // Matrix multiplication is then carried out using the CSR encoded data.
106
+ const int32_t *src_indices;
107
+ const int32_t *src_pointers;
108
+
109
+ if (src_d.encoding () == sparse_encoding::csr) {
110
+ // For CSR encodings, pointer and indices assignment is
111
+ // staightforward as
112
+ // index 1 - index buffer, index 2 - pointer buffer.
113
+ src_indices = src_buffer_1;
114
+ src_pointers = src_buffer_2;
115
+ } else if (src_d.encoding () == sparse_encoding::coo) {
116
+ // For COO encodings, the two index buffers hold the row and column
117
+ // indices respectively. For CSR conversion, the row indices are
118
+ // compressed to generate the CSR pointers.
119
+ src_indices = src_buffer_2;
120
+
121
+ int32_t *src_row_pointers = scratchpad.template get <int32_t >(
122
+ memory_tracking::names::key_matmul_sparse_tmp_ptr);
123
+
124
+ parallel_nd (M + 1 , [&](dim_t m) {
125
+ io::store_float_value (
126
+ src_d.metadata_type (0 ), 0 , src_row_pointers, m);
127
+ });
128
+
129
+ cvt_coo_indices_to_csr_pointers (
130
+ src_buffer_1, src_row_pointers, src_d.nnz (), M);
131
+ src_pointers = src_row_pointers;
132
+ }
133
+
134
+ run_csr_kernel (weights, src_values, src_indices, src_pointers, dst, M,
135
+ N, K, mm_dt, src_d.is_sparse_desc ());
136
+ }
137
+ return status::success;
138
+ }
139
+
140
+ void ref_sparse_matmul_t::cvt_coo_indices_to_csr_pointers (
141
+ const int32_t *indices, int32_t *pointers, const int nnz,
142
+ const int nrows) const {
143
+ parallel_nd (
144
+ nnz, [&](dim_t i) { fetch_and_add (&pointers[indices[i] + 1 ], 1 ); });
145
+ for (int i = 0 ; i < nrows; ++i) {
146
+ pointers[i + 1 ] += pointers[i];
147
+ }
148
+ }
149
+
150
+ void ref_sparse_matmul_t::run_csr_kernel (const void *dmat, const void *values,
151
+ const int32_t *indices, const int32_t *pointers, void *res,
152
+ const dim_t M, const dim_t N, const dim_t K, const data_type_t mm_dt,
153
+ bool is_src_sparse) const {
154
+
155
+ if (is_src_sparse) {
156
+ // With a sparse source tensor, the matrix multiplication is carried out
157
+ // for a sparse multiplier with parallelization over the sparse rows
158
+ // of the multiplier matrix.
51
159
parallel_nd (M, [&](dim_t m) {
52
- for (dim_t k = 0 ; k < K; k++) {
53
- const dim_t row_start = wei_pointers[k];
54
- const dim_t row_end = wei_pointers[k + 1 ];
55
- for (dim_t n = row_start; n < row_end; n++) {
56
- const dim_t src_idx = m * K + k;
57
- const dim_t dst_idx = m * N + wei_indices[n];
58
- dst[dst_idx] = dst[dst_idx] + src[src_idx] * wei_values[n];
160
+ const dim_t row_start = pointers[m];
161
+ const dim_t row_end = pointers[m + 1 ];
162
+
163
+ for (dim_t n = 0 ; n < N; n++) {
164
+ const dim_t c_idx = m * N + n;
165
+ float c_val = io::load_float_value (mm_dt, res, c_idx);
166
+
167
+ for (dim_t k = row_start; k < row_end; k++) {
168
+ const dim_t b_idx = indices[k] * N + n;
169
+ const float a_val = io::load_float_value (mm_dt, values, k);
170
+ const float b_val
171
+ = io::load_float_value (mm_dt, dmat, b_idx);
172
+ c_val += a_val * b_val;
59
173
}
174
+ io::store_float_value (mm_dt, c_val, res, c_idx);
60
175
}
61
176
});
62
- } else if (src_d.is_sparse_desc ()) {
63
- const auto weights = CTX_IN_MEM (const float *, DNNL_ARG_WEIGHTS);
64
- const auto src_values = CTX_IN_MEM (const float *, DNNL_ARG_SRC, 0 );
65
- const auto src_indices = CTX_IN_MEM (const int32_t *, DNNL_ARG_SRC, 1 );
66
- const auto src_pointers = CTX_IN_MEM (const int32_t *, DNNL_ARG_SRC, 2 );
67
-
177
+ } else {
178
+ // With a sparse weights tensor, the matrix multiplication is carried
179
+ // out for a sparse multiplicand with parallelization over the dense
180
+ // rows of the multiplier matrix.
68
181
parallel_nd (M, [&](dim_t m) {
69
- const dim_t row_start = src_pointers[m];
70
- const dim_t row_end = src_pointers[m + 1 ];
71
- for (dim_t k = row_start; k < row_end; k++) {
72
- for (dim_t n = 0 ; n < N; n++) {
73
- const dim_t dst_idx = m * N + n;
74
- const dim_t wei_idx = src_indices[k] * N + n;
75
- dst[dst_idx]
76
- = dst[dst_idx] + src_values[k] * weights[wei_idx];
182
+ for (dim_t k = 0 ; k < K; k++) {
183
+ const dim_t row_start = pointers[k];
184
+ const dim_t row_end = pointers[k + 1 ];
185
+ for (dim_t n = row_start; n < row_end; n++) {
186
+ const dim_t a_idx = m * K + k;
187
+ const dim_t c_idx = m * N + indices[n];
188
+ const float a_val
189
+ = io::load_float_value (mm_dt, dmat, a_idx);
190
+ const float b_val = io::load_float_value (mm_dt, values, n);
191
+ float c_val = io::load_float_value (mm_dt, res, c_idx);
192
+ c_val += a_val * b_val;
193
+ io::store_float_value (mm_dt, c_val, res, c_idx);
77
194
}
78
195
}
79
196
});
80
197
}
81
-
82
- return status::success;
83
198
}
84
199
85
200
} // namespace matmul
0 commit comments