File tree 1 file changed +19
-0
lines changed
src/plugins/intel_cpu/src/shape_inference/custom
1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -53,7 +53,26 @@ class SDPAShapeInfer : public ShapeInferEmptyPads {
53
53
ScaledDotProductAttentionWithKVCache::Config m_config;
54
54
};
55
55
56
+ class PAShapeInfer : public ShapeInferEmptyPads {
57
+ public:
58
+ PAShapeInfer () {}
59
+
60
+ IShapeInfer::Result infer (const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
61
+ const std::unordered_map<size_t , MemoryPtr>& data_dependency) override {
62
+ const auto & query_dims = input_shapes.front ().get ();
63
+
64
+ return {{query_dims}, ShapeInferStatus::success};
65
+ }
66
+
67
+ port_mask_t get_port_mask () const override {
68
+ return EMPTY_PORT_MASK;
69
+ }
70
+ };
71
+
56
72
ShapeInferPtr SDPAShapeInferFactory::makeShapeInfer () const {
73
+ if (m_op->get_type_name () == std::string (" PagedAttentionExtension" )) {
74
+ return std::make_shared<PAShapeInfer>();
75
+ }
57
76
if (auto sdpa = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(m_op)) {
58
77
const auto & config = sdpa->get_config ();
59
78
if (config.output_BLHxS == false )
You can’t perform that action at this time.
0 commit comments