Skip to content

Commit 1468630

Browse files
[CPU] optimize PagedAttention's shape inference (openvinotoolkit#23603)
### Details: - *Specific shape inference for PagedAttention* - *...* ### Tickets: - *ticket-id*
1 parent f514412 commit 1468630

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/plugins/intel_cpu/src/shape_inference/custom/scaled_attn.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,26 @@ class SDPAShapeInfer : public ShapeInferEmptyPads {
5353
ScaledDotProductAttentionWithKVCache::Config m_config;
5454
};
5555

56+
class PAShapeInfer : public ShapeInferEmptyPads {
57+
public:
58+
PAShapeInfer() {}
59+
60+
IShapeInfer::Result infer(const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
61+
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
62+
const auto& query_dims = input_shapes.front().get();
63+
64+
return {{query_dims}, ShapeInferStatus::success};
65+
}
66+
67+
port_mask_t get_port_mask() const override {
68+
return EMPTY_PORT_MASK;
69+
}
70+
};
71+
5672
ShapeInferPtr SDPAShapeInferFactory::makeShapeInfer() const {
73+
if (m_op->get_type_name() == std::string("PagedAttentionExtension")) {
74+
return std::make_shared<PAShapeInfer>();
75+
}
5776
if (auto sdpa = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(m_op)) {
5877
const auto& config = sdpa->get_config();
5978
if (config.output_BLHxS == false)

0 commit comments

Comments
 (0)