Skip to content

Commit b8ba903

Browse files
authored
[GSOC][CPU][ARM] ACL scaled attention (#25183)
### Details: - This PR aims to add ACL implementation for scaled attention
1 parent f579633 commit b8ba903

File tree

5 files changed

+309
-3
lines changed

5 files changed

+309
-3
lines changed

src/plugins/intel_cpu/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ file(GLOB_RECURSE HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/src/*.h
128128
${CMAKE_CURRENT_SOURCE_DIR}/src/*.hpp)
129129

130130
if(NOT OV_CPU_WITH_ACL)
131-
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/acl/*)
131+
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/acl/*
132+
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/acl/*)
132133
endif()
133134

134135
if(NOT X86_64)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "gemm_kernel.hpp"
5+
#define THROW_ERROR(...) OPENVINO_THROW("ACL gemm executor Init Failure '", __VA_ARGS__)
6+
7+
namespace ov {
8+
namespace intel_cpu {
9+
GemmKernel::GemmKernel(size_t M,
10+
size_t N,
11+
size_t K,
12+
bool b_transposed,
13+
ov::element::Type inType)
14+
: M(M),
15+
N(N),
16+
K(K),
17+
b_transposed(b_transposed) {
18+
if (!one_of(inType, ov::element::f32, ov::element::f16, ov::element::bf16))
19+
THROW_ERROR("brgemm kernel only supports bf16, f16 and f32");
20+
21+
if (inType == ov::element::f32)
22+
format = arm_compute::Format::F32;
23+
else if (inType == ov::element::f16)
24+
format = arm_compute::Format::F16;
25+
else if (inType == ov::element::bf16)
26+
format = arm_compute::Format::BFLOAT16;
27+
28+
29+
aclGemmKernel = std::make_unique<arm_compute::NEGEMM>();
30+
}
31+
32+
arm_compute::Status GemmKernel::executeGemm(void *a,
33+
void *b,
34+
arm_compute::TensorInfo& dstInfo,
35+
arm_compute::Tensor& dstTensor,
36+
arm_compute::Strides aStrides,
37+
arm_compute::Strides bStrides,
38+
void *c,
39+
float alpha,
40+
float beta,
41+
arm_compute::Strides* outStrides,
42+
void* out) {
43+
aInfo.init(
44+
shapeCast({M, N}),
45+
format,
46+
aStrides,
47+
size_t(0),
48+
(size_t)(M * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format))));
49+
50+
arm_compute::TensorShape bShape;
51+
if (b_transposed)
52+
bShape = shapeCast({K, N});
53+
else
54+
bShape = shapeCast({N, K});
55+
56+
bInfo.init(
57+
bShape,
58+
format,
59+
bStrides,
60+
size_t(0),
61+
(size_t)(K * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format))));
62+
63+
aTensor.allocator()->init(aInfo);
64+
bTensor.allocator()->init(bInfo);
65+
66+
if (c != nullptr) {
67+
cInfo.init(shapeCast({M, K}), format);
68+
cTensor.allocator()->init(cInfo);
69+
}
70+
71+
if (outStrides != nullptr)
72+
dstInfo.init(
73+
shapeCast({M, K}),
74+
format,
75+
*outStrides,
76+
size_t(0),
77+
(size_t)(M * K * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format))));
78+
else
79+
dstInfo.init(shapeCast({M, K}), format);
80+
81+
dstTensor.allocator()->init(dstInfo);
82+
83+
aTensor.allocator()->import_memory(reinterpret_cast<void *>(a));
84+
bTensor.allocator()->import_memory(reinterpret_cast<void *>(b));
85+
cTensor.allocator()->import_memory(reinterpret_cast<void *>(c));
86+
87+
if (out == nullptr)
88+
dstTensor.allocator()->allocate();
89+
else
90+
dstTensor.allocator()->import_memory(out);
91+
92+
if (b_transposed)
93+
aclGemmInfo.set_pretranspose_B(true);
94+
95+
auto status = aclGemmKernel->validate(&aInfo, &bInfo, &cInfo, &dstInfo, 1.0, 0.0, aclGemmInfo);
96+
97+
if (c == nullptr)
98+
aclGemmKernel->configure(&aTensor, &bTensor, nullptr, &dstTensor, alpha, beta, aclGemmInfo);
99+
else
100+
aclGemmKernel->configure(&aTensor, &bTensor, &cTensor, &dstTensor, alpha, beta, aclGemmInfo);
101+
aclGemmKernel->run();
102+
103+
return status;
104+
}
105+
} // namespace intel_cpu
106+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#pragma once
5+
#include <cstddef>
6+
#include <openvino/core/type/element_type.hpp>
7+
#include "nodes/executors/acl/acl_utils.hpp"
8+
#include "utils/general_utils.h"
9+
10+
#include "arm_compute/runtime/NEON/NEFunctions.h"
11+
#include "arm_compute/core/Types.h"
12+
13+
namespace ov {
14+
namespace intel_cpu {
15+
class GemmKernel {
16+
public:
17+
GemmKernel(size_t M,
18+
size_t N,
19+
size_t K,
20+
bool b_transposed = false,
21+
ov::element::Type inType = ov::element::f32);
22+
23+
arm_compute::Status executeGemm(void* a,
24+
void* b,
25+
arm_compute::TensorInfo& dstInfo,
26+
arm_compute::Tensor& dstTensor,
27+
arm_compute::Strides aStrides,
28+
arm_compute::Strides bStrides,
29+
void* c = nullptr,
30+
float alpha = 1.0f,
31+
float beta = 0.0f,
32+
arm_compute::Strides* outStrides = nullptr,
33+
void* out = nullptr);
34+
35+
private:
36+
size_t M = 0;
37+
size_t N = 0, K = 0;
38+
bool b_transposed = false;
39+
arm_compute::Format format;
40+
arm_compute::TensorInfo aInfo;
41+
arm_compute::TensorInfo bInfo;
42+
arm_compute::TensorInfo cInfo;
43+
arm_compute::Tensor aTensor;
44+
arm_compute::Tensor bTensor;
45+
arm_compute::Tensor cTensor;
46+
arm_compute::Tensor dTensor;
47+
std::unique_ptr<arm_compute::NEGEMM> aclGemmKernel;
48+
arm_compute::GEMMInfo aclGemmInfo;
49+
};
50+
51+
} // namespace intel_cpu
52+
} // namespace ov

src/plugins/intel_cpu/src/nodes/scaled_attn.cpp

+148-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
# include "mlas/sgemm.hpp"
2424
#endif
2525

26+
#ifdef OV_CPU_WITH_ACL
27+
# include "kernels/acl/gemm_kernel.hpp"
28+
#endif
29+
2630
#include "utils/plain_tensor.hpp"
2731
#include "kernels/scaled_attn/softmax.hpp"
2832
#include "kernels/scaled_attn/mha_single_token.hpp"
@@ -505,6 +509,147 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
505509
}
506510
};
507511

512+
#ifdef OV_CPU_WITH_ACL
513+
template <>
514+
struct MHAKernel<ScaledDotProductAttention::KT_ACL, float> {
515+
const GraphContext::CPtr context;
516+
size_t m_block_size;
517+
518+
MHAKernel() = delete;
519+
explicit MHAKernel(GraphContext::CPtr ctx): context(ctx) {
520+
m_block_size = 512;
521+
select_nfltmax_at_0 = false;
522+
}
523+
524+
PlainTensor causal_mask;
525+
bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this
526+
void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) {
527+
causal_mask = mask;
528+
select_nfltmax_at_0 = _select_nfltmax_at_0;
529+
}
530+
531+
// Q, K, V is ready, do attention
532+
// query [B, H, q_len, S]
533+
// present_key [B, H, kv_len, S] stride of last dim maybe > 1
534+
// present_value [B, H, kv_len, S]
535+
// attention_mask [B, 1, q_len, kv_len]
536+
// alibi
537+
// output_emb [B, L1, H*S]
538+
void operator()(dnnl::stream strm,
539+
PlainTensor& query,
540+
PlainTensor& present_key,
541+
PlainTensor& present_value,
542+
const PlainTensor& alibi_mask,
543+
const PlainTensor& attention_mask,
544+
PlainTensor& output_emb,
545+
bool has_out_transpose,
546+
bool auto_causal,
547+
float d_scale = 0.0f) {
548+
auto B = query.size(0);
549+
auto H = query.size(1);
550+
auto q_len = query.size(2);
551+
auto head_size = query.size(3);
552+
auto kv_len = present_key.size(2);
553+
auto h_group_num = present_key.size(1);
554+
size_t h_each_group_len = H / h_group_num;
555+
556+
if (d_scale == 0.0f)
557+
d_scale = 1.0f / sqrt(head_size);
558+
auto k_stride_s = present_key.stride(3);
559+
560+
auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
561+
562+
parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) {
563+
auto m_start = m_blk * m_block_size;
564+
auto m_end = std::min(m_start + m_block_size, q_len);
565+
auto m_cnt = m_end - m_start;
566+
567+
float* q_ptr = &query.at<float>({b, h, m_start, 0});
568+
float* k_ptr = &present_key.at<float>({b, h / h_each_group_len, 0, 0});
569+
float* v_ptr = &present_value.at<float>({b, h / h_each_group_len, 0, 0});
570+
571+
float* alibi_ptr = nullptr;
572+
auto alibi_stride = 0;
573+
if (alibi_mask) {
574+
alibi_ptr = &alibi_mask.at<float>({b, h, 0, 0}, true);
575+
if (alibi_mask.size(2) > 1)
576+
alibi_stride = alibi_mask.stride(2);
577+
}
578+
uint8_t* attn_mask_ptr = nullptr;
579+
auto attn_mask_stride = 0;
580+
if (attention_mask) {
581+
attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<float>({b, h, 0, 0}, true));
582+
if (attention_mask.size(2) > 1)
583+
attn_mask_stride = attention_mask.stride(2) * sizeof(float);
584+
}
585+
uint8_t* cmask_ptr = nullptr;
586+
auto cmask_stride = 0;
587+
if (causal_mask) {
588+
cmask_ptr = &causal_mask.at<uint8_t>({b, h, 0, 0}, true);
589+
if (causal_mask.size(2) > 1)
590+
cmask_stride = causal_mask.stride(2);
591+
}
592+
593+
arm_compute::Tensor qkTensor;
594+
arm_compute::TensorInfo qkInfo;
595+
596+
bool b_transpose = false;
597+
if (k_stride_s == 1)
598+
b_transpose = true;
599+
GemmKernel qk_gemm(m_cnt, head_size, kv_len, b_transpose);
600+
601+
arm_compute::Strides qStrides({query.stride_bytes(3), query.stride_bytes(2)});
602+
arm_compute::Strides kStrides({present_key.stride_bytes(3), present_key.stride_bytes(2)});
603+
qk_gemm.executeGemm(reinterpret_cast<void *>(q_ptr),
604+
reinterpret_cast<void *>(k_ptr),
605+
qkInfo,
606+
qkTensor,
607+
qStrides,
608+
kStrides);
609+
610+
auto qk = reinterpret_cast<float*>(qkTensor.buffer());
611+
612+
613+
for (size_t m = m_start; m < m_end; m++) {
614+
// apply attention mask & sofmax
615+
auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len;
616+
attn_softmax(qk + (m - m_start) * kv_len,
617+
qk + (m - m_start) * kv_len,
618+
d_scale,
619+
alibi_ptr + m * alibi_stride,
620+
attn_mask_ptr + m * attn_mask_stride,
621+
cmask_ptr + m * cmask_stride,
622+
select_nfltmax_at_0,
623+
ncausal,
624+
kv_len,
625+
ov::element::f32,
626+
ov::element::f32);
627+
}
628+
arm_compute::TensorInfo outInfo;
629+
arm_compute::Tensor outTensor;
630+
631+
auto out = has_out_transpose ? &output_emb.at<float>({b, m_start, h * head_size}) : &output_emb.at<float>({b, h, m_start});
632+
auto strides = arm_compute::Strides({output_emb.stride_bytes(1), output_emb.stride_bytes(2)});
633+
GemmKernel out_gemm(m_cnt, kv_len, head_size);
634+
635+
arm_compute::Strides vStrides({present_value.stride_bytes(3), present_value.stride_bytes(2)});
636+
out_gemm.executeGemm(qkTensor.buffer(),
637+
reinterpret_cast<void *>(v_ptr),
638+
outInfo,
639+
outTensor,
640+
qkInfo.strides_in_bytes(),
641+
vStrides,
642+
nullptr,
643+
1.0,
644+
0.0,
645+
&strides,
646+
reinterpret_cast<void*>(out));
647+
qkTensor.allocator()->free();
648+
});
649+
}
650+
};
651+
#endif
652+
508653
#ifdef OV_CPU_WITH_MLAS
509654
template <>
510655
struct MHAKernel<ScaledDotProductAttention::KT_MLAS, float> {
@@ -935,7 +1080,9 @@ void ScaledDotProductAttention::createPrimitive() {
9351080
executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(context);
9361081
#endif
9371082
} else {
938-
#ifdef OV_CPU_WITH_MLAS
1083+
#ifdef OV_CPU_WITH_ACL
1084+
executor = std::make_shared<AttentionExecutor<KT_ACL, float>>(context);
1085+
#elif defined(OV_CPU_WITH_MLAS)
9391086
executor = std::make_shared<AttentionExecutor<KT_MLAS, float>>(context);
9401087
#elif defined(OPENVINO_ARCH_X86_64)
9411088
if (with_cpu_x86_avx512_core()) {

src/plugins/intel_cpu/src/nodes/scaled_attn.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ScaledDotProductAttention : public Node {
3636
void createPrimitive() override;
3737
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
3838

39-
enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS};
39+
enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS, KT_ACL};
4040

4141
void assignState(const std::shared_ptr<VariableStateKVcache>& state, int idx);
4242

0 commit comments

Comments
 (0)