1
+ /* ******************************************************************************
2
+ * Copyright 2025 FUJITSU LIMITED
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ *******************************************************************************/
16
+
17
+ #ifndef CPU_AARCH64_JIT_INT8_KERNEL_TYPES_HPP
18
+ #define CPU_AARCH64_JIT_INT8_KERNEL_TYPES_HPP
19
+
20
+ namespace dnnl {
21
+ namespace impl {
22
+ namespace cpu {
23
+ namespace aarch64 {
24
+ namespace matmul {
25
+
26
+ typedef enum {
27
+ none = 0 ,
28
+ per_tensor = 1 ,
29
+ per_m = 2 ,
30
+ per_n = 3 ,
31
+ per_k = 4 ,
32
+ } jit_int8_broadcast_t ;
33
+
34
+ struct dyn_vals_t {
35
+ int f = 0 ;
36
+ dim_t M = 0 ;
37
+ dim_t K = 0 ;
38
+ dim_t N = 0 ;
39
+ dim_t B = 0 ;
40
+ int is_s8 = 0 , is_u8 = 0 ;
41
+ int mtail, ktail, ntail, m_blk, k_blk, n_blk;
42
+ int get_min_max = 0 , reorder_a = 0 , reorder_b = 0 , cal_src = 0 ;
43
+ int is_mtail = 0 , is_ktail = 0 ;
44
+ };
45
+
46
+ struct dyn_params_t {
47
+ const float *dyn_src;
48
+ const int8_t *src;
49
+ int8_t *dst;
50
+ float *max, *min;
51
+ int *nk, *nm, *nn;
52
+ int *tl, *mtl, *ntl;
53
+ };
54
+
55
+ struct brg_int8_t {
56
+ int M, K, N;
57
+ const int m_blk = 8 , n_blk = 4 , k_blk = 8 ;
58
+ const int ld_block = 6 , rd_block = 4 , bd_block = 8 ;
59
+ int na, nb;
60
+ int m_tail, n_tail, k_tail;
61
+ int is_m_tail, is_k_tail, is_n_tail, is_zp_cal;
62
+ int dst_dt_sz;
63
+ bool is_s8;
64
+ bool is_bias;
65
+ bool with_scales;
66
+ bool with_dst_scales;
67
+ bool is_oc_scales;
68
+ jit_int8_broadcast_t zp_type_a = jit_int8_broadcast_t ::none;
69
+ jit_int8_broadcast_t zp_type_b = jit_int8_broadcast_t ::none;
70
+ jit_int8_broadcast_t zp_type_c = jit_int8_broadcast_t ::none;
71
+ bool is_zp_b_int8 = false ;
72
+ bool b_reo = true ;
73
+ data_type_t zp_b_dt;
74
+ dim_t B;
75
+ };
76
+
77
+ struct call_params_t {
78
+ const uint8_t *src, *wei;
79
+ float *dst;
80
+ const float *bias, *scales, *dst_scales;
81
+ dim_t M, K, N;
82
+ char *buf_B_ptr_;
83
+ int *na, *nb;
84
+ int32_t *src_zero_point, *wei_zero_point, *dst_zero_point;
85
+ const int8_t *wei_zero_point_buf;
86
+ float *zp_a_ptr, *zp_b_ptr;
87
+ };
88
+
89
+ } // namespace matmul
90
+ } // namespace aarch64
91
+ } // namespace cpu
92
+ } // namespace impl
93
+ } // namespace dnnl
94
+ #endif
0 commit comments