1
1
/* ******************************************************************************
2
2
* Copyright 2022-2023 Intel Corporation
3
3
* Copyright 2023-2024 FUJITSU LIMITED
4
+ * Copyright 2024 Arm Ltd. and affiliates
4
5
*
5
6
* Licensed under the Apache License, Version 2.0 (the "License");
6
7
* you may not use this file except in compliance with the License.
@@ -47,15 +48,18 @@ impl::data_type_t get_accum_datatype(brgemm_t *brg) {
47
48
return brg->is_int8 ? data_type::s32 : data_type::f32;
48
49
}
49
50
50
- void init_kernel_datatype (
51
+ status_t init_kernel_datatype (
51
52
brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) {
52
- assert (dt_a != data_type::undef && dt_b != data_type::undef);
53
+ if (dt_a != data_type::undef && dt_b != data_type::undef)
54
+ return status::unimplemented;
53
55
brg->is_int8 = utils::one_of (dt_a, data_type::u8, data_type::s8)
54
56
&& utils::one_of (dt_b, data_type::u8, data_type::s8);
55
57
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
56
58
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
57
59
brg->is_f16 = utils::one_of (data_type::f16, dt_a, dt_b);
58
- assert (brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16 );
60
+ if (brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16 )
61
+ return status::unimplemented;
62
+ return status::success;
59
63
}
60
64
61
65
void init_common_conf (brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
@@ -88,27 +92,22 @@ void maybe_try_bf32(brgemm_t *brg) {
88
92
//
89
93
}
90
94
91
- void set_isa_impl (brgemm_t *brg) {
95
+ status_t set_isa_impl (brgemm_t *brg) {
92
96
auto is_isa_ok = [&](cpu_isa_t isa) {
93
97
return mayiuse (isa) &&
94
98
// maybe IMPLICATION(brg->isa_user != isa_undef,
95
99
// is_superset(brg->isa_user, isa)), but the API is not clear.
96
100
one_of (brg->isa_user , isa_undef, isa);
97
101
};
98
102
99
- if (brg->is_bf32 ) {
100
- assert (!" unsupported case" );
101
- } else if (brg->is_f32 ) {
102
- brg->isa_impl = utils::map (true , isa_undef, is_isa_ok (sve_512), sve_512,
103
- is_isa_ok (sve_256), sve_256);
104
- } else if (brg->is_bf16 ) {
105
- assert (!" unsupported case" );
106
- } else if (brg->is_f16 ) {
107
- assert (!" unsupported case" );
108
- } else if (brg->is_int8 ) {
103
+ if (brg->is_bf32 || brg->is_bf16 || brg->is_f16 ) {
104
+ return status::unimplemented;
105
+ } else if (brg->is_f32 || brg->is_int8 ) {
109
106
brg->isa_impl = utils::map (true , isa_undef, is_isa_ok (sve_512), sve_512,
110
107
is_isa_ok (sve_256), sve_256);
108
+ return status::success;
111
109
}
110
+ return status::success;
112
111
}
113
112
114
113
void set_brg_vmm (brgemm_t *brg) {
@@ -187,7 +186,7 @@ inline size_t data_type_vnni_granularity(data_type_t data_type) {
187
186
}
188
187
status_t brgemm_blocking (brgemm_t *brg) {
189
188
190
- set_isa_impl (brg);
189
+ CHECK ( set_isa_impl (brg) );
191
190
if (brg->isa_impl == isa_undef) return status::unimplemented;
192
191
assert (!brg->is_dgmm ); // should not be called from brdgmm
193
192
set_brg_vmm (brg);
@@ -296,18 +295,19 @@ status_t brdgmm_blocking(brgemm_t *brg) {
296
295
return status::success;
297
296
}
298
297
299
- void init_brgemm_conf (brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
300
- impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
301
- float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
302
- dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) {
298
+ status_t init_brgemm_conf (brgemm_t *brg, cpu_isa_t isa,
299
+ brgemm_batch_kind_t type, impl::data_type_t dt_a,
300
+ impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
301
+ dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K,
302
+ const brgemm_strides_t *strides, bool is_bf32) {
303
303
304
304
init_common_conf (brg, type, alpha, beta, strides);
305
305
306
306
brg->layout = layout;
307
307
308
308
brg->dt_a = brg->is_row_major () ? dt_a : dt_b;
309
309
brg->dt_b = brg->is_row_major () ? dt_b : dt_a;
310
- init_kernel_datatype (brg, brg->dt_a , brg->dt_b );
310
+ CHECK ( init_kernel_datatype (brg, brg->dt_a , brg->dt_b ) );
311
311
312
312
brg->dt_c = get_accum_datatype (brg);
313
313
brg->dt_d = brg->dt_c ;
@@ -319,7 +319,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
319
319
brg->typesize_D = types::data_type_size (brg->dt_d );
320
320
321
321
brg->isa_user = isa;
322
- set_isa_impl (brg);
322
+ CHECK ( set_isa_impl (brg) );
323
323
brg->is_bf32 = false ;
324
324
325
325
brg->has_int8_vnni = true ;
@@ -352,11 +352,13 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
352
352
brg->rd_step = has_no_vnni_compute_instruction
353
353
? 1
354
354
: data_type_vnni_granularity (brg->dt_b );
355
+ return status::success;
355
356
}
356
357
357
- void init_brdgmm_conf (brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
358
- impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
359
- float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
358
+ status_t init_brdgmm_conf (brgemm_t *brg, cpu_isa_t isa,
359
+ brgemm_batch_kind_t type, impl::data_type_t dt_a,
360
+ impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
361
+ dim_t LDA, dim_t LDC, dim_t M, dim_t N,
360
362
const brgemm_strides_t *strides) {
361
363
362
364
init_common_conf (brg, type, alpha, beta, strides);
@@ -365,7 +367,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
365
367
366
368
brg->dt_a = dt_a;
367
369
brg->dt_b = dt_b;
368
- init_kernel_datatype (brg, brg->dt_a , brg->dt_b );
370
+ CHECK ( init_kernel_datatype (brg, brg->dt_a , brg->dt_b ) );
369
371
370
372
brg->dt_c = get_accum_datatype (brg);
371
373
brg->dt_d = brg->dt_c ;
@@ -394,6 +396,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
394
396
395
397
brg->bcast_dim = M;
396
398
brg->load_dim = N;
399
+ return status::success;
397
400
}
398
401
399
402
} // namespace brgemm_utils
@@ -402,4 +405,4 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
402
405
} // namespace impl
403
406
} // namespace dnnl
404
407
405
- // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
408
+ // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
0 commit comments