Skip to content

Commit e2b240a

Browse files
authored
Merge pull request #2050 from seanofthemillers/rocm6_deprecation_fixes
ROCm 6 deprecation fixes for rocsparse
2 parents c15b51e + d1bf499 commit e2b240a

3 files changed

+55
-5
lines changed

sparse/src/KokkosSparse_Utils_rocsparse.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
#include <sstream>
2222

2323
#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
24+
#if __has_include(<rocm-core/rocm_version.h>)
25+
#include <rocm-core/rocm_version.h>
26+
#else
2427
#include <rocm_version.h>
25-
#include "rocsparse/rocsparse.h"
28+
#endif
29+
#include <rocsparse/rocsparse.h>
2630

2731
namespace KokkosSparse {
2832
namespace Impl {

sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp

+39-1
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,46 @@ void spmv_block_impl_rocsparse(
869869
rocsparse_mat_info info;
870870
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_create_mat_info(&info));
871871

872+
// *_ex* functions deprecated in introduced in 6+
873+
#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 60000
874+
if constexpr (std::is_same_v<value_type, float>) {
875+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv_analysis(
876+
handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr,
877+
bsr_col_ind, block_dim, info));
878+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv(
879+
handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr,
880+
bsr_col_ind, block_dim, info, x_, beta_, y_));
881+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info));
882+
} else if constexpr (std::is_same_v<value_type, double>) {
883+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_dbsrmv_analysis(
884+
handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr,
885+
bsr_col_ind, block_dim, info));
886+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_dbsrmv(
887+
handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr,
888+
bsr_col_ind, block_dim, info, x_, beta_, y_));
889+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info));
890+
} else if constexpr (std::is_same_v<value_type, Kokkos::complex<float>>) {
891+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_cbsrmv_analysis(
892+
handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr,
893+
bsr_col_ind, block_dim, info));
894+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_cbsrmv(
895+
handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr,
896+
bsr_col_ind, block_dim, info, x_, beta_, y_));
897+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info));
898+
} else if constexpr (std::is_same_v<value_type, Kokkos::complex<double>>) {
899+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_zbsrmv_analysis(
900+
handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr,
901+
bsr_col_ind, block_dim, info));
902+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_zbsrmv(
903+
handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr,
904+
bsr_col_ind, block_dim, info, x_, beta_, y_));
905+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info));
906+
} else {
907+
static_assert(KokkosKernels::Impl::always_false_v<value_type>,
908+
"unsupported value type for rocsparse_*bsrmv");
909+
}
872910
// *_ex* functions introduced in 5.4.0
873-
#if KOKKOSSPARSE_IMPL_ROCM_VERSION < 50400
911+
#elif KOKKOSSPARSE_IMPL_ROCM_VERSION < 50400
874912
if constexpr (std::is_same_v<value_type, float>) {
875913
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv(
876914
handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr,

sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,6 @@ KOKKOSSPARSE_SPMV_CUSPARSE(Kokkos::complex<float>, int64_t, size_t,
359359

360360
// rocSPARSE
361361
#if defined(KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE)
362-
#include <rocsparse/rocsparse.h>
363-
#include <rocm_version.h>
364362
#include "KokkosSparse_Utils_rocsparse.hpp"
365363

366364
namespace KokkosSparse {
@@ -443,7 +441,17 @@ void spmv_rocsparse(const Kokkos::HIP& exec,
443441
alg = rocsparse_spmv_alg_csr_stream;
444442
}
445443

446-
#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400
444+
#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 60000
445+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(
446+
rocsparse_spmv(handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta,
447+
vecY, compute_type, alg, rocsparse_spmv_stage_buffer_size,
448+
&buffer_size, tmp_buffer));
449+
KOKKOS_IMPL_HIP_SAFE_CALL(hipMalloc(&tmp_buffer, buffer_size));
450+
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(
451+
rocsparse_spmv(handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta,
452+
vecY, compute_type, alg, rocsparse_spmv_stage_compute,
453+
&buffer_size, tmp_buffer));
454+
#elif KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400
447455
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_spmv_ex(
448456
handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, vecY,
449457
compute_type, alg, rocsparse_spmv_stage_auto, &buffer_size, tmp_buffer));

0 commit comments

Comments
 (0)