Skip to content

Commit e634182

Browse files
dzarukinluweizhou2016
authored andcommitted
cpu: gemm: add ref bf16 gemm to keep functionality working
1 parent 2ead5d4 commit e634182

File tree

4 files changed

+377
-2
lines changed

4 files changed

+377
-2
lines changed

src/cpu/gemm/bf16/ref_gemm_bf16.cpp

+327
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "oneapi/dnnl/dnnl_types.h"
18+
19+
#include "common/dnnl_thread.hpp"
20+
#include "common/nstl.hpp"
21+
#include "common/utils.hpp"
22+
23+
#include "cpu/platform.hpp"
24+
25+
#include "cpu/gemm/bf16/ref_gemm_bf16.hpp"
26+
#include "cpu/gemm/f32/gemm_utils_f32.hpp"
27+
28+
namespace dnnl {
29+
namespace impl {
30+
namespace cpu {
31+
32+
using namespace dnnl::impl::utils;
33+
using namespace gemm_utils;
34+
35+
namespace {
36+
37+
void copy_A(bool isTransA, dim_t K, const bfloat16_t *A, const dim_t lda,
38+
bfloat16_t *ws) {
39+
for (dim_t k = 0; k < K; k++) {
40+
PRAGMA_OMP_SIMD()
41+
for (dim_t i = 0; i < unroll_factor<bfloat16_t>::m; i++) {
42+
ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
43+
}
44+
ws += unroll_factor<bfloat16_t>::m;
45+
}
46+
}
47+
48+
template <bool isTransA, bool isTransB>
49+
void kernel_mxn(dim_t K, const bfloat16_t *A, const dim_t lda,
50+
const bfloat16_t *B, const dim_t ldb, float *C, const dim_t ldc,
51+
const float alpha, const float beta) {
52+
float c[unroll_factor<bfloat16_t>::m * unroll_factor<bfloat16_t>::n]
53+
= {0.f};
54+
for (dim_t k = 0; k < K; k++) {
55+
for (dim_t j = 0; j < unroll_factor<bfloat16_t>::n; j++) {
56+
bfloat16_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
57+
PRAGMA_OMP_SIMD()
58+
for (dim_t i = 0; i < unroll_factor<bfloat16_t>::m; i++) {
59+
bfloat16_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
60+
c[i + unroll_factor<bfloat16_t>::m * j] += a * b;
61+
}
62+
}
63+
}
64+
for (dim_t j = 0; j < unroll_factor<bfloat16_t>::n; j++) {
65+
PRAGMA_OMP_SIMD()
66+
for (dim_t i = 0; i < unroll_factor<bfloat16_t>::m; i++) {
67+
C[i + j * ldc] = (beta == 0.f)
68+
? alpha * c[i + unroll_factor<bfloat16_t>::m * j]
69+
: alpha * c[i + unroll_factor<bfloat16_t>::m * j]
70+
+ beta * C[i + j * ldc];
71+
}
72+
}
73+
}
74+
75+
template <bool isTransA, bool isTransB>
76+
void block_ker(const dim_t M, const dim_t N, const dim_t K, const bfloat16_t *A,
77+
const dim_t lda, const bfloat16_t *B, const dim_t ldb, float *C,
78+
const dim_t ldc, const float alpha, const float beta, bfloat16_t *ws,
79+
bool do_copy) {
80+
dim_t Nu = rnd_dn(N, unroll_factor<bfloat16_t>::n);
81+
dim_t Mu = rnd_dn(M, unroll_factor<bfloat16_t>::m);
82+
for (dim_t i = 0; i < Mu; i += unroll_factor<bfloat16_t>::m) {
83+
for (dim_t j = 0; j < Nu; j += unroll_factor<bfloat16_t>::n) {
84+
const bfloat16_t *b = isTransB ? &B[j] : &B[j * ldb];
85+
const bfloat16_t *a = isTransA ? &A[i * lda] : &A[i];
86+
if (do_copy) {
87+
if (j == 0) { copy_A(isTransA, K, a, lda, ws); }
88+
kernel_mxn<false, isTransB>(K, ws, unroll_factor<bfloat16_t>::m,
89+
b, ldb, &C[i + j * ldc], ldc, alpha, beta);
90+
} else {
91+
kernel_mxn<isTransA, isTransB>(
92+
K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
93+
}
94+
}
95+
}
96+
// tail processing
97+
for (dim_t i = 0; i < M; i++) {
98+
for (dim_t j = Nu; j < N; j++) {
99+
float c = beta == 0.f ? 0.f : beta * C[i + j * ldc];
100+
for (dim_t p = 0; p < K; p++) {
101+
bfloat16_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
102+
bfloat16_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
103+
c += alpha * a * b;
104+
}
105+
C[i + j * ldc] = c;
106+
}
107+
}
108+
for (dim_t i = Mu; i < M; i++) {
109+
for (dim_t j = 0; j < Nu; j++) {
110+
float c = beta == 0.f ? 0.f : beta * C[i + j * ldc];
111+
for (dim_t p = 0; p < K; p++) {
112+
bfloat16_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
113+
bfloat16_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
114+
c += alpha * a * b;
115+
}
116+
C[i + j * ldc] = c;
117+
}
118+
}
119+
}
120+
121+
template <bool isTransA, bool isTransB>
122+
void gemm_ithr(const dim_t M, const dim_t N, const dim_t K, const float alpha,
123+
const bfloat16_t *A, const dim_t lda, const bfloat16_t *B,
124+
const dim_t ldb, const float beta, float *C, const dim_t ldc,
125+
bool do_copy, bfloat16_t *ws) {
126+
constexpr dim_t BM = gemm_traits<bfloat16_t, isTransA, isTransB>::BM;
127+
constexpr dim_t BN = gemm_traits<bfloat16_t, isTransA, isTransB>::BN;
128+
constexpr dim_t BK = gemm_traits<bfloat16_t, isTransA, isTransB>::BK;
129+
130+
const bfloat16_t *curA;
131+
const bfloat16_t *curB;
132+
float *curC;
133+
134+
if ((M <= 0) || (N <= 0)) return;
135+
136+
if ((K <= 0) || (alpha == 0.f)) {
137+
dim_t MN = N * M;
138+
if (beta == 0.f) {
139+
for (dim_t j = 0; j < MN; j++)
140+
C[j] = 0.f;
141+
} else if (beta != 1.f) {
142+
for (dim_t j = 0; j < MN; j++)
143+
C[j] *= beta;
144+
}
145+
return;
146+
}
147+
148+
for (dim_t Bk = 0; Bk < K; Bk += BK) {
149+
dim_t kb = nstl::min(K - Bk, BK);
150+
for (dim_t Bm = 0; Bm < M; Bm += BM) {
151+
dim_t mb = nstl::min(M - Bm, BM);
152+
for (dim_t Bn = 0; Bn < N; Bn += BN) {
153+
dim_t nb = nstl::min(N - Bn, BN);
154+
curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda;
155+
curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
156+
curC = C + Bm + Bn * ldc;
157+
if (Bk == 0) {
158+
block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB,
159+
ldb, curC, ldc, alpha, beta, ws, do_copy);
160+
} else {
161+
block_ker<isTransA, isTransB>(mb, nb, kb, curA, lda, curB,
162+
ldb, curC, ldc, alpha, 1.f, ws, do_copy);
163+
}
164+
}
165+
}
166+
}
167+
}
168+
169+
} // namespace
170+
171+
dnnl_status_t ref_gemm_bf16bf16f32(const char *transa_, const char *transb_,
172+
const dim_t *M_, const dim_t *N_, const dim_t *K_, const float *alpha_,
173+
const bfloat16_t *A, const dim_t *lda_, const bfloat16_t *B,
174+
const dim_t *ldb_, const float *beta_, float *C, const dim_t *ldc_) {
175+
176+
if (!(utils::one_of(*transa_, 'n', 'N', 't', 'T')
177+
&& utils::one_of(*transb_, 'n', 'N', 't', 'T')))
178+
return dnnl_unimplemented;
179+
180+
bool isTransA = (*transa_ == 'T' || *transa_ == 't');
181+
bool isTransB = (*transb_ == 'T' || *transb_ == 't');
182+
const dim_t M = *M_, N = *N_, K = *K_;
183+
const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_;
184+
const float alpha = *alpha_, beta = *beta_;
185+
186+
// early out and avoid division by zero in partitioning
187+
if (utils::one_of(0, M, N)) return dnnl_success;
188+
189+
int max_nthr = dnnl_get_current_num_threads();
190+
int nthr_m, nthr_n, nthr_k;
191+
dim_t MB, NB, KB;
192+
// thread balancing over M, N, K & size of blocking dimensions
193+
calc_nthr_nocopy_avx(
194+
M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
195+
assert(IMPLICATION(!dnnl_thr_syncable(), nthr_k == 1));
196+
197+
float *c_buffers = nullptr;
198+
bfloat16_t *ws_buffers = nullptr;
199+
if (nthr_k > 1) {
200+
c_buffers = (float *)malloc(
201+
sizeof(*c_buffers) * nthr_m * nthr_n * (nthr_k - 1) * MB * NB,
202+
PAGE_4K);
203+
if (!c_buffers) {
204+
nthr_k = 1;
205+
KB = K;
206+
}
207+
}
208+
209+
bool do_copy = (NB / unroll_factor<bfloat16_t>::n > 3);
210+
const int nthr_mn = nthr_m * nthr_n;
211+
const int nthr_to_use = nthr_mn * nthr_k;
212+
const size_t ws_elems_per_thr = K * unroll_factor<bfloat16_t>::m;
213+
const size_t ws_size_per_thr
214+
= rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
215+
if (do_copy) {
216+
ws_buffers
217+
= (bfloat16_t *)malloc(nthr_to_use * ws_size_per_thr, PAGE_4K);
218+
if (!ws_buffers) do_copy = false;
219+
}
220+
221+
auto get_thr_block = [&](dim_t &from, dim_t &to, dim_t &myN, dim_t NB,
222+
dim_t N, int ithr) {
223+
from = NB * (ithr);
224+
to = NB * (ithr + 1);
225+
if (to > N) to = N;
226+
myN = to - from;
227+
};
228+
229+
parallel(nthr_to_use, [&](int ithr, int nthr) {
230+
assert(nthr_to_use == nthr);
231+
MAYBE_UNUSED(nthr);
232+
233+
int ithr_mn = ithr % nthr_mn;
234+
int ithr_m = ithr_mn % nthr_m;
235+
int ithr_n = ithr_mn / nthr_m;
236+
int ithr_k = ithr / nthr_mn;
237+
238+
int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
239+
240+
bfloat16_t *ws = do_copy
241+
? ws_buffers + ithr * ws_size_per_thr / sizeof(float)
242+
: nullptr;
243+
244+
dim_t m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
245+
k_from = 0, k_to = 0, myK = 0;
246+
247+
get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
248+
get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
249+
get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
250+
251+
if (myM > 0 && myN > 0) {
252+
float myBeta, *myC;
253+
dim_t ld;
254+
if (ithr_k == 0) {
255+
myC = &(C[m_from + n_from * ldc]);
256+
myBeta = beta;
257+
ld = ldc;
258+
} else {
259+
myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
260+
myBeta = 0.0f;
261+
ld = MB;
262+
}
263+
const bfloat16_t *myA = isTransA ? &(A[k_from + m_from * lda])
264+
: &(A[m_from + k_from * lda]);
265+
const bfloat16_t *myB = isTransB ? &(B[n_from + k_from * ldb])
266+
: &(B[k_from + n_from * ldb]);
267+
268+
if (!isTransA) {
269+
if (!isTransB) {
270+
gemm_ithr<false, false>(myM, myN, myK, alpha, myA, lda, myB,
271+
ldb, myBeta, myC, ld, do_copy, ws);
272+
} else {
273+
gemm_ithr<false, true>(myM, myN, myK, alpha, myA, lda, myB,
274+
ldb, myBeta, myC, ld, do_copy, ws);
275+
}
276+
} else {
277+
if (!isTransB) {
278+
gemm_ithr<true, false>(myM, myN, myK, alpha, myA, lda, myB,
279+
ldb, myBeta, myC, ld, do_copy, ws);
280+
} else {
281+
gemm_ithr<true, true>(myM, myN, myK, alpha, myA, lda, myB,
282+
ldb, myBeta, myC, ld, do_copy, ws);
283+
}
284+
}
285+
}
286+
});
287+
288+
if (nthr_k > 1) {
289+
parallel(nthr_to_use, [&](int ithr, int nthr) {
290+
assert(nthr_to_use == nthr);
291+
MAYBE_UNUSED(nthr);
292+
293+
int ithr_mn = ithr % nthr_mn;
294+
int ithr_m = ithr_mn % nthr_m;
295+
int ithr_k = ithr / nthr_mn;
296+
int ithr_n = ithr_mn / nthr_m;
297+
298+
dim_t n_from = 0, n_to = 0, myN = 0;
299+
dim_t m_from = 0, m_to = 0, myM = 0;
300+
301+
int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
302+
303+
get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
304+
get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
305+
306+
// sum matrices partitioned along K dimension
307+
dim_t offset = 0, block = 0;
308+
partition_unit_diff(ithr_k, nthr_k, myN, &offset, &block);
309+
for (int ik = 1; ik < nthr_k; ++ik) {
310+
float *myC = c_buffers
311+
+ MB * ((dim_t)NB * (cbase + ik - 1) + offset);
312+
313+
gemm_utils::sum_two_matrices(myM, block, myC, MB,
314+
&C[m_from + (n_from + offset) * ldc], ldc);
315+
}
316+
});
317+
}
318+
319+
free(ws_buffers);
320+
free(c_buffers);
321+
322+
return dnnl_success;
323+
}
324+
325+
} // namespace cpu
326+
} // namespace impl
327+
} // namespace dnnl

src/cpu/gemm/bf16/ref_gemm_bf16.hpp

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#ifndef CPU_GEMM_BF16_REF_GEMM_BF16_HPP
18+
#define CPU_GEMM_BF16_REF_GEMM_BF16_HPP
19+
20+
#include "oneapi/dnnl/dnnl_types.h"
21+
22+
#include "common/c_types_map.hpp"
23+
24+
namespace dnnl {
25+
namespace impl {
26+
namespace cpu {
27+
28+
dnnl_status_t ref_gemm_bf16bf16f32(const char *transa, const char *transb,
29+
const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha,
30+
const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B,
31+
const dim_t *ldb, const float *beta, float *C, const dim_t *ldc);
32+
33+
} // namespace cpu
34+
} // namespace impl
35+
} // namespace dnnl
36+
37+
#endif // CPU_GEMM_F32_REF_GEMM_F32_HPP

src/cpu/gemm/f32/gemm_utils_f32.hpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2018-2020 Intel Corporation
2+
* Copyright 2018-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -45,6 +45,15 @@ struct gemm_traits<float, isTransA, isTransB> {
4545
static constexpr dim_t BK = isTransB ? 96 : 256;
4646
};
4747

48+
template <bool isTransA, bool isTransB>
49+
struct gemm_traits<bfloat16_t, isTransA, isTransB> {
50+
static constexpr dim_t m = 32;
51+
static constexpr dim_t n = 6;
52+
static constexpr dim_t BM = 4032;
53+
static constexpr dim_t BN = isTransA ? 96 : 48;
54+
static constexpr dim_t BK = isTransB ? 96 : 256;
55+
};
56+
4857
template <typename T>
4958
using unroll_factor = gemm_traits<T, false, false>;
5059

0 commit comments

Comments
 (0)