Skip to content

Commit 2844463

Browse files
usstqwgzintel
andauthored
[CPU] Fix performance regressions in Llama & Gemma (openvinotoolkit#23396)
### Details: - the sin/cos cache is removed from LlamaRotaryEmbedding, sin/cos table used for ROPE is generated on the fly for each 2nd token and each layer, this PR is match such subgraph and share the first match across all macthes - [There](https://github.com/openvinotoolkit/openvino/blob/5b7d7c87a8a5847f97d869473fa22e8f1290a116/src/common/transformations/src/transformations/common_optimizations/shared_ops_optimization.cpp#L64) is generic common transformation doing similar optimizations but failed due to `cpu___module.model.layers.1.self_attn.rotary_emb/aten::to/Convert` is not constant (but will be const folded later, and after that it can pass the check). another limitation is SharedOpOptimization is only for OP sharing rather than a subgraph sharing. - the causal mask maintenance logic in original Llama & Gemma model is very inefficient, a big triangular causal mask tensor is updated on every token generation: ```python def _update_causal_mask(): ... ... min_dtype = torch.finfo(dtype).min causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype causal_mask = causal_mask.to(dtype=dtype, device=device) mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) ``` due to limitations in ngraph OP set, PT FE generated even more costly subgraph : a Range node generating 8K-by-8K 1D indices to ScatterElementsUpdate OP. but actually only very small slice of the result mask is being used by SDPA as combined causal/attention mask. ```python if attention_mask is not None and cache_position is not None: causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] ``` I introduced a cpu custom ngraph & node `CausalMaskPreprocess` to fuse above logic. - gen_pattern code is also partially refactored to display better verbose logs when pattern matching failed. ### Tickets: - *CVS-133878* - *CVS-134592* --------- Co-authored-by: guozhong wang <guozhong.wang@intel.com>
1 parent daa947d commit 2844463

18 files changed

+1172
-239
lines changed

src/plugins/intel_cpu/src/cpu_types.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
219219
{"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention},
220220
{"PagedAttentionExtension", Type::ScaledDotProductAttention},
221221
{"RoPE", Type::RoPE},
222+
{"CausalMaskPreprocess", Type::CausalMaskPreprocess},
222223
};
223224
return type_to_name_tbl;
224225
}
@@ -337,6 +338,7 @@ std::string NameFromType(const Type type) {
337338
CASE(Ngram);
338339
CASE(ScaledDotProductAttention);
339340
CASE(RoPE);
341+
CASE(CausalMaskPreprocess);
340342
CASE(Unknown);
341343
}
342344
#undef CASE

src/plugins/intel_cpu/src/cpu_types.h

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ enum class Type {
116116
Ngram,
117117
ScaledDotProductAttention,
118118
RoPE,
119+
CausalMaskPreprocess,
119120
};
120121

