Skip to content

Commit 3a05ca5

Browse files
Ryo-not-riovpirogov
authored andcommitted
src: cpu: conv: Use acl_indirect_gemm for bf16 convolutions
performance improvements: Total benchdnn tests: 57 Min: 15x Average: 131x Max: 320x
1 parent 094cc1d commit 3a05ca5

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

src/cpu/aarch64/acl_convolution_utils.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*******************************************************************************/
1616

1717
#include "cpu/aarch64/acl_convolution_utils.hpp"
18+
#include "common/convolution_pd.hpp"
1819
#include "common/utils.hpp"
1920
#include "oneapi/dnnl/dnnl.h"
2021

@@ -62,9 +63,10 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md,
6263
everyone_is(data_type::f32, src_d.data_type(),
6364
wei_d.data_type(), dst_d.data_type()),
6465
everyone_is(data_type::f16, src_d.data_type(),
66+
wei_d.data_type(), dst_d.data_type()),
67+
everyone_is(data_type::bf16, src_d.data_type(),
6568
wei_d.data_type(), dst_d.data_type())),
66-
" src, dst and wei must be fp16 or fp32");
67-
69+
" src, dst and wei must be fp16, bf16 or fp32");
6870
// batch size
6971
const int mb = src_d.dims()[0];
7072

src/cpu/aarch64/acl_indirect_gemm_convolution.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2023 Arm Ltd. and affiliates
2+
* Copyright 2021-2024 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -84,12 +84,15 @@ struct acl_indirect_gemm_convolution_fwd_t : public primitive_t {
8484

8585
const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef)
8686
&& attr()->has_default_values(smask_t::post_ops, f16);
87+
const bool is_bf16_ok
88+
= expect_data_types(bf16, bf16, bf16, bf16, undef)
89+
&& attr_.post_ops_.len() == 0;
8790
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
8891
&& attr()->has_default_values(
8992
smask_t::post_ops | smask_t::fpmath_mode, f32);
9093
bool ok = is_fwd()
9194
&& set_default_alg_kind(alg_kind::convolution_direct)
92-
&& utils::one_of(true, is_fp16_ok, is_fp32_ok)
95+
&& utils::one_of(true, is_fp16_ok, is_bf16_ok, is_fp32_ok)
9396
&& !has_zero_dim_memory();
9497
if (!ok) return status::unimplemented;
9598

src/cpu/cpu_convolution_list.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2019-2024 Intel Corporation
3-
* Copyright 2020-2023 Arm Ltd. and affiliates
3+
* Copyright 2020-2024 Arm Ltd. and affiliates
44
* Copyright 2020-2024 FUJITSU LIMITED
55
*
66
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -179,6 +179,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
179179
CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<bf16>)
180180
CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>)
181181
CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>)
182+
CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
182183
CPU_INSTANCE(ref_convolution_fwd_t)
183184
CPU_INSTANCE(ref_fused_convolution_fwd_t)
184185
nullptr,

0 commit comments

Comments
 (0)