@@ -63,145 +63,31 @@ namespace KokkosBlas {
63
63
// / CT/NT, NT/CT, CT/CT
64
64
// /
65
65
66
- // /
67
- // / NT/NT
68
- // /
69
-
70
- template <>
71
- template <typename ScalarType, typename AViewType, typename BViewType,
72
- typename CViewType>
73
- KOKKOS_INLINE_FUNCTION int
74
- SerialGemm<Trans::NoTranspose, Trans::NoTranspose,
75
- Algo::Gemm::Unblocked>::invoke(const ScalarType alpha,
76
- const AViewType &A,
77
- const BViewType &B,
78
- const ScalarType beta,
79
- const CViewType &C) {
80
- // C = beta C + alpha A B
81
- // C (m x n), A(m x k), B(k x n)
82
- return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke (
83
- C.extent (0 ), C.extent (1 ), A.extent (1 ), alpha, A.data (), A.stride_0 (),
84
- A.stride_1 (), B.data (), B.stride_0 (), B.stride_1 (), beta, C.data (),
85
- C.stride_0 (), C.stride_1 ());
86
- }
87
-
88
- template <>
89
- template <typename ScalarType, typename AViewType, typename BViewType,
90
- typename CViewType>
91
- KOKKOS_INLINE_FUNCTION int
92
- SerialGemm<Trans::NoTranspose, Trans::NoTranspose, Algo::Gemm::Blocked>::invoke(
93
- const ScalarType alpha, const AViewType &A, const BViewType &B,
94
- const ScalarType beta, const CViewType &C) {
95
- // C = beta C + alpha A B
96
- // C (m x n), A(m x k), B(k x n)
97
- return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke (
98
- C.extent (0 ), C.extent (1 ), A.extent (1 ), alpha, A.data (), A.stride_0 (),
99
- A.stride_1 (), B.data (), B.stride_0 (), B.stride_1 (), beta, C.data (),
100
- C.stride_0 (), C.stride_1 ());
101
- }
102
-
103
- // /
104
- // / T/NT
105
- // /
106
-
107
- template <>
66
+ template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
108
67
template <typename ScalarType, typename AViewType, typename BViewType,
109
68
typename CViewType>
110
- KOKKOS_INLINE_FUNCTION int
111
- SerialGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Unblocked>::invoke(
69
+ KOKKOS_INLINE_FUNCTION int SerialGemm<ArgTransA, ArgTransB, ArgAlgo>::invoke(
112
70
const ScalarType alpha, const AViewType &A, const BViewType &B,
113
71
const ScalarType beta, const CViewType &C) {
114
72
// C = beta C + alpha A B
115
73
// C (m x n), A(m x k), B(k x n)
116
- return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke (
117
- C.extent (0 ), C.extent (1 ), A.extent (0 ), alpha, A.data (), A.stride_1 (),
118
- A.stride_0 (), B.data (), B.stride_0 (), B.stride_1 (), beta, C.data (),
119
- C.stride_0 (), C.stride_1 ());
74
+ static_assert (std::is_same<ArgAlgo, Algo::Gemm::Unblocked>::value ||
75
+ std::is_same<ArgAlgo, Algo::Gemm::Blocked>::value ||
76
+ std::is_same<ArgAlgo, Algo::Gemm::CompactMKL>::value,
77
+ " Algorithm not supported" );
78
+
79
+ using TransA = Impl::MatrixModeInfo<ArgTransA>;
80
+ using TransB = Impl::MatrixModeInfo<ArgTransB>;
81
+ const auto ae1 = TransA::extent (A, 1 );
82
+ const auto as0 = TransA::stride_0 (A);
83
+ const auto as1 = TransA::stride_1 (A);
84
+ const auto bs0 = TransB::stride_0 (B);
85
+ const auto bs1 = TransB::stride_1 (B);
86
+
87
+ return Impl::SerialGemmInternal<ArgAlgo>::invoke (
88
+ C.extent (0 ), C.extent (1 ), ae1, alpha, A.data (), as0, as1, B.data (), bs0,
89
+ bs1, beta, C.data (), C.stride_0 (), C.stride_1 ());
120
90
}
121
-
122
- template <>
123
- template <typename ScalarType, typename AViewType, typename BViewType,
124
- typename CViewType>
125
- KOKKOS_INLINE_FUNCTION int
126
- SerialGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Blocked>::invoke(
127
- const ScalarType alpha, const AViewType &A, const BViewType &B,
128
- const ScalarType beta, const CViewType &C) {
129
- // C = beta C + alpha A B
130
- // C (m x n), A(m x k), B(k x n)
131
- return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke (
132
- C.extent (0 ), C.extent (1 ), A.extent (0 ), alpha, A.data (), A.stride_1 (),
133
- A.stride_0 (), B.data (), B.stride_0 (), B.stride_1 (), beta, C.data (),
134
- C.stride_0 (), C.stride_1 ());
135
- }
136
-
137
- // /
138
- // / NT/T
139
- // /
140
-
141
- template <>
142
- template <typename ScalarType, typename AViewType, typename BViewType,
143
- typename CViewType>
144
- KOKKOS_INLINE_FUNCTION int
145
- SerialGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Unblocked>::invoke(
146
- const ScalarType alpha, const AViewType &A, const BViewType &B,
147
- const ScalarType beta, const CViewType &C) {
148
- // C = beta C + alpha A B
149
- // C (m x n), A(m x k), B(k x n)
150
- return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke (
151
- C.extent (0 ), C.extent (1 ), A.extent (1 ), alpha, A.data (), A.stride_0 (),
152
- A.stride_1 (), B.data (), B.stride_1 (), B.stride_0 (), beta, C.data (),
153
- C.stride_0 (), C.stride_1 ());
154
- }
155
-
156
- template <>
157
- template <typename ScalarType, typename AViewType, typename BViewType,
158
- typename CViewType>
159
- KOKKOS_INLINE_FUNCTION int
160
- SerialGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Blocked>::invoke(
161
- const ScalarType alpha, const AViewType &A, const BViewType &B,
162
- const ScalarType beta, const CViewType &C) {
163
- // C = beta C + alpha A B
164
- // C (m x n), A(m x k), B(k x n)
165
- return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke (
166
- C.extent (0 ), C.extent (1 ), A.extent (1 ), alpha, A.data (), A.stride_0 (),
167
- A.stride_1 (), B.data (), B.stride_1 (), B.stride_0 (), beta, C.data (),
168
- C.stride_0 (), C.stride_1 ());
169
- }
170
-
171
- // /
172
- // / T/T
173
- // /
174
-
175
- template <>
176
- template <typename ScalarType, typename AViewType, typename BViewType,
177
- typename CViewType>
178
- KOKKOS_INLINE_FUNCTION int
179
- SerialGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Unblocked>::invoke(
180
- const ScalarType alpha, const AViewType &A, const BViewType &B,
181
- const ScalarType beta, const CViewType &C) {
182
- // C = beta C + alpha A B
183
- // C (m x n), A(m x k), B(k x n)
184
- return Impl::SerialGemmInternal<Algo::Gemm::Unblocked>::invoke (
185
- C.extent (0 ), C.extent (1 ), A.extent (0 ), alpha, A.data (), A.stride_1 (),
186
- A.stride_0 (), B.data (), B.stride_1 (), B.stride_0 (), beta, C.data (),
187
- C.stride_0 (), C.stride_1 ());
188
- }
189
-
190
- template <>
191
- template <typename ScalarType, typename AViewType, typename BViewType,
192
- typename CViewType>
193
- KOKKOS_INLINE_FUNCTION int
194
- SerialGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Blocked>::invoke(
195
- const ScalarType alpha, const AViewType &A, const BViewType &B,
196
- const ScalarType beta, const CViewType &C) {
197
- // C = beta C + alpha A B
198
- // C (m x n), A(m x k), B(k x n)
199
- return Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke (
200
- C.extent (0 ), C.extent (1 ), A.extent (0 ), alpha, A.data (), A.stride_1 (),
201
- A.stride_0 (), B.data (), B.stride_1 (), B.stride_0 (), beta, C.data (),
202
- C.stride_0 (), C.stride_1 ());
203
- }
204
-
205
91
} // namespace KokkosBlas
206
92
207
93
#endif
0 commit comments