|
| 1 | +/******************************************************************************* |
| 2 | +* Copyright 2023 Intel Corporation |
| 3 | +* |
| 4 | +* Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +* you may not use this file except in compliance with the License. |
| 6 | +* You may obtain a copy of the License at |
| 7 | +* |
| 8 | +* http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +* |
| 10 | +* Unless required by applicable law or agreed to in writing, software |
| 11 | +* distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +* See the License for the specific language governing permissions and |
| 14 | +* limitations under the License. |
| 15 | +*******************************************************************************/ |
| 16 | + |
| 17 | +#include "oneapi/dnnl/dnnl_types.h" |
| 18 | + |
| 19 | +#include "common/dnnl_thread.hpp" |
| 20 | +#include "common/nstl.hpp" |
| 21 | +#include "common/utils.hpp" |
| 22 | + |
| 23 | +#include "cpu/platform.hpp" |
| 24 | + |
| 25 | +#include "cpu/gemm/bf16/ref_gemm_bf16.hpp" |
| 26 | +#include "cpu/gemm/f32/gemm_utils_f32.hpp" |
| 27 | + |
| 28 | +namespace dnnl { |
| 29 | +namespace impl { |
| 30 | +namespace cpu { |
| 31 | + |
| 32 | +using namespace dnnl::impl::utils; |
| 33 | +using namespace gemm_utils; |
| 34 | + |
| 35 | +namespace { |
| 36 | + |
| 37 | +void copy_A(bool isTransA, dim_t K, const bfloat16_t *A, const dim_t lda, |
| 38 | + bfloat16_t *ws) { |
| 39 | + for (dim_t k = 0; k < K; k++) { |
| 40 | + PRAGMA_OMP_SIMD() |
| 41 | + for (dim_t i = 0; i < unroll_factor<bfloat16_t>::m; i++) { |
| 42 | + ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; |
| 43 | + } |
| 44 | + ws += unroll_factor<bfloat16_t>::m; |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +template <bool isTransA, bool isTransB> |
| 49 | +void kernel_mxn(dim_t K, const bfloat16_t *A, const dim_t lda, |
| 50 | + const bfloat16_t *B, const dim_t ldb, float *C, const dim_t ldc, |
| 51 | + const float alpha, const float beta) { |
| 52 | + float c[unroll_factor<bfloat16_t>::m * unroll_factor<bfloat16_t>::n] |
| 53 | + = {0.f}; |
| 54 | + for (dim_t k = 0; k < K; k++) { |
| 55 | + for (dim_t j = 0; j < unroll_factor<bfloat16_t>::n; j++) { |
| 56 | + bfloat16_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; |
| 57 | + PRAGMA_OMP_SIMD() |
| 58 | + for (dim_t i = 0; i < unroll_factor<bfloat16_t>::m; i++) { |
| 59 | + bfloat16_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; |
| 60 | + c[i + unroll_factor<bfloat16_t>::m * j] += a * b; |
| 61 | + } |
| 62 | + } |
| 63 | + } |
| 64 | + for (dim_t j = 0; j < unroll_factor<bfloat16_t>::n; j++) { |
| 65 | + PRAGMA_OMP_SIMD() |
| 66 | + for (dim_t i = 0; i < unroll_factor<bfloat16_t>::m; i++) { |
| 67 | + C[i + j * ldc] = (beta == 0.f) |
| 68 | + ? alpha * c[i + unroll_factor<bfloat16_t>::m * j] |
| 69 | + : alpha * c[i + unroll_factor<bfloat16_t>::m * j] |
| 70 | + + beta * C[i + j * ldc]; |
| 71 | + } |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +template <bool isTransA, bool isTransB> |
| 76 | +void block_ker(const dim_t M, const dim_t N, const dim_t K, const bfloat16_t *A, |
| 77 | + const dim_t lda, const bfloat16_t *B, const dim_t ldb, float *C, |
| 78 | + const dim_t ldc, const float alpha, const float beta, bfloat16_t *ws, |
| 79 | + bool do_copy) { |
| 80 | + dim_t Nu = rnd_dn(N, unroll_factor<bfloat16_t>::n); |
| 81 | + dim_t Mu = rnd_dn(M, unroll_factor<bfloat16_t>::m); |
| 82 | + for (dim_t i = 0; i < Mu; i += unroll_factor<bfloat16_t>::m) { |
| 83 | + for (dim_t j = 0; j < Nu; j += unroll_factor<bfloat16_t>::n) { |
| 84 | + const bfloat16_t *b = isTransB ? &B[j] : &B[j * ldb]; |
| 85 | + const bfloat16_t *a = isTransA ? &A[i * lda] : &A[i]; |
| 86 | + if (do_copy) { |
| 87 | + if (j == 0) { copy_A(isTransA, K, a, lda, ws); } |
| 88 | + kernel_mxn<false, isTransB>(K, ws, unroll_factor<bfloat16_t>::m, |
| 89 | + b, ldb, &C[i + j * ldc], ldc, alpha, beta); |
| 90 | + } else { |
| 91 | + kernel_mxn<isTransA, isTransB>( |
| 92 | + K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); |
| 93 | + } |
| 94 | + } |
| 95 | + } |
| 96 | + // tail processing |
| 97 | + for (dim_t i = 0; i < M; i++) { |
| 98 | + for (dim_t j = Nu; j < N; j++) { |
| 99 | + float c = beta == 0.f ? 0.f : beta * C[i + j * ldc]; |
| 100 | + for (dim_t p = 0; p < K; p++) { |
| 101 | + bfloat16_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; |
| 102 | + bfloat16_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; |
| 103 | + c += alpha * a * b; |
| 104 | + } |
| 105 | + C[i + j * ldc] = c; |
| 106 | + } |
| 107 | + } |
| 108 | + for (dim_t i = Mu; i < M; i++) { |
| 109 | + for (dim_t j = 0; j < Nu; j++) { |
| 110 | + float c = beta == 0.f ? 0.f : beta * C[i + j * ldc]; |
| 111 | + for (dim_t p = 0; p < K; p++) { |
| 112 | + bfloat16_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; |
| 113 | + bfloat16_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; |
| 114 | + c += alpha * a * b; |
| 115 | + } |
| 116 | + C[i + j * ldc] = c; |
| 117 | + } |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +template <bool isTransA, bool isTransB> |
| 122 | +void gemm_ithr(const dim_t M, const dim_t N, const dim_t K, const float alpha, |
| 123 | + const bfloat16_t *A, const dim_t lda, const bfloat16_t *B, |
| 124 | + const dim_t ldb, const float beta, float *C, const dim_t ldc, |
| 125 | + bool do_copy, bfloat16_t *ws) { |
| 126 | + constexpr dim_t BM = gemm_traits<bfloat16_t, isTransA, isTransB>::BM; |
| 127 | + constexpr dim_t BN = gemm_traits<bfloat16_t, isTransA, isTransB>::BN; |
| 128 | + constexpr dim_t BK = gemm_traits<bfloat16_t, isTransA, isTransB>::BK; |
| 129 | + |
| 130 | + const bfloat16_t *curA; |
| 131 | + const bfloat16_t *curB; |
| 132 | + float *curC; |
| 133 | + |
| 134 | + if ((M <= 0) || (N <= 0)) return; |
| 135 | + |
| 136 | + if ((K <= 0) || (alpha == 0.f)) { |
| 137 | + dim_t MN = N * M; |
| 138 | + if (beta == 0.f) { |
| 139 | + for (dim_t j = 0; j < MN; j++) |
| 140 | + C[j] = 0.f; |
| 141 | + } else if (beta != 1.f) { |
| 142 | + for (dim_t j = 0; j < MN; j++) |
| 143 | + C[j] *= beta; |
| 144 | + } |
| 145 | + return; |
| 146 | + } |
| 147 | + |
| 148 | + for (dim_t Bk = 0; Bk < K; Bk += BK) { |
| 149 | + dim_t kb = nstl::min(K - Bk, BK); |
| 150 | + for (dim_t Bm = 0; Bm < M; Bm += BM) { |
| 151 | + dim_t mb = nstl::min(M - Bm, BM); |
| 152 | + for (dim_t Bn = 0; Bn < N; Bn += BN) { |
| 153 | + dim_t nb = nstl::min(N - Bn, BN); |
| 154 | + curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; |
| 155 | + curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; |
| 156 | + curC = C + Bm + Bn * ldc; |
| 157 | + if (Bk == 0) { |
| 158 | + block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB, |
| 159 | + ldb, curC, ldc, alpha, beta, ws, do_copy); |
| 160 | + } else { |
| 161 | + block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB, |
| 162 | + ldb, curC, ldc, alpha, 1.f, ws, do_copy); |
| 163 | + } |
| 164 | + } |
| 165 | + } |
| 166 | + } |
| 167 | +} |
| 168 | + |
| 169 | +} // namespace |
| 170 | + |
| 171 | +dnnl_status_t ref_gemm_bf16bf16f32(const char *transa_, const char *transb_, |
| 172 | + const dim_t *M_, const dim_t *N_, const dim_t *K_, const float *alpha_, |
| 173 | + const bfloat16_t *A, const dim_t *lda_, const bfloat16_t *B, |
| 174 | + const dim_t *ldb_, const float *beta_, float *C, const dim_t *ldc_) { |
| 175 | + |
| 176 | + if (!(utils::one_of(*transa_, 'n', 'N', 't', 'T') |
| 177 | + && utils::one_of(*transb_, 'n', 'N', 't', 'T'))) |
| 178 | + return dnnl_unimplemented; |
| 179 | + |
| 180 | + bool isTransA = (*transa_ == 'T' || *transa_ == 't'); |
| 181 | + bool isTransB = (*transb_ == 'T' || *transb_ == 't'); |
| 182 | + const dim_t M = *M_, N = *N_, K = *K_; |
| 183 | + const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; |
| 184 | + const float alpha = *alpha_, beta = *beta_; |
| 185 | + |
| 186 | + // early out and avoid division by zero in partitioning |
| 187 | + if (utils::one_of(0, M, N)) return dnnl_success; |
| 188 | + |
| 189 | + int max_nthr = dnnl_get_current_num_threads(); |
| 190 | + int nthr_m, nthr_n, nthr_k; |
| 191 | + dim_t MB, NB, KB; |
| 192 | + // thread balancing over M, N, K & size of blocking dimensions |
| 193 | + calc_nthr_nocopy_avx( |
| 194 | + M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); |
| 195 | + assert(IMPLICATION(!dnnl_thr_syncable(), nthr_k == 1)); |
| 196 | + |
| 197 | + float *c_buffers = nullptr; |
| 198 | + bfloat16_t *ws_buffers = nullptr; |
| 199 | + if (nthr_k > 1) { |
| 200 | + c_buffers = (float *)malloc( |
| 201 | + sizeof(*c_buffers) * nthr_m * nthr_n * (nthr_k - 1) * MB * NB, |
| 202 | + PAGE_4K); |
| 203 | + if (!c_buffers) { |
| 204 | + nthr_k = 1; |
| 205 | + KB = K; |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + bool do_copy = (NB / unroll_factor<bfloat16_t>::n > 3); |
| 210 | + const int nthr_mn = nthr_m * nthr_n; |
| 211 | + const int nthr_to_use = nthr_mn * nthr_k; |
| 212 | + const size_t ws_elems_per_thr = K * unroll_factor<bfloat16_t>::m; |
| 213 | + const size_t ws_size_per_thr |
| 214 | + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); |
| 215 | + if (do_copy) { |
| 216 | + ws_buffers |
| 217 | + = (bfloat16_t *)malloc(nthr_to_use * ws_size_per_thr, PAGE_4K); |
| 218 | + if (!ws_buffers) do_copy = false; |
| 219 | + } |
| 220 | + |
| 221 | + auto get_thr_block = [&](dim_t &from, dim_t &to, dim_t &myN, dim_t NB, |
| 222 | + dim_t N, int ithr) { |
| 223 | + from = NB * (ithr); |
| 224 | + to = NB * (ithr + 1); |
| 225 | + if (to > N) to = N; |
| 226 | + myN = to - from; |
| 227 | + }; |
| 228 | + |
| 229 | + parallel(nthr_to_use, [&](int ithr, int nthr) { |
| 230 | + assert(nthr_to_use == nthr); |
| 231 | + MAYBE_UNUSED(nthr); |
| 232 | + |
| 233 | + int ithr_mn = ithr % nthr_mn; |
| 234 | + int ithr_m = ithr_mn % nthr_m; |
| 235 | + int ithr_n = ithr_mn / nthr_m; |
| 236 | + int ithr_k = ithr / nthr_mn; |
| 237 | + |
| 238 | + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); |
| 239 | + |
| 240 | + bfloat16_t *ws = do_copy |
| 241 | + ? ws_buffers + ithr * ws_size_per_thr / sizeof(float) |
| 242 | + : nullptr; |
| 243 | + |
| 244 | + dim_t m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, |
| 245 | + k_from = 0, k_to = 0, myK = 0; |
| 246 | + |
| 247 | + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); |
| 248 | + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); |
| 249 | + get_thr_block(k_from, k_to, myK, KB, K, ithr_k); |
| 250 | + |
| 251 | + if (myM > 0 && myN > 0) { |
| 252 | + float myBeta, *myC; |
| 253 | + dim_t ld; |
| 254 | + if (ithr_k == 0) { |
| 255 | + myC = &(C[m_from + n_from * ldc]); |
| 256 | + myBeta = beta; |
| 257 | + ld = ldc; |
| 258 | + } else { |
| 259 | + myC = c_buffers + MB * NB * (cbase + ithr_k - 1); |
| 260 | + myBeta = 0.0f; |
| 261 | + ld = MB; |
| 262 | + } |
| 263 | + const bfloat16_t *myA = isTransA ? &(A[k_from + m_from * lda]) |
| 264 | + : &(A[m_from + k_from * lda]); |
| 265 | + const bfloat16_t *myB = isTransB ? &(B[n_from + k_from * ldb]) |
| 266 | + : &(B[k_from + n_from * ldb]); |
| 267 | + |
| 268 | + if (!isTransA) { |
| 269 | + if (!isTransB) { |
| 270 | + gemm_ithr<false, false>(myM, myN, myK, alpha, myA, lda, myB, |
| 271 | + ldb, myBeta, myC, ld, do_copy, ws); |
| 272 | + } else { |
| 273 | + gemm_ithr<false, true>(myM, myN, myK, alpha, myA, lda, myB, |
| 274 | + ldb, myBeta, myC, ld, do_copy, ws); |
| 275 | + } |
| 276 | + } else { |
| 277 | + if (!isTransB) { |
| 278 | + gemm_ithr<true, false>(myM, myN, myK, alpha, myA, lda, myB, |
| 279 | + ldb, myBeta, myC, ld, do_copy, ws); |
| 280 | + } else { |
| 281 | + gemm_ithr<true, true>(myM, myN, myK, alpha, myA, lda, myB, |
| 282 | + ldb, myBeta, myC, ld, do_copy, ws); |
| 283 | + } |
| 284 | + } |
| 285 | + } |
| 286 | + }); |
| 287 | + |
| 288 | + if (nthr_k > 1) { |
| 289 | + parallel(nthr_to_use, [&](int ithr, int nthr) { |
| 290 | + assert(nthr_to_use == nthr); |
| 291 | + MAYBE_UNUSED(nthr); |
| 292 | + |
| 293 | + int ithr_mn = ithr % nthr_mn; |
| 294 | + int ithr_m = ithr_mn % nthr_m; |
| 295 | + int ithr_k = ithr / nthr_mn; |
| 296 | + int ithr_n = ithr_mn / nthr_m; |
| 297 | + |
| 298 | + dim_t n_from = 0, n_to = 0, myN = 0; |
| 299 | + dim_t m_from = 0, m_to = 0, myM = 0; |
| 300 | + |
| 301 | + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); |
| 302 | + |
| 303 | + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); |
| 304 | + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); |
| 305 | + |
| 306 | + // sum matrices partitioned along K dimension |
| 307 | + dim_t offset = 0, block = 0; |
| 308 | + partition_unit_diff(ithr_k, nthr_k, myN, &offset, &block); |
| 309 | + for (int ik = 1; ik < nthr_k; ++ik) { |
| 310 | + float *myC = c_buffers |
| 311 | + + MB * ((dim_t)NB * (cbase + ik - 1) + offset); |
| 312 | + |
| 313 | + gemm_utils::sum_two_matrices(myM, block, myC, MB, |
| 314 | + &C[m_from + (n_from + offset) * ldc], ldc); |
| 315 | + } |
| 316 | + }); |
| 317 | + } |
| 318 | + |
| 319 | + free(ws_buffers); |
| 320 | + free(c_buffers); |
| 321 | + |
| 322 | + return dnnl_success; |
| 323 | +} |
| 324 | + |
| 325 | +} // namespace cpu |
| 326 | +} // namespace impl |
| 327 | +} // namespace dnnl |
0 commit comments