-
Notifications
You must be signed in to change notification settings - Fork 102
/
Copy pathKokkosSparse_spgemm_symbolic.hpp
224 lines (195 loc) · 10.3 KB
/
KokkosSparse_spgemm_symbolic.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef _KOKKOS_SPGEMM_SYMBOLIC_HPP
#define _KOKKOS_SPGEMM_SYMBOLIC_HPP
#include "KokkosKernels_helpers.hpp"
#include "KokkosSparse_spgemm_symbolic_spec.hpp"
#include "KokkosSparse_Utils.hpp"
namespace KokkosSparse {
namespace Experimental {
template <typename KernelHandle, typename alno_row_view_t_,
typename alno_nnz_view_t_, typename blno_row_view_t_,
typename blno_nnz_view_t_, typename clno_row_view_t_>
void spgemm_symbolic(KernelHandle *handle,
typename KernelHandle::const_nnz_lno_t m,
typename KernelHandle::const_nnz_lno_t n,
typename KernelHandle::const_nnz_lno_t k,
alno_row_view_t_ row_mapA, alno_nnz_view_t_ entriesA,
bool transposeA, blno_row_view_t_ row_mapB,
blno_nnz_view_t_ entriesB, bool transposeB,
clno_row_view_t_ row_mapC, bool computeRowptrs = false) {
static_assert(
std::is_same<typename clno_row_view_t_::value_type,
typename clno_row_view_t_::non_const_value_type>::value,
"KokkosSparse::spgemm_symbolic: Output matrix rowmap must be non-const.");
static_assert(
std::is_same<typename KernelHandle::const_size_type,
typename alno_row_view_t_::const_value_type>::value,
"KokkosSparse::spgemm_symbolic: Size type of left handside matrix should "
"be same as kernelHandle sizetype.");
static_assert(
std::is_same<typename KernelHandle::const_size_type,
typename blno_row_view_t_::const_value_type>::value,
"KokkosSparse::spgemm_symbolic: Size type of right handside matrix "
"should be same as kernelHandle sizetype.");
static_assert(
std::is_same<typename KernelHandle::const_size_type,
typename clno_row_view_t_::const_value_type>::value,
"KokkosSparse::spgemm_symbolic: Size type of output matrix should be "
"same as kernelHandle sizetype.");
static_assert(
std::is_same<typename KernelHandle::const_nnz_lno_t,
typename alno_nnz_view_t_::const_value_type>::value,
"KokkosSparse::spgemm_symbolic: lno type of left handside matrix should "
"be same as kernelHandle lno_t.");
static_assert(
std::is_same<typename KernelHandle::const_nnz_lno_t,
typename blno_nnz_view_t_::const_value_type>::value,
"KokkosSparse::spgemm_symbolic: lno type of right handside matrix should "
"be same as kernelHandle lno_t.");
if (transposeA || transposeB) {
throw std::runtime_error(
"SpGEMM is not implemented for Transposes yet. "
"If you need this case please let kokkos-kernels developers know.\n");
}
typedef typename KernelHandle::const_size_type c_size_t;
typedef typename KernelHandle::const_nnz_lno_t c_lno_t;
typedef typename KernelHandle::const_nnz_scalar_t c_scalar_t;
typedef typename KernelHandle::HandleExecSpace c_exec_t;
typedef typename KernelHandle::HandleTempMemorySpace c_temp_t;
typedef typename KernelHandle::HandlePersistentMemorySpace c_persist_t;
typedef typename Kokkos::Device<c_exec_t, c_temp_t> UniformDevice_t;
std::cout << "Create const handle" << std::endl;
typedef typename KokkosKernels::Experimental::KokkosKernelsHandle<
c_size_t, c_lno_t, c_scalar_t, c_exec_t, c_temp_t, c_persist_t>
const_handle_type;
const_handle_type tmp_handle(*handle);
typedef Kokkos::View<typename alno_row_view_t_::const_value_type *,
typename KokkosKernels::Impl::GetUnifiedLayout<
alno_row_view_t_>::array_layout,
UniformDevice_t, // typename
// alno_row_view_t_::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
Internal_alno_row_view_t_;
typedef Kokkos::View<typename alno_nnz_view_t_::const_value_type *,
typename KokkosKernels::Impl::GetUnifiedLayout<
alno_nnz_view_t_>::array_layout,
UniformDevice_t, // typename
// alno_nnz_view_t_::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
Internal_alno_nnz_view_t_;
typedef Kokkos::View<typename blno_row_view_t_::const_value_type *,
typename KokkosKernels::Impl::GetUnifiedLayout<
blno_row_view_t_>::array_layout,
UniformDevice_t, // typename
// blno_row_view_t_::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
Internal_blno_row_view_t_;
typedef Kokkos::View<typename blno_nnz_view_t_::const_value_type *,
typename KokkosKernels::Impl::GetUnifiedLayout<
blno_nnz_view_t_>::array_layout,
UniformDevice_t, // typename
// blno_nnz_view_t_::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
Internal_blno_nnz_view_t_;
// static assert clno_row_view_t_ cannot be const type.
typedef Kokkos::View<typename clno_row_view_t_::non_const_value_type *,
typename KokkosKernels::Impl::GetUnifiedLayout<
clno_row_view_t_>::array_layout,
UniformDevice_t, // typename
// clno_row_view_t_::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
Internal_clno_row_view_t_;
std::cout << "Wrap views with Internal types" << std::endl;
Internal_alno_row_view_t_ const_a_r(row_mapA.data(), row_mapA.extent(0));
Internal_alno_nnz_view_t_ const_a_l(entriesA.data(), entriesA.extent(0));
Internal_blno_row_view_t_ const_b_r(row_mapB.data(), row_mapB.extent(0));
Internal_blno_nnz_view_t_ const_b_l(entriesB.data(), entriesB.extent(0));
Internal_clno_row_view_t_ c_r(row_mapC.data(), row_mapC.extent(0));
// Verify that graphs A and B are sorted.
// This test is designed to be as efficient as possible, but still skip
// it in a release build.
//
// Temporary fix for Trilinos issue #11655: Only perform this check if a TPL
// is to be called. The KokkosKernels (non-TPL) implementation does not
// actually require sorted indices yet. And Tpetra uses size_type = size_t, so
// it will (currently) not be calling a TPL path.
#ifndef NDEBUG
if constexpr (KokkosSparse::Impl::spgemm_symbolic_tpl_spec_avail<
const_handle_type, Internal_alno_row_view_t_,
Internal_alno_nnz_view_t_, Internal_blno_row_view_t_,
Internal_blno_nnz_view_t_,
Internal_clno_row_view_t_>::value) {
if (!KokkosSparse::Impl::isCrsGraphSorted(const_a_r, const_a_l))
throw std::runtime_error(
"KokkosSparse::spgemm_symbolic: entries of A are not sorted within "
"rows. May use KokkosSparse::sort_crs_matrix to sort it.");
if (!KokkosSparse::Impl::isCrsGraphSorted(const_b_r, const_b_l))
throw std::runtime_error(
"KokkosSparse::spgemm_symbolic: entries of B are not sorted within "
"rows. May use KokkosSparse::sort_crs_matrix to sort it.");
}
#endif
std::cout << "Extract and validate spgemm handle" << std::endl;
auto spgemmHandle = tmp_handle.get_spgemm_handle();
if (!spgemmHandle) {
throw std::invalid_argument(
"KokkosSparse::spgemm_symbolic: the given KernelHandle does not have "
"an SpGEMM handle associated with it.");
}
if (!spgemmHandle->checkMatrixIdentitiesSymbolic(const_a_r, const_a_l,
const_b_r, const_b_l)) {
throw std::invalid_argument(
"KokkosSparse::spgemm_symbolic: once used, an spgemm handle cannot be "
"reused for a product with a different sparsity pattern.\n"
"The rowptrs and entries of A and B must be identical to those "
"passed to the first spgemm_symbolic call.");
}
auto algo = spgemmHandle->get_algorithm_type();
if (algo == SPGEMM_DEBUG || algo == SPGEMM_SERIAL) {
// Never call a TPL if serial/debug is requested (this is needed for
// testing)
Kokkos::Profiling::pushRegion("KokkosSparse: spgemm_symbolic [serial/debug]");
std::cout << "KokkosSparse: spgemm_symbolic [serial/debug]" << std::endl;
KokkosSparse::Impl::SPGEMM_SYMBOLIC<
const_handle_type, // KernelHandle,
Internal_alno_row_view_t_, Internal_alno_nnz_view_t_,
Internal_blno_row_view_t_, Internal_blno_nnz_view_t_,
Internal_clno_row_view_t_,
false>::spgemm_symbolic(&tmp_handle, // handle,
m, n, k, const_a_r, const_a_l, transposeA,
const_b_r, const_b_l, transposeB, c_r,
computeRowptrs);
Kokkos::Profiling::popRegion();
} else {
Kokkos::Profiling::pushRegion("KokkosSparse: spgemm_symbolic []");
std::cout << "KokkosSparse: spgemm_symbolic []" << std::endl;
KokkosSparse::Impl::SPGEMM_SYMBOLIC<
const_handle_type, // KernelHandle,
Internal_alno_row_view_t_, Internal_alno_nnz_view_t_,
Internal_blno_row_view_t_, Internal_blno_nnz_view_t_,
Internal_clno_row_view_t_>::spgemm_symbolic(&tmp_handle, // handle,
m, n, k, const_a_r,
const_a_l, transposeA,
const_b_r, const_b_l,
transposeB, c_r,
computeRowptrs);
Kokkos::Profiling::popRegion();
}
}
} // namespace Experimental
} // namespace KokkosSparse
#endif