Skip to content

Commit e66697d

Browse files
committed
mode_kk_to_onemkl and trans_mode_kk_to_mkl -> trans_mode_kk_to_onemkl
1 parent f6fcdd7 commit e66697d

3 files changed

+12
-22
lines changed

blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp

+1-12
Original file line numberDiff line numberDiff line change
@@ -777,17 +777,6 @@ KOKKOSBLAS2_CGEMV_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false)
777777
namespace KokkosBlas {
778778
namespace Impl {
779779

780-
inline oneapi::mkl::transpose mode_kk_to_onemkl(char mode_kk) {
781-
switch (toupper(mode_kk)) {
782-
case 'N': return oneapi::mkl::transpose::nontrans;
783-
case 'T': return oneapi::mkl::transpose::trans;
784-
case 'C': return oneapi::mkl::transpose::conjtrans;
785-
default:;
786-
}
787-
throw std::invalid_argument(
788-
"Invalid mode for oneMKL (should be one of N, T, C)");
789-
}
790-
791780
template <typename T, bool is_complex = false>
792781
struct kokkos_to_std_type_map {
793782
using type = T;
@@ -829,7 +818,7 @@ struct kokkos_to_std_type_map<T, true> {
829818
bool row_major = std::is_same<Kokkos::LayoutRight, LAYOUT>::value; \
830819
const std::int64_t M = A.extent(0); \
831820
const std::int64_t N = A.extent(1); \
832-
oneapi::mkl::transpose trans = mode_kk_to_onemkl(kk_trans[0]); \
821+
oneapi::mkl::transpose trans = trans_mode_kk_to_onemkl(kk_trans[0]); \
833822
const std::int64_t LDA = row_major ? A.stride(0) : A.stride(1); \
834823
std::string label = "KokkosBlas::gemv[TPL_ONEMKL," + \
835824
Kokkos::ArithTraits<SCALAR>::name() + "]"; \

blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -564,8 +564,8 @@ TPL_SCALAR_TYPE is the type MKL accents for SCALAR_TYPE
564564
const int64_t cst = is_lr ? C.stride(0) : C.stride(1); \
565565
const int64_t ldc = cst == 0 ? 1 : cst; \
566566
\
567-
oneapi::mkl::transpose transa = trans_mode_kk_to_mkl(transA); \
568-
oneapi::mkl::transpose transb = trans_mode_kk_to_mkl(transB); \
567+
oneapi::mkl::transpose transa = trans_mode_kk_to_onemkl(transA); \
568+
oneapi::mkl::transpose transb = trans_mode_kk_to_onemkl(transB); \
569569
oneapi::mkl::blas::compute_mode mode = oneapi::mkl::blas::compute_mode::standard; \
570570
\
571571
if constexpr (!is_lr) { \

blas/tpls/KokkosBlas_tpl_spec.hpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,15 @@ namespace Impl {
240240

241241

242242
/// \brief This function converts KK transpose mode to MKL transpose mode
243-
inline oneapi::mkl::transpose trans_mode_kk_to_mkl(const char kkMode[]) {
244-
oneapi::mkl::transpose trans;
245-
if ((kkMode[0] == 'N') || (kkMode[0] == 'n'))
246-
return oneapi::mkl::transpose::N;
247-
else if ((kkMode[0] == 'T') || (kkMode[0] == 't'))
248-
return oneapi::mkl::transpose::T;
249-
else
250-
return oneapi::mkl::transpose::C;
243+
inline oneapi::mkl::transpose trans_mode_kk_to_onemkl(char mode_kk) {
244+
switch (toupper(mode_kk)) {
245+
case 'N': return oneapi::mkl::transpose::nontrans;
246+
case 'T': return oneapi::mkl::transpose::trans;
247+
case 'C': return oneapi::mkl::transpose::conjtrans;
248+
default:;
249+
}
250+
throw std::invalid_argument(
251+
"Invalid mode for oneMKL (should be one of N, T, C)");
251252
}
252253

253254
} // namespace Impl

0 commit comments

Comments
 (0)