121122
enum class Algorithm {

src/plugins/intel_cpu/src/extension.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ov_ops/nms_static_shape_ie.hpp"
1313
#include "ov_ops/type_relaxed.hpp"
1414
#include "snippets/op/subgraph.hpp"
15+
#include "transformations/cpu_opset/common/op/causal_mask_preprocess.hpp"
1516
#include "transformations/cpu_opset/common/op/fully_connected.hpp"
1617
#include "transformations/cpu_opset/common/op/leaky_relu.hpp"
1718
#include "transformations/cpu_opset/common/op/ngram.hpp"
@@ -69,6 +70,7 @@ class TypeRelaxedExtension : public ov::OpExtension<ov::op::TypeRelaxed<Op>> {
6970
OP_EXTENSION(ov::intel_cpu::LeakyReluNode) \
7071
OP_EXTENSION(ov::intel_cpu::PowerStaticNode) \
7172
OP_EXTENSION(ov::intel_cpu::RoPENode) \
73+
OP_EXTENSION(ov::intel_cpu::CausalMaskPreprocessNode) \
7274
OP_EXTENSION(ov::intel_cpu::SwishNode) \
7375
OP_EXTENSION(ov::intel_cpu::NgramNode) \
7476
OP_EXTENSION(ov::op::internal::NonMaxSuppressionIEInternal) \
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "causal_mask_preprocess.h"
6+
7+
#include "common/bfloat16.hpp"
8+
#include "common/cpu_memcpy.h"
9+
#include "cpu/x64/cpu_isa_traits.hpp"
10+
#include "shape_inference/shape_inference_internal_dyn.hpp"
11+
#include "utils/plain_tensor.hpp"
12+
13+
#include <chrono>
14+
#include <string>
15+
#include <vector>
16+
17+
namespace ov {
18+
namespace intel_cpu {
19+
namespace node {
20+
21+
/*
22+
CausalMaskPreprocess:
23+
inputs:
24+
0: attention_mask : i64[N, kv_len]
25+
0 means mask-out, 1 means attends to
26+
1: batch_size (size_Gather) : i32[1]
27+
2: cache_positions i32[q_len];
28+
3: kvLen i32[1];
29+
outputs
30+
0: causal mask for SDPA : f32[batch_size, 1, q_len, kvLen]
31+
32+
The functionality is equivalent to following python code:
33+
34+
##### preprocess
35+
min_dtype = torch.finfo(dtype).min
36+
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
37+
causal_mask = causal_mask.to(dtype=dtype, device=device)
38+
39+
mask_length = attention_mask.shape[-1]
40+
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
41+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
42+
43+
##### when being used will be further sliced
44+
causal_mask = attention_mask
45+
if attention_mask is not None and cache_position is not None:
46+
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
47+
*/
48+
template <typename T>
49+
struct CausalMaskPreprocess::ExecutorCausalMaskPreprocess : public CausalMaskPreprocess::Executor {
50+
void execute(dnnl::stream strm,
51+
intel_cpu::Node * pnode,
52+
const intel_cpu::CausalMaskPreprocessNode::Config& config) override {
53+
ov::intel_cpu::PlainTensor t_attention_mask(pnode->getSrcMemoryAtPort(0));
54+
ov::intel_cpu::PlainTensor t_batch_size(pnode->getSrcMemoryAtPort(1));
55+
ov::intel_cpu::PlainTensor t_cache_positions(pnode->getSrcMemoryAtPort(2));
56+
ov::intel_cpu::PlainTensor t_kvLen(pnode->getSrcMemoryAtPort(3));
57+
58+
auto mask_length = t_attention_mask.size(-1);
59+
auto batch_size = static_cast<size_t>(*t_batch_size.ptr<int32_t>(0));
60+
auto kvLen = static_cast<size_t>(*t_kvLen.ptr<int32_t>(0));
61+
auto qLen = t_cache_positions.size(0);
62+
63+
VectorDims newDims{batch_size, 1, qLen, kvLen};
64+
pnode->redefineOutputMemory({newDims});
65+
ov::intel_cpu::PlainTensor t_dst(pnode->getDstMemoryAtPort(0));
66+
67+
DEBUG_LOG("CausalMaskPreprocess::execute", config.type, " batch_size=", batch_size, " qLen=", qLen, " kvLen=", kvLen);
68+
DEBUG_LOG("CausalMaskPreprocess::execute attention_mask=", t_attention_mask);
69+
DEBUG_LOG("CausalMaskPreprocess::execute cache_positions=", t_cache_positions);
70+
71+
// raw_causal_mask is already ensured to be triu by transformation
72+
auto* prow = t_cache_positions.ptr<int32_t>(0);
73+
T min_dtype = std::numeric_limits<T>::lowest();
74+
75+
parallel_for2d(batch_size, qLen, [&](size_t n, size_t i) {
76+
auto* pamask = t_attention_mask.ptr<int32_t>(n, 0);
77+
auto* pdst = t_dst.ptr<T>(n, 0, i);
78+
auto row = static_cast<size_t>(prow[i]);
79+
size_t j = 0;
80+
for (; j < mask_length; j++) {
81+
bool cmask_eq0 = (j <= row);
82+
bool amask_eq0 = (pamask[j] == 0);
83+
bool padding_mask = (cmask_eq0 && amask_eq0);
84+
pdst[j] = (padding_mask | (!cmask_eq0))? min_dtype : T(0);
85+
}
86+
for (; j < kvLen; j++) {
87+
bool cmask_eq0 = (j <= row);
88+
pdst[j] = cmask_eq0 ? T(0) : min_dtype;
89+
}
90+
});
91+
DEBUG_LOG("CausalMaskPreprocess::execute dst=", t_dst);
92+
}
93+
};
94+
95+
CausalMaskPreprocess::CausalMaskPreprocess(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
96+
: Node(op, context, InternalDynShapeInferFactory()) {
97+
std::string errorMessage;
98+
if (!isSupportedOperation(op, errorMessage)) {
99+
OPENVINO_THROW("CPU: " + errorMessage);
100+
}
101+
102+
const auto node = std::dynamic_pointer_cast<const intel_cpu::CausalMaskPreprocessNode>(op);
103+
m_config = node->get_config();
104+
}
105+
106+
bool CausalMaskPreprocess::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
107+
try {
108+
const auto node = std::dynamic_pointer_cast<const intel_cpu::CausalMaskPreprocessNode>(op);
109+
if (!node) {
110+
errorMessage = "Only CausalMaskPreprocessNode operation is supported";
111+
return false;
112+
}
113+
} catch (...) {
114+
return false;
115+
}
116+
return true;
117+
}
118+
119+
void CausalMaskPreprocess::initSupportedPrimitiveDescriptors() {
120+
if (!supportedPrimitiveDescriptors.empty())
121+
return;
122+
123+
std::vector<ov::element::Type> iprecs = getOriginalInputPrecisions();
124+
std::vector<ov::element::Type> oprecs = getOriginalOutputPrecisions();
125+
126+
// precision preferences
127+
if (m_config.type == "CausalMaskPreprocess") {
128+
if (oprecs[0] == ov::element::bf16) {
129+
m_executor = std::make_shared<ExecutorCausalMaskPreprocess<ov::bfloat16>>();
130+
} else {
131+
// fallback to default precision
132+
m_executor = std::make_shared<ExecutorCausalMaskPreprocess<float>>();
133+
oprecs[0] = ov::element::f32;
134+
}
135+
// all input precisions must be int32
136+
for (auto& prec : iprecs) prec = ov::element::i32;
137+
} else {
138+
OPENVINO_THROW("CPU: CausalMaskPreprocess type not supported : " + m_config.type);
139+
}
140+
141+
std::vector<PortConfigurator> inPortConfigs;
142+
for (size_t i = 0; i < getOriginalInputsNumber(); i++)
143+
inPortConfigs.emplace_back(LayoutType::ncsp, iprecs[i], getInputShapeAtPort(i), false, -1);
144+
145+
std::vector<PortConfigurator> outPortConfigs;
146+
for (size_t i = 0; i < getOriginalOutputsNumber(); i++)
147+
outPortConfigs.emplace_back(LayoutType::ncsp, oprecs[i], getOutputShapeAtPort(i), false, -1);
148+
149+
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
150+
}
151+
152+
void CausalMaskPreprocess::execute(dnnl::stream strm) {
153+
m_executor->execute(strm, this, m_config);
154+
}
155+
156+
} // namespace node
157+
} // namespace intel_cpu
158+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "node.h"
8+
#include "transformations/cpu_opset/common/op/causal_mask_preprocess.hpp"
9+
10+
namespace ov {
11+
namespace intel_cpu {
12+
namespace node {
13+
14+
class CausalMaskPreprocess : public Node {
15+
public:
16+
CausalMaskPreprocess(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
17+
18+
void getSupportedDescriptors() override {}
19+
bool created() const override {
20+
return getType() == Type::CausalMaskPreprocess;
21+
}
22+
bool needPrepareParams() const override {
23+
return false;
24+
};
25+
void executeDynamicImpl(dnnl::stream strm) override {
26+
execute(strm);
27+
}
28+
void initSupportedPrimitiveDescriptors() override;
29+
void execute(dnnl::stream strm) override;
30+
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
31+
32+
private:
33+
struct Executor {
34+
virtual void execute(dnnl::stream strm,
35+
intel_cpu::Node * pnode,
36+
const intel_cpu::CausalMaskPreprocessNode::Config& config) = 0;
37+
};
38+
template <typename T>
39+
struct ExecutorCausalMaskPreprocess;
40+
intel_cpu::CausalMaskPreprocessNode::Config m_config;
41+
std::shared_ptr<Executor> m_executor;
42+
};
43+
44+
} // namespace node
45+
} // namespace intel_cpu
46+
} // namespace ov

src/plugins/intel_cpu/src/nodes_factory.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
#include "nodes/topk.h"
9797
#include "nodes/transpose.h"
9898
#include "nodes/unique.hpp"
99+
#include "nodes/causal_mask_preprocess.h"
99100

100101
namespace ov {
101102
namespace intel_cpu {
@@ -183,6 +184,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") {
183184
INTEL_CPU_NODE(Unique, Type::Unique);
184185
INTEL_CPU_NODE(Ngram, Type::Ngram);
185186
INTEL_CPU_NODE(RoPE, Type::RoPE);
187+
INTEL_CPU_NODE(CausalMaskPreprocess, Type::CausalMaskPreprocess);
186188
INTEL_CPU_NODE(Interpolate, Type::Interpolate);
187189
INTEL_CPU_NODE(Inverse, Type::Inverse);
188190
INTEL_CPU_NODE(RandomUniform, Type::RandomUniform);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "causal_mask_preprocess.hpp"
5+
6+
#include <algorithm>
7+
8+
#include "transformations/itt.hpp"
9+
10+
ov::intel_cpu::CausalMaskPreprocessNode::CausalMaskPreprocessNode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) {
11+
constructor_validate_and_infer_types();
12+
}
13+
14+
std::shared_ptr<ov::Node> ov::intel_cpu::CausalMaskPreprocessNode::clone_with_new_inputs(const ov::OutputVector& new_args) const {
15+
INTERNAL_OP_SCOPE(CausalMaskPreprocessNode_with_new_inputs);
16+
check_new_args_count(this, new_args);
17+
return std::make_shared<ov::intel_cpu::CausalMaskPreprocessNode>(new_args, m_config);
18+
}
19+
20+
void ov::intel_cpu::CausalMaskPreprocessNode::validate_and_infer_types() {
21+
INTERNAL_OP_SCOPE(CausalMaskPreprocessNode_validate_and_infer_types);
22+
if (m_config.type == "CausalMaskPreprocess") {
23+
// inputs:
24+
// 0: attention_mask : i64[N, kv_len]
25+
// 0 means mask-out, 1 means attends to
26+
// 1: batch_size (size_Gather) : i32[1]
27+
// 2: cache_positions i32[q_len];
28+
// 3: kvLen i32[1];
29+
// outputs
30+
// 0: causal mask for SDPA : f32[batch_size, 1, q_len, kvLen]
31+
32+
auto batch_size = Dimension::dynamic();
33+
auto q_len = get_input_partial_shape(2)[0];
34+
auto kv_len = Dimension::dynamic();
35+
set_output_type(0, ov::element::f32, {batch_size, 1, q_len, kv_len});
36+
return;
37+
}
38+
NODE_VALIDATION_CHECK(this, false, "unsupported type : ", m_config.type);
39+
}
40+
41+
bool ov::intel_cpu::CausalMaskPreprocessNode::visit_attributes(ov::AttributeVisitor& visitor) {
42+
INTERNAL_OP_SCOPE(CausalMaskPreprocessNode_visit_attributes);
43+
visitor.start_structure("config");
44+
visitor.on_attribute("type", m_config.type);
45+
visitor.finish_structure();
46+
return true;
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/op/op.hpp"
8+
9+
namespace ov {
10+
namespace intel_cpu {
11+
12+
class CausalMaskPreprocessNode : public ov::op::Op {
13+
public:
14+
OPENVINO_OP("CausalMaskPreprocess", "cpu_plugin_opset");
15+
16+
CausalMaskPreprocessNode() = default;
17+
18+
struct Config {
19+
std::string type;
20+
};
21+
22+
CausalMaskPreprocessNode(const OutputVector& args, const Config& cfg);
23+
24+
bool visit_attributes(ov::AttributeVisitor& visitor) override;
25+
26+
void validate_and_infer_types() override;
27+
28+
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
29+
30+
const Config& get_config() const {
31+
return m_config;
32+
}
33+
34+
Config& get_config() {
35+
return m_config;
36+
}
37+
38+
private:
39+
Config m_config;
40+
};
41+
42+
} // namespace intel_cpu
43+
} // namespace ov

0 commit comments

Comments
 (0)