Skip to content

Commit 7685f05

Browse files
committed
examples: add example for coo sparse matmul
1 parent 9233c5a commit 7685f05

File tree

3 files changed

+111
-2
lines changed

3 files changed

+111
-2
lines changed

examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ file(GLOB_RECURSE headers *.hpp *.h)
5757

5858
if(NOT DNNL_EXPERIMENTAL_SPARSE)
5959
list(REMOVE_ITEM sources ${CMAKE_CURRENT_SOURCE_DIR}/cpu_matmul_csr.cpp)
60+
list(REMOVE_ITEM sources ${CMAKE_CURRENT_SOURCE_DIR}/cpu_matmul_coo.cpp)
6061
list(REMOVE_ITEM sources ${CMAKE_CURRENT_SOURCE_DIR}/cpu_matmul_weights_compression.cpp)
6162
endif()
6263

examples/cpu_matmul_coo.cpp

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*******************************************************************************
2+
* Copyright 2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @example cpu_matmul_coo.cpp
18+
/// > Annotated version: @ref cpu_matmul_coo_cpp
19+
///
20+
/// This C++ API example demonstrates how to create and execute a
21+
/// [MatMul](@ref dev_guide_matmul) primitive that uses a source tensor
22+
/// encoded with the COO sparse encoding.
23+
///
24+
/// @page cpu_matmul_coo_cpp MatMul Primitive Example
25+
///
26+
/// @include cpu_matmul_coo.cpp
27+
28+
#include <algorithm>
29+
#include <cmath>
30+
#include <iostream>
31+
#include <string>
32+
#include <vector>
33+
34+
#include "dnnl.hpp"
35+
#include "example_utils.hpp"
36+
37+
using namespace dnnl;
38+
39+
using tag = memory::format_tag;
40+
using dt = memory::data_type;
41+
42+
bool check_result(dnnl::memory dst_mem) {
43+
// clang-format off
44+
const std::vector<float> expected_result = {8.750000, 11.250000, 2.500000,
45+
6.000000, 2.250000, 3.750000,
46+
19.000000, 15.500000, 5.250000,
47+
4.000000, 7.000000, 3.000000};
48+
// clang-format on
49+
std::vector<float> dst_data(expected_result.size());
50+
read_from_dnnl_memory(dst_data.data(), dst_mem);
51+
return expected_result == dst_data;
52+
}
53+
54+
void sparse_matmul() {
55+
dnnl::engine engine(engine::kind::cpu, 0);
56+
57+
const memory::dim M = 4;
58+
const memory::dim N = 3;
59+
const memory::dim K = 6;
60+
61+
// A sparse matrix represented in the COO format.
62+
std::vector<float> src_coo_values = {2.5f, 1.5f, 1.5f, 2.5f, 2.0f};
63+
std::vector<int32_t> src_coo_row_indices = {0, 1, 2, 2, 3};
64+
std::vector<int32_t> src_coo_col_indices = {0, 2, 0, 5, 1};
65+
66+
// clang-format off
67+
std::vector<float> weights_data = {3.5f, 4.5f, 1.0f,
68+
2.0f, 3.5f, 1.5f,
69+
4.0f, 1.5f, 2.5f,
70+
3.5f, 5.5f, 4.5f,
71+
1.5f, 2.5f, 5.5f,
72+
5.5f, 3.5f, 1.5f};
73+
// clang-format on
74+
75+
const int nnz = static_cast<int>(src_coo_values.size());
76+
77+
// Create a memory descriptor for COO format by providing information
78+
// about number of non-zero entries and data types of metadata.
79+
const auto src_coo_md = memory::desc::coo({M, K}, dt::f32, nnz, dt::s32);
80+
const auto wei_md = memory::desc({K, N}, dt::f32, tag::oi);
81+
const auto dst_md = memory::desc({M, N}, dt::f32, tag::nc);
82+
83+
// This memory is created for the given values and metadata of COO format.
84+
memory src_coo_mem(src_coo_md, engine,
85+
{src_coo_values.data(), src_coo_row_indices.data(),
86+
src_coo_col_indices.data()});
87+
memory wei_mem(wei_md, engine, weights_data.data());
88+
memory dst_mem(dst_md, engine);
89+
90+
dnnl::stream stream(engine);
91+
92+
auto sparse_matmul_pd
93+
= matmul::primitive_desc(engine, src_coo_md, wei_md, dst_md);
94+
auto sparse_matmul_prim = matmul(sparse_matmul_pd);
95+
96+
std::unordered_map<int, memory> sparse_matmul_args;
97+
sparse_matmul_args.insert({DNNL_ARG_SRC, src_coo_mem});
98+
sparse_matmul_args.insert({DNNL_ARG_WEIGHTS, wei_mem});
99+
sparse_matmul_args.insert({DNNL_ARG_DST, dst_mem});
100+
101+
sparse_matmul_prim.execute(stream, sparse_matmul_args);
102+
stream.wait();
103+
if (!check_result(dst_mem)) throw std::runtime_error("Unexpected output.");
104+
}
105+
106+
int main(int argc, char **argv) {
107+
return handle_example_errors({engine::kind::cpu}, sparse_matmul);
108+
}

examples/cpu_matmul_csr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2023 Intel Corporation
2+
* Copyright 2023-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@
1818
/// > Annotated version: @ref cpu_matmul_csr_cpp
1919
///
2020
/// This C++ API example demonstrates how to create and execute a
21-
/// [MatMul](@ref dev_guide_matmul) primitive that uses a weights tensor
21+
/// [MatMul](@ref dev_guide_matmul) primitive that uses a source tensor
2222
/// encoded with the CSR sparse encoding.
2323
///
2424
/// @page cpu_matmul_csr_cpp MatMul Primitive Example

0 commit comments

Comments
 (0)