|
23 | 23 | # include "mlas/sgemm.hpp"
|
24 | 24 | #endif
|
25 | 25 |
|
| 26 | +#ifdef OV_CPU_WITH_ACL |
| 27 | +# include "kernels/acl/gemm_kernel.hpp" |
| 28 | +#endif |
| 29 | + |
26 | 30 | #include "utils/plain_tensor.hpp"
|
27 | 31 | #include "kernels/scaled_attn/softmax.hpp"
|
28 | 32 | #include "kernels/scaled_attn/mha_single_token.hpp"
|
@@ -505,6 +509,147 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
505 | 509 | }
|
506 | 510 | };
|
507 | 511 |
|
| 512 | +#ifdef OV_CPU_WITH_ACL |
| 513 | +template <> |
| 514 | +struct MHAKernel<ScaledDotProductAttention::KT_ACL, float> { |
| 515 | + const GraphContext::CPtr context; |
| 516 | + size_t m_block_size; |
| 517 | + |
| 518 | + MHAKernel() = delete; |
| 519 | + explicit MHAKernel(GraphContext::CPtr ctx): context(ctx) { |
| 520 | + m_block_size = 512; |
| 521 | + select_nfltmax_at_0 = false; |
| 522 | + } |
| 523 | + |
| 524 | + PlainTensor causal_mask; |
| 525 | + bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this |
| 526 | + void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { |
| 527 | + causal_mask = mask; |
| 528 | + select_nfltmax_at_0 = _select_nfltmax_at_0; |
| 529 | + } |
| 530 | + |
| 531 | + // Q, K, V is ready, do attention |
| 532 | + // query [B, H, q_len, S] |
| 533 | + // present_key [B, H, kv_len, S] stride of last dim maybe > 1 |
| 534 | + // present_value [B, H, kv_len, S] |
| 535 | + // attention_mask [B, 1, q_len, kv_len] |
| 536 | + // alibi |
| 537 | + // output_emb [B, L1, H*S] |
| 538 | + void operator()(dnnl::stream strm, |
| 539 | + PlainTensor& query, |
| 540 | + PlainTensor& present_key, |
| 541 | + PlainTensor& present_value, |
| 542 | + const PlainTensor& alibi_mask, |
| 543 | + const PlainTensor& attention_mask, |
| 544 | + PlainTensor& output_emb, |
| 545 | + bool has_out_transpose, |
| 546 | + bool auto_causal, |
| 547 | + float d_scale = 0.0f) { |
| 548 | + auto B = query.size(0); |
| 549 | + auto H = query.size(1); |
| 550 | + auto q_len = query.size(2); |
| 551 | + auto head_size = query.size(3); |
| 552 | + auto kv_len = present_key.size(2); |
| 553 | + auto h_group_num = present_key.size(1); |
| 554 | + size_t h_each_group_len = H / h_group_num; |
| 555 | + |
| 556 | + if (d_scale == 0.0f) |
| 557 | + d_scale = 1.0f / sqrt(head_size); |
| 558 | + auto k_stride_s = present_key.stride(3); |
| 559 | + |
| 560 | + auto m_blocks = (q_len + m_block_size - 1) / m_block_size; |
| 561 | + |
| 562 | + parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) { |
| 563 | + auto m_start = m_blk * m_block_size; |
| 564 | + auto m_end = std::min(m_start + m_block_size, q_len); |
| 565 | + auto m_cnt = m_end - m_start; |
| 566 | + |
| 567 | + float* q_ptr = &query.at<float>({b, h, m_start, 0}); |
| 568 | + float* k_ptr = &present_key.at<float>({b, h / h_each_group_len, 0, 0}); |
| 569 | + float* v_ptr = &present_value.at<float>({b, h / h_each_group_len, 0, 0}); |
| 570 | + |
| 571 | + float* alibi_ptr = nullptr; |
| 572 | + auto alibi_stride = 0; |
| 573 | + if (alibi_mask) { |
| 574 | + alibi_ptr = &alibi_mask.at<float>({b, h, 0, 0}, true); |
| 575 | + if (alibi_mask.size(2) > 1) |
| 576 | + alibi_stride = alibi_mask.stride(2); |
| 577 | + } |
| 578 | + uint8_t* attn_mask_ptr = nullptr; |
| 579 | + auto attn_mask_stride = 0; |
| 580 | + if (attention_mask) { |
| 581 | + attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<float>({b, h, 0, 0}, true)); |
| 582 | + if (attention_mask.size(2) > 1) |
| 583 | + attn_mask_stride = attention_mask.stride(2) * sizeof(float); |
| 584 | + } |
| 585 | + uint8_t* cmask_ptr = nullptr; |
| 586 | + auto cmask_stride = 0; |
| 587 | + if (causal_mask) { |
| 588 | + cmask_ptr = &causal_mask.at<uint8_t>({b, h, 0, 0}, true); |
| 589 | + if (causal_mask.size(2) > 1) |
| 590 | + cmask_stride = causal_mask.stride(2); |
| 591 | + } |
| 592 | + |
| 593 | + arm_compute::Tensor qkTensor; |
| 594 | + arm_compute::TensorInfo qkInfo; |
| 595 | + |
| 596 | + bool b_transpose = false; |
| 597 | + if (k_stride_s == 1) |
| 598 | + b_transpose = true; |
| 599 | + GemmKernel qk_gemm(m_cnt, head_size, kv_len, b_transpose); |
| 600 | + |
| 601 | + arm_compute::Strides qStrides({query.stride_bytes(3), query.stride_bytes(2)}); |
| 602 | + arm_compute::Strides kStrides({present_key.stride_bytes(3), present_key.stride_bytes(2)}); |
| 603 | + qk_gemm.executeGemm(reinterpret_cast<void *>(q_ptr), |
| 604 | + reinterpret_cast<void *>(k_ptr), |
| 605 | + qkInfo, |
| 606 | + qkTensor, |
| 607 | + qStrides, |
| 608 | + kStrides); |
| 609 | + |
| 610 | + auto qk = reinterpret_cast<float*>(qkTensor.buffer()); |
| 611 | + |
| 612 | + |
| 613 | + for (size_t m = m_start; m < m_end; m++) { |
| 614 | + // apply attention mask & sofmax |
| 615 | + auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; |
| 616 | + attn_softmax(qk + (m - m_start) * kv_len, |
| 617 | + qk + (m - m_start) * kv_len, |
| 618 | + d_scale, |
| 619 | + alibi_ptr + m * alibi_stride, |
| 620 | + attn_mask_ptr + m * attn_mask_stride, |
| 621 | + cmask_ptr + m * cmask_stride, |
| 622 | + select_nfltmax_at_0, |
| 623 | + ncausal, |
| 624 | + kv_len, |
| 625 | + ov::element::f32, |
| 626 | + ov::element::f32); |
| 627 | + } |
| 628 | + arm_compute::TensorInfo outInfo; |
| 629 | + arm_compute::Tensor outTensor; |
| 630 | + |
| 631 | + auto out = has_out_transpose ? &output_emb.at<float>({b, m_start, h * head_size}) : &output_emb.at<float>({b, h, m_start}); |
| 632 | + auto strides = arm_compute::Strides({output_emb.stride_bytes(1), output_emb.stride_bytes(2)}); |
| 633 | + GemmKernel out_gemm(m_cnt, kv_len, head_size); |
| 634 | + |
| 635 | + arm_compute::Strides vStrides({present_value.stride_bytes(3), present_value.stride_bytes(2)}); |
| 636 | + out_gemm.executeGemm(qkTensor.buffer(), |
| 637 | + reinterpret_cast<void *>(v_ptr), |
| 638 | + outInfo, |
| 639 | + outTensor, |
| 640 | + qkInfo.strides_in_bytes(), |
| 641 | + vStrides, |
| 642 | + nullptr, |
| 643 | + 1.0, |
| 644 | + 0.0, |
| 645 | + &strides, |
| 646 | + reinterpret_cast<void*>(out)); |
| 647 | + qkTensor.allocator()->free(); |
| 648 | + }); |
| 649 | + } |
| 650 | +}; |
| 651 | +#endif |
| 652 | + |
508 | 653 | #ifdef OV_CPU_WITH_MLAS
|
509 | 654 | template <>
|
510 | 655 | struct MHAKernel<ScaledDotProductAttention::KT_MLAS, float> {
|
@@ -935,7 +1080,9 @@ void ScaledDotProductAttention::createPrimitive() {
|
935 | 1080 | executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(context);
|
936 | 1081 | #endif
|
937 | 1082 | } else {
|
938 |
| -#ifdef OV_CPU_WITH_MLAS |
| 1083 | +#ifdef OV_CPU_WITH_ACL |
| 1084 | + executor = std::make_shared<AttentionExecutor<KT_ACL, float>>(context); |
| 1085 | +#elif defined(OV_CPU_WITH_MLAS) |
939 | 1086 | executor = std::make_shared<AttentionExecutor<KT_MLAS, float>>(context);
|
940 | 1087 | #elif defined(OPENVINO_ARCH_X86_64)
|
941 | 1088 | if (with_cpu_x86_avx512_core()) {
|
|
0 commit comments