Skip to content

Commit 2105dd9

Browse files
author
Mikołaj Zuzek
committed
GEMM: refactor crossing of A/B matrix transposes
1 parent de3d26c commit 2105dd9

12 files changed

+417
-1532
lines changed

blas/impl/KokkosBlas3_serial_gemm_impl.hpp

+18-132
Original file line numberDiff line numberDiff line change
@@ -63,145 +63,31 @@ namespace KokkosBlas {
6363
/// CT/NT, NT/CT, CT/CT
6464
///
6565

66-
///
67-
/// NT/NT
68-
///
69-
70-
template <>
71-
template <typename ScalarType, typename AViewType, typename BViewType,
72-
typename CViewType>
73-
KOKKOS_INLINE_FUNCTION int
74-
SerialGemm<Trans::NoTranspose, Trans::NoTranspose,
75-
Algo::Gemm::Unblocked>::invoke(const ScalarType alpha,
76-
const AViewType &A,
77-
const BViewType &B,
78-
const ScalarType beta,
79-
const CViewType &C) {
80-
// C = beta C + alpha A B
81-
// C (m x n), A(m x k), B(k x n)
82-
return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
83-
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
84-
A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
85-
C.stride_0(), C.stride_1());
86-
}
87-
88-
template <>
89-
template <typename ScalarType, typename AViewType, typename BViewType,
90-
typename CViewType>
91-
KOKKOS_INLINE_FUNCTION int
92-
SerialGemm<Trans::NoTranspose, Trans::NoTranspose, Algo::Gemm::Blocked>::invoke(
93-
const ScalarType alpha, const AViewType &A, const BViewType &B,
94-
const ScalarType beta, const CViewType &C) {
95-
// C = beta C + alpha A B
96-
// C (m x n), A(m x k), B(k x n)
97-
return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
98-
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
99-
A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
100-
C.stride_0(), C.stride_1());
101-
}
102-
103-
///
104-
/// T/NT
105-
///
106-
107-
template <>
66+
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
10867
template <typename ScalarType, typename AViewType, typename BViewType,
10968
typename CViewType>
110-
KOKKOS_INLINE_FUNCTION int
111-
SerialGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Unblocked>::invoke(
69+
KOKKOS_INLINE_FUNCTION int SerialGemm<ArgTransA, ArgTransB, ArgAlgo>::invoke(
11270
const ScalarType alpha, const AViewType &A, const BViewType &B,
11371
const ScalarType beta, const CViewType &C) {
11472
// C = beta C + alpha A B
11573
// C (m x n), A(m x k), B(k x n)
116-
return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
117-
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
118-
A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
119-
C.stride_0(), C.stride_1());
74+
static_assert(std::is_same<ArgAlgo, Algo::Gemm::Unblocked>::value ||
75+
std::is_same<ArgAlgo, Algo::Gemm::Blocked>::value ||
76+
std::is_same<ArgAlgo, Algo::Gemm::CompactMKL>::value,
77+
"Algorithm not supported");
78+
79+
using TransA = Impl::MatrixModeInfo<ArgTransA>;
80+
using TransB = Impl::MatrixModeInfo<ArgTransB>;
81+
const auto ae1 = TransA::extent(A, 1);
82+
const auto as0 = TransA::stride_0(A);
83+
const auto as1 = TransA::stride_1(A);
84+
const auto bs0 = TransB::stride_0(B);
85+
const auto bs1 = TransB::stride_1(B);
86+
87+
return Impl::SerialGemmInternal<ArgAlgo>::invoke(
88+
C.extent(0), C.extent(1), ae1, alpha, A.data(), as0, as1, B.data(), bs0,
89+
bs1, beta, C.data(), C.stride_0(), C.stride_1());
12090
}
121-
122-
template <>
123-
template <typename ScalarType, typename AViewType, typename BViewType,
124-
typename CViewType>
125-
KOKKOS_INLINE_FUNCTION int
126-
SerialGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Blocked>::invoke(
127-
const ScalarType alpha, const AViewType &A, const BViewType &B,
128-
const ScalarType beta, const CViewType &C) {
129-
// C = beta C + alpha A B
130-
// C (m x n), A(m x k), B(k x n)
131-
return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
132-
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
133-
A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
134-
C.stride_0(), C.stride_1());
135-
}
136-
137-
///
138-
/// NT/T
139-
///
140-
141-
template <>
142-
template <typename ScalarType, typename AViewType, typename BViewType,
143-
typename CViewType>
144-
KOKKOS_INLINE_FUNCTION int
145-
SerialGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Unblocked>::invoke(
146-
const ScalarType alpha, const AViewType &A, const BViewType &B,
147-
const ScalarType beta, const CViewType &C) {
148-
// C = beta C + alpha A B
149-
// C (m x n), A(m x k), B(k x n)
150-
return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
151-
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
152-
A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
153-
C.stride_0(), C.stride_1());
154-
}
155-
156-
template <>
157-
template <typename ScalarType, typename AViewType, typename BViewType,
158-
typename CViewType>
159-
KOKKOS_INLINE_FUNCTION int
160-
SerialGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Blocked>::invoke(
161-
const ScalarType alpha, const AViewType &A, const BViewType &B,
162-
const ScalarType beta, const CViewType &C) {
163-
// C = beta C + alpha A B
164-
// C (m x n), A(m x k), B(k x n)
165-
return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
166-
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
167-
A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
168-
C.stride_0(), C.stride_1());
169-
}
170-
171-
///
172-
/// T/T
173-
///
174-
175-
template <>
176-
template <typename ScalarType, typename AViewType, typename BViewType,
177-
typename CViewType>
178-
KOKKOS_INLINE_FUNCTION int
179-
SerialGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Unblocked>::invoke(
180-
const ScalarType alpha, const AViewType &A, const BViewType &B,
181-
const ScalarType beta, const CViewType &C) {
182-
// C = beta C + alpha A B
183-
// C (m x n), A(m x k), B(k x n)
184-
return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
185-
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
186-
A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
187-
C.stride_0(), C.stride_1());
188-
}
189-
190-
template <>
191-
template <typename ScalarType, typename AViewType, typename BViewType,
192-
typename CViewType>
193-
KOKKOS_INLINE_FUNCTION int
194-
SerialGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Blocked>::invoke(
195-
const ScalarType alpha, const AViewType &A, const BViewType &B,
196-
const ScalarType beta, const CViewType &C) {
197-
// C = beta C + alpha A B
198-
// C (m x n), A(m x k), B(k x n)
199-
return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
200-
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
201-
A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
202-
C.stride_0(), C.stride_1());
203-
}
204-
20591
} // namespace KokkosBlas
20692

20793
#endif

0 commit comments

Comments
 (0)