Skip to content

Commit d54c72a

Browse files
committed
aarch64: matmul: addition of JIT int8 kernel
1 parent 7a74129 commit d54c72a

12 files changed

+2090
-7
lines changed

include/oneapi/dnnl/dnnl.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2016-2025 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -1615,6 +1615,9 @@ struct memory : public handle<dnnl_memory_t> {
16151615
BA16a32b4a = dnnl_BA16a32b4a,
16161616
BA16a48b4a = dnnl_BA16a48b4a,
16171617
BA16a64b4a = dnnl_BA16a64b4a,
1618+
BA24b8a = dnnl_BA24b8a,
1619+
aCB24c8b = dnnl_aCB24c8b,
1620+
abDC24d8c = dnnl_abDC24d8c,
16181621
decbA16a = dnnl_decbA16a,
16191622
decbA8a = dnnl_decbA8a,
16201623
defcbA16a = dnnl_defcbA16a,

include/oneapi/dnnl/dnnl_types.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2016-2025 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -1047,6 +1047,9 @@ typedef enum {
10471047
dnnl_aCBdef8b8c,
10481048
dnnl_abdEC16e4c,
10491049
dnnl_abDC16d4c,
1050+
dnnl_BA24b8a,
1051+
dnnl_aCB24c8b,
1052+
dnnl_abDC24d8c,
10501053

10511054
/// Just a sentinel, not real memory format tag. Must be changed after new
10521055
/// format tag is added.

src/common/c_types_map.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2016-2025 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -372,6 +372,9 @@ const format_tag_t aCB16b16c = dnnl_aCB16b16c;
372372
const format_tag_t aCB16b32c = dnnl_aCB16b32c;
373373
const format_tag_t aCB16b48c = dnnl_aCB16b48c;
374374
const format_tag_t aCB16b64c = dnnl_aCB16b64c;
375+
const format_tag_t BA24b8a = dnnl_BA24b8a;
376+
const format_tag_t aCB24c8b = dnnl_aCB24c8b;
377+
const format_tag_t abDC24d8c = dnnl_abDC24d8c;
375378
const format_tag_t aCB16b16c2b = dnnl_aCB16b16c2b;
376379
const format_tag_t aCB16b32c2b = dnnl_aCB16b32c2b;
377380
const format_tag_t aCB16b48c2b = dnnl_aCB16b48c2b;

src/common/dnnl_debug_autogenerated.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2018-2025 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -953,6 +953,9 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
953953
if (v == dnnl_aCBdef8b8c) return "aCBdef8b8c";
954954
if (v == dnnl_abdEC16e4c) return "abdEC16e4c";
955955
if (v == dnnl_abDC16d4c) return "abDC16d4c";
956+
if (v == dnnl_BA24b8a) return "BA24b8a";
957+
if (v == dnnl_aCB24c8b) return "aCB24c8b";
958+
if (v == dnnl_abDC24d8c) return "abDC24d8c";
956959
if (v == dnnl_format_tag_last) return "format_tag_last";
957960
if (v == dnnl_x) return "x";
958961
if (v == dnnl_nc) return "nc";

src/common/memory_desc_wrapper.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2016-2025 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -202,6 +202,9 @@ status_t memory_desc_wrapper::compute_blocking(
202202
C(BA16a32b, {1, 0}, {16, 32}, {0, 1});
203203
C(BA16a48b, {1, 0}, {16, 48}, {0, 1});
204204
C(BA16a64b, {1, 0}, {16, 64}, {0, 1});
205+
C(BA24b8a, {1, 0}, {24, 8}, {1, 0});
206+
C(aCB24c8b, {0, 2, 1}, {24, 8}, {2, 1});
207+
C(abDC24d8c, {0, 1, 3, 2}, {24, 8}, {3, 2});
205208
C(BA16a16b2a, {1, 0}, {16, 16, 2}, {0, 1, 0});
206209
C(BA16a32b2a, {1, 0}, {16, 32, 2}, {0, 1, 0});
207210
C(BA16a48b2a, {1, 0}, {16, 48, 2}, {0, 1, 0});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*******************************************************************************
2+
* Copyright 2025 FUJITSU LIMITED
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#ifndef CPU_AARCH64_JIT_INT8_KERNEL_TYPES_HPP
18+
#define CPU_AARCH64_JIT_INT8_KERNEL_TYPES_HPP
19+
20+
namespace dnnl {
21+
namespace impl {
22+
namespace cpu {
23+
namespace aarch64 {
24+
namespace matmul {
25+
26+
typedef enum {
27+
none = 0,
28+
per_tensor = 1,
29+
per_m = 2,
30+
per_n = 3,
31+
per_k = 4,
32+
} jit_int8_broadcast_t;
33+
34+
struct dyn_vals_t {
35+
int f = 0;
36+
dim_t M = 0;
37+
dim_t K = 0;
38+
dim_t N = 0;
39+
dim_t B = 0;
40+
int is_s8 = 0, is_u8 = 0;
41+
int mtail, ktail, ntail, m_blk, k_blk, n_blk;
42+
int get_min_max = 0, reorder_a = 0, reorder_b = 0, cal_src = 0;
43+
int is_mtail = 0, is_ktail = 0;
44+
};
45+
46+
struct dyn_params_t {
47+
const float *dyn_src;
48+
const int8_t *src;
49+
int8_t *dst;
50+
float *max, *min;
51+
int *nk, *nm, *nn;
52+
int *tl, *mtl, *ntl;
53+
};
54+
55+
struct brg_int8_t {
56+
int M, K, N;
57+
const int m_blk = 8, n_blk = 4, k_blk = 8;
58+
const int ld_block = 6, rd_block = 4, bd_block = 8;
59+
int na, nb;
60+
int m_tail, n_tail, k_tail;
61+
int is_m_tail, is_k_tail, is_n_tail, is_zp_cal;
62+
int dst_dt_sz;
63+
bool is_s8;
64+
bool is_bias;
65+
bool with_scales;
66+
bool with_dst_scales;
67+
bool is_oc_scales;
68+
jit_int8_broadcast_t zp_type_a = jit_int8_broadcast_t::none;
69+
jit_int8_broadcast_t zp_type_b = jit_int8_broadcast_t::none;
70+
jit_int8_broadcast_t zp_type_c = jit_int8_broadcast_t::none;
71+
bool is_zp_b_int8 = false;
72+
bool b_reo = true;
73+
data_type_t zp_b_dt;
74+
dim_t B;
75+
};
76+
77+
struct call_params_t {
78+
const uint8_t *src, *wei;
79+
float *dst;
80+
const float *bias, *scales, *dst_scales;
81+
dim_t M, K, N;
82+
char *buf_B_ptr_;
83+
int *na, *nb;
84+
int32_t *src_zero_point, *wei_zero_point, *dst_zero_point;
85+
const int8_t *wei_zero_point_buf;
86+
float *zp_a_ptr, *zp_b_ptr;
87+
};
88+
89+
} // namespace matmul
90+
} // namespace aarch64
91+
} // namespace cpu
92+
} // namespace impl
93+
} // namespace dnnl
94+
#endif

0 commit comments

Comments
 (0)