diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp index c62ed1ce909..de18f6d3488 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp @@ -15,6 +15,7 @@ *******************************************************************************/ #include "cpu/aarch64/acl_convolution_utils.hpp" +#include "common/convolution_pd.hpp" #include "common/utils.hpp" #include "oneapi/dnnl/dnnl.h" @@ -62,9 +63,10 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, everyone_is(data_type::f32, src_d.data_type(), wei_d.data_type(), dst_d.data_type()), everyone_is(data_type::f16, src_d.data_type(), + wei_d.data_type(), dst_d.data_type()), + everyone_is(data_type::bf16, src_d.data_type(), wei_d.data_type(), dst_d.data_type())), - " src, dst and wei must be fp16 or fp32"); - + " src, dst and wei must be fp16, bf16 or fp32"); // batch size const int mb = src_d.dims()[0]; diff --git a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp index 762878ad7d1..a48d48ea30b 100644 --- a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp +++ b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Arm Ltd. and affiliates +* Copyright 2021-2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -84,12 +84,15 @@ struct acl_indirect_gemm_convolution_fwd_t : public primitive_t { const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) && attr()->has_default_values(smask_t::post_ops, f16); + const bool is_bf16_ok + = expect_data_types(bf16, bf16, bf16, bf16, undef) + && attr_.post_ops_.len() == 0; const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) && attr()->has_default_values( smask_t::post_ops | smask_t::fpmath_mode, f32); bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) - && utils::one_of(true, is_fp16_ok, is_fp32_ok) + && utils::one_of(true, is_fp16_ok, is_bf16_ok, is_fp32_ok) && !has_zero_dim_memory(); if (!ok) return status::unimplemented; diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp index dcdddedfac8..67a0093cde9 100644 --- a/src/cpu/cpu_convolution_list.cpp +++ b/src/cpu/cpu_convolution_list.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2019-2024 Intel Corporation -* Copyright 2020-2023 Arm Ltd. and affiliates +* Copyright 2020-2024 Arm Ltd. and affiliates * Copyright 2020-2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -179,6 +179,7 @@ const std::map> &impl_list_map() CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t) CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) CPU_INSTANCE(ref_convolution_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) nullptr,