Skip to content

Commit d989d26

Browse files
author
Mikołaj Zuzek
committed
GEMV: refactor A matrix transpose
1 parent 4e145e3 commit d989d26

File tree

2 files changed

+60
-239
lines changed

2 files changed

+60
-239
lines changed

blas/impl/KokkosBlas2_serial_gemv_impl.hpp

+16-78
Original file line numberDiff line numberDiff line change
@@ -57,88 +57,26 @@ namespace KokkosBlas {
5757
/// Serial Impl
5858
/// ===========
5959

60-
///
61-
/// NT
62-
///
63-
64-
template <>
65-
template <typename ScalarType, typename AViewType, typename xViewType,
66-
typename yViewType>
67-
KOKKOS_INLINE_FUNCTION int
68-
SerialGemv<Trans::NoTranspose, Algo::Gemv::Unblocked>::invoke(
69-
const ScalarType alpha, const AViewType &A, const xViewType &x,
70-
const ScalarType beta, const yViewType &y) {
71-
return Impl::SerialGemvInternal<Algo::Gemv::Unblocked>::invoke(
72-
A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(), A.stride_1(),
73-
x.data(), x.stride_0(), beta, y.data(), y.stride_0());
74-
}
75-
76-
template <>
77-
template <typename ScalarType, typename AViewType, typename xViewType,
78-
typename yViewType>
79-
KOKKOS_INLINE_FUNCTION int
80-
SerialGemv<Trans::NoTranspose, Algo::Gemv::Blocked>::invoke(
81-
const ScalarType alpha, const AViewType &A, const xViewType &x,
82-
const ScalarType beta, const yViewType &y) {
83-
return Impl::SerialGemvInternal<Algo::Gemv::Blocked>::invoke(
84-
A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(), A.stride_1(),
85-
x.data(), x.stride_0(), beta, y.data(), y.stride_0());
86-
}
87-
88-
///
89-
/// T
90-
///
91-
92-
template <>
93-
template <typename ScalarType, typename AViewType, typename xViewType,
94-
typename yViewType>
95-
KOKKOS_INLINE_FUNCTION int
96-
SerialGemv<Trans::Transpose, Algo::Gemv::Unblocked>::invoke(
97-
const ScalarType alpha, const AViewType &A, const xViewType &x,
98-
const ScalarType beta, const yViewType &y) {
99-
return Impl::SerialGemvInternal<Algo::Gemv::Unblocked>::invoke(
100-
A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(),
101-
x.data(), x.stride_0(), beta, y.data(), y.stride_0());
102-
}
103-
104-
template <>
105-
template <typename ScalarType, typename AViewType, typename xViewType,
106-
typename yViewType>
107-
KOKKOS_INLINE_FUNCTION int
108-
SerialGemv<Trans::Transpose, Algo::Gemv::Blocked>::invoke(
109-
const ScalarType alpha, const AViewType &A, const xViewType &x,
110-
const ScalarType beta, const yViewType &y) {
111-
return Impl::SerialGemvInternal<Algo::Gemv::Blocked>::invoke(
112-
A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(),
113-
x.data(), x.stride_0(), beta, y.data(), y.stride_0());
114-
}
115-
116-
///
117-
/// CT
118-
///
119-
120-
template <>
121-
template <typename ScalarType, typename AViewType, typename xViewType,
122-
typename yViewType>
123-
KOKKOS_INLINE_FUNCTION int
124-
SerialGemv<Trans::ConjTranspose, Algo::Gemv::Unblocked>::invoke(
125-
const ScalarType alpha, const AViewType &A, const xViewType &x,
126-
const ScalarType beta, const yViewType &y) {
127-
return Impl::SerialGemvInternal<Algo::Gemv::Unblocked>::invoke(
128-
Impl::OpConj(), A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
129-
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
130-
}
131-
132-
template <>
60+
template <typename ArgTrans, typename ArgAlgo>
13361
template <typename ScalarType, typename AViewType, typename xViewType,
13462
typename yViewType>
135-
KOKKOS_INLINE_FUNCTION int
136-
SerialGemv<Trans::ConjTranspose, Algo::Gemv::Blocked>::invoke(
63+
KOKKOS_INLINE_FUNCTION int SerialGemv<ArgTrans, ArgAlgo>::invoke(
13764
const ScalarType alpha, const AViewType &A, const xViewType &x,
13865
const ScalarType beta, const yViewType &y) {
139-
return Impl::SerialGemvInternal<Algo::Gemv::Blocked>::invoke(
140-
Impl::OpConj(), A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
141-
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
66+
static_assert(std::is_same<ArgAlgo, Algo::Gemv::Unblocked>::value ||
67+
std::is_same<ArgAlgo, Algo::Gemv::Blocked>::value ||
68+
std::is_same<ArgAlgo, Algo::Gemv::CompactMKL>::value,
69+
"Algorithm not supported");
70+
71+
using TransA = Impl::MatrixModeInfo<ArgTrans>;
72+
const auto ae0 = TransA::extent(A, 0);
73+
const auto ae1 = TransA::extent(A, 1);
74+
const auto as0 = TransA::stride_0(A);
75+
const auto as1 = TransA::stride_1(A);
76+
77+
return Impl::SerialGemvInternal<ArgAlgo>::invoke(
78+
ae0, ae1, alpha, A.data(), as0, as1, x.data(), x.stride_0(), beta,
79+
y.data(), y.stride_0());
14280
}
14381

14482
} // namespace KokkosBlas

blas/impl/KokkosBlas2_team_gemv_impl.hpp

+44-161
Original file line numberDiff line numberDiff line change
@@ -49,167 +49,50 @@
4949

5050
namespace KokkosBlas {
5151

52-
///
53-
/// NT
54-
///
55-
56-
template <>
57-
struct TeamGemv<Trans::NoTranspose, Algo::Gemv::Unblocked> {
58-
template <typename MemberType, typename ScalarType, typename AViewType,
59-
typename xViewType, typename yViewType>
60-
KOKKOS_INLINE_FUNCTION static int invoke(
61-
const MemberType& member, const ScalarType alpha, const AViewType& A,
62-
const xViewType& x, const ScalarType beta, const yViewType& y) {
63-
static_assert(AViewType::Rank == 2,
64-
"Blas TeamGemv requires rank-2 A matrix");
65-
return Impl::TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
66-
member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(),
67-
A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
68-
}
69-
};
70-
71-
template <>
72-
struct TeamGemv<Trans::NoTranspose, Algo::Gemv::Blocked> {
73-
template <typename MemberType, typename ScalarType, typename AViewType,
74-
typename xViewType, typename yViewType>
75-
KOKKOS_INLINE_FUNCTION static int invoke(
76-
const MemberType& member, const ScalarType alpha, const AViewType& A,
77-
const xViewType& x, const ScalarType beta, const yViewType& y) {
78-
static_assert(AViewType::Rank == 2,
79-
"Blas TeamGemv requires rank-2 A matrix");
80-
return Impl::TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
81-
member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(),
82-
A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
83-
}
84-
};
85-
86-
///
87-
/// T
88-
///
89-
90-
template <>
91-
struct TeamGemv<Trans::Transpose, Algo::Gemv::Unblocked> {
92-
template <typename MemberType, typename ScalarType, typename AViewType,
93-
typename xViewType, typename yViewType>
94-
KOKKOS_INLINE_FUNCTION static int invoke(
95-
const MemberType& member, const ScalarType alpha, const AViewType& A,
96-
const xViewType& x, const ScalarType beta, const yViewType& y) {
97-
static_assert(AViewType::Rank == 2,
98-
"BLAS TeamGemv requires rank-2 A matrix");
99-
return Impl::TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
100-
member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
101-
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
102-
}
103-
};
104-
105-
template <>
106-
struct TeamGemv<Trans::Transpose, Algo::Gemv::Blocked> {
107-
template <typename MemberType, typename ScalarType, typename AViewType,
108-
typename xViewType, typename yViewType>
109-
KOKKOS_INLINE_FUNCTION static int invoke(
110-
const MemberType& member, const ScalarType alpha, const AViewType& A,
111-
const xViewType& x, const ScalarType beta, const yViewType& y) {
112-
static_assert(AViewType::Rank == 2,
113-
"BLAS TeamGemv requires rank-2 A matrix");
114-
return Impl::TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
115-
member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
116-
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
117-
}
118-
};
119-
120-
///
121-
/// CT
122-
///
123-
124-
template <>
125-
struct TeamGemv<Trans::ConjTranspose, Algo::Gemv::Unblocked> {
126-
template <typename MemberType, typename ScalarType, typename AViewType,
127-
typename xViewType, typename yViewType>
128-
KOKKOS_INLINE_FUNCTION static int invoke(
129-
const MemberType& member, const ScalarType alpha, const AViewType& A,
130-
const xViewType& x, const ScalarType beta, const yViewType& y) {
131-
static_assert(AViewType::Rank == 2,
132-
"BLAS TeamGemv requires rank-2 A matrix");
133-
return Impl::TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
134-
member, Impl::OpConj{}, A.extent(1), A.extent(0), alpha, A.data(),
135-
A.stride_1(), A.stride_0(), x.data(), x.stride_0(), beta, y.data(),
136-
y.stride_0());
137-
}
138-
};
139-
140-
template <>
141-
struct TeamGemv<Trans::ConjTranspose, Algo::Gemv::Blocked> {
142-
template <typename MemberType, typename ScalarType, typename AViewType,
143-
typename xViewType, typename yViewType>
144-
KOKKOS_INLINE_FUNCTION static int invoke(
145-
const MemberType& member, const ScalarType alpha, const AViewType& A,
146-
const xViewType& x, const ScalarType beta, const yViewType& y) {
147-
static_assert(AViewType::Rank == 2,
148-
"BLAS TeamGemv requires rank-2 A matrix");
149-
return Impl::TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
150-
member, Impl::OpConj{}, A.extent(1), A.extent(0), alpha, A.data(),
151-
A.stride_1(), A.stride_0(), x.data(), x.stride_0(), beta, y.data(),
152-
y.stride_0());
153-
}
154-
};
155-
156-
///
157-
/// NT
158-
///
159-
160-
template <>
161-
struct TeamVectorGemv<Trans::NoTranspose, Algo::Gemv::Unblocked> {
162-
template <typename MemberType, typename ScalarType, typename AViewType,
163-
typename xViewType, typename yViewType>
164-
KOKKOS_INLINE_FUNCTION static int invoke(
165-
const MemberType& member, const ScalarType alpha, const AViewType& A,
166-
const xViewType& x, const ScalarType beta, const yViewType& y) {
167-
static_assert(AViewType::Rank == 2,
168-
"Blas TeamVectorGemv requires rank-2 A matrix");
169-
return Impl::TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
170-
member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(),
171-
A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
172-
}
173-
};
174-
175-
///
176-
/// T
177-
///
178-
179-
template <>
180-
struct TeamVectorGemv<Trans::Transpose, Algo::Gemv::Unblocked> {
181-
template <typename MemberType, typename ScalarType, typename AViewType,
182-
typename xViewType, typename yViewType>
183-
KOKKOS_INLINE_FUNCTION static int invoke(
184-
const MemberType& member, const ScalarType alpha, const AViewType& A,
185-
const xViewType& x, const ScalarType beta, const yViewType& y) {
186-
static_assert(AViewType::Rank == 2,
187-
"Blas TeamVectorGemv requires rank-2 A matrix");
188-
return Impl::TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
189-
member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
190-
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
191-
}
192-
};
193-
194-
///
195-
/// CT
196-
///
197-
198-
template <>
199-
struct TeamVectorGemv<Trans::ConjTranspose, Algo::Gemv::Unblocked> {
200-
template <typename MemberType, typename ScalarType, typename AViewType,
201-
typename xViewType, typename yViewType>
202-
KOKKOS_INLINE_FUNCTION static int invoke(
203-
const MemberType& member, const ScalarType alpha, const AViewType& A,
204-
const xViewType& x, const ScalarType beta, const yViewType& y) {
205-
static_assert(AViewType::Rank == 2,
206-
"Blas TeamVectorGemv requires rank-2 A matrix");
207-
return Impl::TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
208-
member, Impl::OpConj{}, A.extent(1), A.extent(0), alpha, A.data(),
209-
A.stride_1(), A.stride_0(), x.data(), x.stride_0(), beta, y.data(),
210-
y.stride_0());
211-
}
212-
};
52+
template <typename ArgTrans, typename ArgAlgo>
53+
template <typename MemberType, typename ScalarType, typename AViewType,
54+
typename xViewType, typename yViewType>
55+
KOKKOS_INLINE_FUNCTION int TeamGemv<ArgTrans, ArgAlgo>::invoke(
56+
const MemberType& member, const ScalarType alpha, const AViewType& A,
57+
const xViewType& x, const ScalarType beta, const yViewType& y) {
58+
static_assert(std::is_same<ArgAlgo, Algo::Gemv::Unblocked>::value ||
59+
std::is_same<ArgAlgo, Algo::Gemv::Blocked>::value,
60+
"Algorithm not supported");
61+
static_assert(AViewType::Rank == 2,
62+
"KokkosBlas::TeamGemv requires rank-2 A matrix");
63+
64+
using TransA = Impl::MatrixModeInfo<ArgTrans>;
65+
const auto ae0 = TransA::extent(A, 0);
66+
const auto ae1 = TransA::extent(A, 1);
67+
const auto as0 = TransA::stride_0(A);
68+
const auto as1 = TransA::stride_1(A);
69+
70+
return Impl::TeamGemvInternal<ArgAlgo>::invoke(
71+
member, ae0, ae1, alpha, A.data(), as0, as1, x.data(), x.stride_0(), beta,
72+
y.data(), y.stride_0());
73+
}
74+
75+
template <typename ArgTrans, typename ArgAlgo>
76+
template <typename MemberType, typename ScalarType, typename AViewType,
77+
typename xViewType, typename yViewType>
78+
KOKKOS_INLINE_FUNCTION int TeamVectorGemv<ArgTrans, ArgAlgo>::invoke(
79+
const MemberType& member, const ScalarType alpha, const AViewType& A,
80+
const xViewType& x, const ScalarType beta, const yViewType& y) {
81+
static_assert(std::is_same<ArgAlgo, Algo::Gemv::Unblocked>::value,
82+
"Algorithm not supported");
83+
static_assert(AViewType::Rank == 2,
84+
"KokkosBlas::TeamVectorGemv requires rank-2 A matrix");
85+
86+
using TransA = Impl::MatrixModeInfo<ArgTrans>;
87+
const auto ae0 = TransA::extent(A, 0);
88+
const auto ae1 = TransA::extent(A, 1);
89+
const auto as0 = TransA::stride_0(A);
90+
const auto as1 = TransA::stride_1(A);
91+
92+
return Impl::TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
93+
member, ae0, ae1, alpha, A.data(), as0, as1, x.data(), x.stride_0(), beta,
94+
y.data(), y.stride_0());
95+
}
21396

21497
} // namespace KokkosBlas
21598

0 commit comments

Comments
 (0)