Skip to content

Commit 2c4dd7e

Browse files
vqd8aVinh Quang Dang (-EXP)
and
Vinh Quang Dang (-EXP)
authored
Add MAGMA TPL support for GESV on HIP backend (#2326)
* Small changes for MAGMA GESV on HIP * Apply clang-format * Relax eps to 1e-8 for multi-rhs tests --------- Co-authored-by: Vinh Quang Dang (-EXP) <vqdang@kokkos-dev-2.sandia.gov>
1 parent f26fbca commit 2c4dd7e

4 files changed

+119
-57
lines changed

lapack/src/KokkosLapack_gesv.hpp

+16
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,15 @@ void gesv(const ExecutionSpace& space, const AMatrix& A, const BXMV& B, const IP
6363
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename AMatrix::memory_space>::accessible);
6464
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename BXMV::memory_space>::accessible);
6565
#if defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA)
66+
#if defined(KOKKOS_ENABLE_CUDA)
6667
if constexpr (!std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
6768
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename IPIVV::memory_space>::accessible);
6869
}
70+
#elif defined(KOKKOS_ENABLE_HIP)
71+
if constexpr (!std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
72+
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename IPIVV::memory_space>::accessible);
73+
}
74+
#endif
6975
#else
7076
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename IPIVV::memory_space>::accessible);
7177
#endif
@@ -96,13 +102,23 @@ void gesv(const ExecutionSpace& space, const AMatrix& A, const BXMV& B, const IP
96102
// Check for no pivoting case. Only MAGMA supports no pivoting interface
97103
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA // have MAGMA TPL
98104
#ifdef KOKKOSKERNELS_ENABLE_TPL_LAPACK // and have LAPACK TPL
105+
#if defined(KOKKOS_ENABLE_CUDA)
99106
if ((!std::is_same<typename AMatrix::device_type::memory_space, Kokkos::CudaSpace>::value) && (IPIV0 == 0) &&
100107
(IPIV.data() == nullptr)) {
101108
std::ostringstream os;
102109
os << "KokkosLapack::gesv: IPIV: " << IPIV0 << ". "
103110
<< "LAPACK TPL does not support no pivoting.";
104111
KokkosKernels::Impl::throw_runtime_exception(os.str());
105112
}
113+
#elif defined(KOKKOS_ENABLE_HIP)
114+
if ((!std::is_same<typename AMatrix::device_type::memory_space, Kokkos::HIPSpace>::value) && (IPIV0 == 0) &&
115+
(IPIV.data() == nullptr)) {
116+
std::ostringstream os;
117+
os << "KokkosLapack::gesv: IPIV: " << IPIV0 << ". "
118+
<< "LAPACK TPL does not support no pivoting.";
119+
KokkosKernels::Impl::throw_runtime_exception(os.str());
120+
}
121+
#endif
106122
#endif
107123
#else // not have MAGMA TPL
108124
#ifdef KOKKOSKERNELS_ENABLE_TPL_LAPACK // but have LAPACK TPL

lapack/tpls/KokkosLapack_gesv_tpl_spec_avail.hpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,28 @@ KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_LAPACK(Kokkos::complex<float>, Kokkos::LayoutLe
5252

5353
namespace KokkosLapack {
5454
namespace Impl {
55-
#define KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(SCALAR, LAYOUT, MEMSPACE) \
56-
template <> \
57-
struct gesv_tpl_spec_avail< \
58-
Kokkos::Cuda, \
59-
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEMSPACE>, \
60-
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
61-
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEMSPACE>, \
62-
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
63-
Kokkos::View<magma_int_t*, LAYOUT, Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
64-
Kokkos::MemoryTraits<Kokkos::Unmanaged> > > { \
65-
enum : bool { value = true }; \
55+
#define KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
56+
template <> \
57+
struct gesv_tpl_spec_avail< \
58+
EXECSPACE, \
59+
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
60+
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
61+
Kokkos::View<magma_int_t*, LAYOUT, Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
62+
Kokkos::MemoryTraits<Kokkos::Unmanaged> > > { \
63+
enum : bool { value = true }; \
6664
};
67-
68-
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(double, Kokkos::LayoutLeft, Kokkos::CudaSpace)
69-
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(float, Kokkos::LayoutLeft, Kokkos::CudaSpace)
70-
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<double>, Kokkos::LayoutLeft, Kokkos::CudaSpace)
71-
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<float>, Kokkos::LayoutLeft, Kokkos::CudaSpace)
65+
#if defined(KOKKOS_ENABLE_CUDA)
66+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(double, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
67+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(float, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
68+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<double>, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
69+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<float>, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
70+
#endif
71+
#if defined(KOKKOS_ENABLE_HIP)
72+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(double, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
73+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(float, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
74+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<double>, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
75+
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<float>, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
76+
#endif
7277
} // namespace Impl
7378
} // namespace KokkosLapack
7479
#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA

lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp

+40-34
Original file line numberDiff line numberDiff line change
@@ -197,42 +197,48 @@ void magmaGesvWrapper(const ExecSpace& space, const AViewType& A, const BViewTyp
197197
Kokkos::Profiling::popRegion();
198198
}
199199

200-
#define KOKKOSLAPACK_GESV_MAGMA(SCALAR, LAYOUT, MEM_SPACE) \
201-
template <> \
202-
struct GESV<Kokkos::Cuda, \
203-
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEM_SPACE>, \
204-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
205-
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEM_SPACE>, \
206-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
207-
Kokkos::View<magma_int_t*, LAYOUT, Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
208-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
209-
true, \
210-
gesv_eti_spec_avail<Kokkos::Cuda, \
211-
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEM_SPACE>, \
212-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
213-
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEM_SPACE>, \
214-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
215-
Kokkos::View<magma_int_t*, LAYOUT, \
216-
Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
217-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>>::value> { \
218-
using AViewType = Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEM_SPACE>, \
219-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
220-
using BViewType = Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::Cuda, MEM_SPACE>, \
221-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
222-
using PViewType = \
223-
Kokkos::View<magma_int_t*, LAYOUT, Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
224-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
225-
\
226-
static void gesv(const Kokkos::Cuda& space, const AViewType& A, const BViewType& B, const PViewType& IPIV) { \
227-
magmaGesvWrapper(space, A, B, IPIV); \
228-
} \
200+
#define KOKKOSLAPACK_GESV_MAGMA(SCALAR, LAYOUT, EXEC_SPACE, MEM_SPACE) \
201+
template <> \
202+
struct GESV< \
203+
EXEC_SPACE, \
204+
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
205+
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
206+
Kokkos::View<magma_int_t*, LAYOUT, Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
207+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
208+
true, \
209+
gesv_eti_spec_avail< \
210+
EXEC_SPACE, \
211+
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
212+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
213+
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
214+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
215+
Kokkos::View<magma_int_t*, LAYOUT, Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
216+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>>::value> { \
217+
using AViewType = Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
218+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
219+
using BViewType = Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
220+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
221+
using PViewType = \
222+
Kokkos::View<magma_int_t*, LAYOUT, Kokkos::Device<Kokkos::DefaultHostExecutionSpace, Kokkos::HostSpace>, \
223+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
224+
\
225+
static void gesv(const EXEC_SPACE& space, const AViewType& A, const BViewType& B, const PViewType& IPIV) { \
226+
magmaGesvWrapper(space, A, B, IPIV); \
227+
} \
229228
};
230229

231-
KOKKOSLAPACK_GESV_MAGMA(float, Kokkos::LayoutLeft, Kokkos::CudaSpace)
232-
KOKKOSLAPACK_GESV_MAGMA(double, Kokkos::LayoutLeft, Kokkos::CudaSpace)
233-
KOKKOSLAPACK_GESV_MAGMA(Kokkos::complex<float>, Kokkos::LayoutLeft, Kokkos::CudaSpace)
234-
KOKKOSLAPACK_GESV_MAGMA(Kokkos::complex<double>, Kokkos::LayoutLeft, Kokkos::CudaSpace)
235-
230+
#if defined(KOKKOS_ENABLE_CUDA)
231+
KOKKOSLAPACK_GESV_MAGMA(float, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
232+
KOKKOSLAPACK_GESV_MAGMA(double, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
233+
KOKKOSLAPACK_GESV_MAGMA(Kokkos::complex<float>, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
234+
KOKKOSLAPACK_GESV_MAGMA(Kokkos::complex<double>, Kokkos::LayoutLeft, Kokkos::Cuda, Kokkos::CudaSpace)
235+
#endif
236+
#if defined(KOKKOS_ENABLE_HIP)
237+
KOKKOSLAPACK_GESV_MAGMA(float, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
238+
KOKKOSLAPACK_GESV_MAGMA(double, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
239+
KOKKOSLAPACK_GESV_MAGMA(Kokkos::complex<float>, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
240+
KOKKOSLAPACK_GESV_MAGMA(Kokkos::complex<double>, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace)
241+
#endif
236242
} // namespace Impl
237243
} // namespace KokkosLapack
238244
#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA

0 commit comments

Comments
 (0)