Skip to content

Commit c975788

Browse files
authored
[cpu] Integrate IStaticShapeInfer wirth IShapeInfer (openvinotoolkit#27770)
### Details: - The `IStaticShapeInfer` interface extends `IShapeInfer`. - Remove `NgraphShapeInfer` class as its functionality is replaced by `IStaticShapeInfer`. - Refactor shape inference unit test to avoid names clashes with CPU plugin types: - use `ov::Shape` to avoid interpretation as `intel_cpu::Shape`. - rename test type `ShapeVector` to `StaticShapeVector`. ### Tickets: - CVS-118704 --------- Signed-off-by: Pawel Raasz <pawel.raasz@intel.com> Signed-off-by: Raasz, Pawel <pawel.raasz@intel.com>
1 parent 78a6ad8 commit c975788

File tree

97 files changed

+1077
-1214
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+1077
-1214
lines changed

src/plugins/intel_cpu/src/nodes/deconv.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <common/primitive_desc.hpp>
1313
#include <common/primitive_desc_iface.hpp>
1414
#include "cpu/x64/cpu_isa_traits.hpp"
15-
#include "shape_inference/shape_inference_ngraph.hpp"
15+
#include "shape_inference/shape_inference.hpp"
1616

1717
#include "eltwise.h"
1818
#include "fake_quantize.h"
@@ -128,12 +128,11 @@ bool DeconvKey::operator==(const DeconvKey &rhs) const {
128128
*/
129129
class DeconfolutionShapeInferFactory : public ShapeInferFactory {
130130
public:
131-
DeconfolutionShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
131+
DeconfolutionShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(std::move(op)) {}
132+
132133
ShapeInferPtr makeShapeInfer() const override {
133-
if (m_op->get_input_size() > 2) {
134-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), PortMask(2));
135-
}
136-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
134+
const auto port_mask = (m_op->get_input_size() > 2) ? PortMask(2) : EMPTY_PORT_MASK;
135+
return make_shape_inference(m_op, port_mask);
137136
}
138137
private:
139138
std::shared_ptr<ov::Node> m_op;

src/plugins/intel_cpu/src/nodes/eye.cpp

+3-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "openvino/op/eye.hpp"
77
#include <utils/bfloat16.hpp>
88
#include "openvino/core/parallel.hpp"
9-
#include "shape_inference/shape_inference_ngraph.hpp"
9+
#include "shape_inference/shape_inference.hpp"
1010
#include "utils/bfloat16.hpp"
1111

1212
#define THROW_ERROR(...) OPENVINO_THROW(NameFromType(getType()), " node with name '", getName(), "' ", __VA_ARGS__)
@@ -33,13 +33,8 @@ class EyeShapeInferFactory : public ShapeInferFactory {
3333
public:
3434
EyeShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
3535
ShapeInferPtr makeShapeInfer() const override {
36-
IShapeInfer::port_mask_t port_mask = EMPTY_PORT_MASK;
37-
if (m_op->get_input_size() == 4) {
38-
port_mask = PortMask(Eye::ROWS_NUM, Eye::COLS_NUM, Eye::DIAGONAL_INDEX, Eye::BATCH_SHAPE);
39-
} else {
40-
port_mask = PortMask(Eye::ROWS_NUM, Eye::COLS_NUM, Eye::DIAGONAL_INDEX);
41-
}
42-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
36+
return (m_op->get_input_size() == 4) ? make_shape_inference(m_op)
37+
: make_shape_inference(m_op, PortMask(Eye::ROWS_NUM, Eye::COLS_NUM));
4338
}
4439
private:
4540
std::shared_ptr<ov::Node> m_op;

src/plugins/intel_cpu/src/nodes/interpolate.cpp

+5-12
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "openvino/opsets/opset11.hpp"
2222
#include "openvino/opsets/opset4.hpp"
2323
#include "shape_inference/shape_inference.hpp"
24-
#include "shape_inference/shape_inference_ngraph.hpp"
2524
#include "shape_inference/static_shape.hpp"
2625
#include "utils/bfloat16.hpp"
2726
#include "utils/cpu_utils.hpp"
@@ -1763,27 +1762,21 @@ class InterpolateShapeInferFactory : public ShapeInferFactory {
17631762
public:
17641763
InterpolateShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
17651764
ShapeInferPtr makeShapeInfer() const override {
1766-
IShapeInfer::port_mask_t port_mask = 0x00;
17671765
if (auto interp4 = ov::as_type_ptr<ov::opset4::Interpolate>(m_op)) {
17681766
const auto &attr = interp4->get_attrs();
1769-
1770-
if (attr.shape_calculation_mode == ngInterpShapeCalcMode::SCALES) {
1771-
port_mask = PortMask(Interpolate::SCALES_ID, Interpolate::AXES_ID);
1772-
} else if (attr.shape_calculation_mode == ngInterpShapeCalcMode::SIZES) {
1773-
port_mask = PortMask(Interpolate::TARGET_SHAPE_ID, Interpolate::AXES_ID);
1774-
} else {
1775-
OPENVINO_ASSERT(false, "Unsupported interpolate shape calculation mode");
1776-
}
1767+
const auto is_supported_mode = (attr.shape_calculation_mode == ngInterpShapeCalcMode::SCALES) ||
1768+
(attr.shape_calculation_mode == ngInterpShapeCalcMode::SIZES);
1769+
OPENVINO_ASSERT(is_supported_mode, "Unsupported interpolate shape calculation mode");
1770+
return make_shape_inference(m_op);
17771771
} else if (auto interp11 = ov::as_type_ptr<ov::op::v11::Interpolate>(m_op)) {
1778-
port_mask = PortMask(Interpolate::SIZE_OR_SCALE_ID_V11, Interpolate::AXES_ID_V11);
1772+
return make_shape_inference(m_op);
17791773
} else {
17801774
OPENVINO_THROW("Shape infer factory cannot be created for ",
17811775
m_op->get_type_name(),
17821776
" node with name: ",
17831777
m_op->get_friendly_name(),
17841778
", only versions 4 and 11 are supported.");
17851779
}
1786-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
17871780
}
17881781

17891782
private:

src/plugins/intel_cpu/src/nodes/reference.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "reference.h"
66
#include "common/cpu_memcpy.h"
7-
#include "shape_inference/shape_inference_ngraph.hpp"
7+
#include "shape_inference/shape_inference.hpp"
88

99
namespace ov {
1010
namespace intel_cpu {
@@ -14,7 +14,7 @@ class ReferenceShapeInferFactory : public ShapeInferFactory {
1414
ReferenceShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op{std::move(op)} {}
1515

1616
ShapeInferPtr makeShapeInfer() const override {
17-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), FULL_PORT_MASK);
17+
return make_shape_inference(m_op, FULL_PORT_MASK);
1818
}
1919

2020
private:

src/plugins/intel_cpu/src/nodes/rnn.cpp

+27-15
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@
1111
#include "nodes/input.h"
1212
#include "nodes/reorder.h"
1313
#include "openvino/core/parallel.hpp"
14-
#include "shape_inference/shape_inference_ngraph.hpp"
15-
#include "transformations/utils/utils.hpp"
16-
17-
#include "ov_ops/augru_cell.hpp"
18-
#include "ov_ops/augru_sequence.hpp"
1914
#include "openvino/op/gru_cell.hpp"
2015
#include "openvino/op/gru_sequence.hpp"
2116
#include "openvino/op/lstm_sequence.hpp"
2217
#include "openvino/op/rnn_cell.hpp"
2318
#include "openvino/op/rnn_sequence.hpp"
19+
#include "ov_ops/augru_cell.hpp"
20+
#include "ov_ops/augru_sequence.hpp"
21+
#include "shape_inference/shape_inference.hpp"
22+
#include "transformations/utils/utils.hpp"
2423

2524
using namespace dnnl;
2625

2726

2827
namespace ov {
2928
namespace intel_cpu {
29+
3030
namespace node {
3131

3232
static rnn_direction ieDirection2dnnl(const std::shared_ptr<const ov::Node>& op) {
@@ -356,19 +356,17 @@ namespace {
356356
* dimentions permutation, necessary due to the mismatch between the ngrpah and the oneDNN RNN node descriptions.
357357
*
358358
*/
359-
class RnnShapeInfer : public NgraphShapeInfer {
359+
class RnnShapeInfer : public IShapeInfer {
360360
public:
361-
RnnShapeInfer(std::shared_ptr<ov::Node> op) :
362-
NgraphShapeInfer(make_shape_inference(op), EMPTY_PORT_MASK) {
363-
is_sequence = !(RNN::isCell(op));
364-
365-
native_order = RNN::testNativeOrder(op);
366-
}
361+
RnnShapeInfer(std::shared_ptr<ov::Node> op)
362+
: is_sequence(!(RNN::isCell(op))),
363+
native_order(RNN::testNativeOrder(op)),
364+
m_shape_infer(make_shape_inference(std::move(op))) {}
367365

368366
Result infer(
369367
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
370368
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
371-
auto result = NgraphShapeInfer::infer(input_shapes, data_dependency);
369+
auto result = m_shape_infer->infer(input_shapes, data_dependency);
372370
if (ShapeInferStatus::success != result.status) {
373371
OPENVINO_THROW("Unexpected: Unexpected shape inference result status");
374372
}
@@ -382,10 +380,24 @@ class RnnShapeInfer : public NgraphShapeInfer {
382380
return {std::move(originOutputShapes), result.status};
383381
}
384382

383+
const ov::CoordinateDiff& get_pads_begin() override {
384+
return m_shape_infer->get_pads_begin();
385+
}
386+
387+
const ov::CoordinateDiff& get_pads_end() override {
388+
return m_shape_infer->get_pads_end();
389+
}
390+
391+
port_mask_t get_port_mask() const override {
392+
return m_shape_infer->get_port_mask();
393+
}
394+
385395
private:
386-
bool is_sequence = false;
387-
bool native_order = true;
396+
bool is_sequence;
397+
bool native_order;
398+
ShapeInferPtr m_shape_infer;
388399
};
400+
389401
class RnnShapeInferFactory final : public ShapeInferFactory {
390402
public:
391403
RnnShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}

src/plugins/intel_cpu/src/nodes/strided_slice.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "common/cpu_memcpy.h"
99
#include "input.h"
1010
#include "openvino/opsets/opset1.hpp"
11-
#include "shape_inference/shape_inference_ngraph.hpp"
1211
#include "slice_shape_inference_utils.hpp"
1312
#include "shape_inference/custom/strided_slice.hpp"
1413

src/plugins/intel_cpu/src/shape_inference/custom/matmul.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "matmul.hpp"
66
#include "utils.hpp"
77
#include "openvino/opsets/opset1.hpp"
8+
#include "shape_inference/shape_inference.hpp"
89

910
namespace ov {
1011
namespace intel_cpu {
@@ -64,17 +65,17 @@ Result MMShapeInfer::infer(
6465

6566
ShapeInferPtr MMShapeInferFactory::makeShapeInfer() const {
6667
if (const auto matmul = ov::as_type_ptr<const ov::opset1::MatMul>(m_op)) {
67-
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
68-
const bool transpose_a = matmul->get_transpose_a();
69-
const bool transpose_b = matmul->get_transpose_b();
7068
const auto input_rank0 = matmul->get_input_partial_shape(0).rank().get_length();
7169
const auto input_rank1 = matmul->get_input_partial_shape(1).rank().get_length();
70+
7271
if (input_rank0 == input_rank1) {
72+
const auto output_rank = matmul->get_output_partial_shape(0).rank().get_length();
73+
const bool transpose_a = matmul->get_transpose_a();
74+
const bool transpose_b = matmul->get_transpose_b();
7375
return std::make_shared<MMShapeInfer>(output_rank, transpose_a, transpose_b);
7476
} else {
75-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
77+
return make_shape_inference(m_op);
7678
}
77-
7879
} else {
7980
OPENVINO_THROW("Unexpected operation type in the MatMul shape inference factory");
8081
}

src/plugins/intel_cpu/src/shape_inference/custom/matmul.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
//
44

55
#include <node.h>
6+
67
#include "shape_inference/shape_inference_cpu.hpp"
7-
#include "shape_inference/shape_inference_ngraph.hpp"
88

99
#pragma once
1010
namespace ov {
@@ -42,4 +42,3 @@ class MMShapeInferFactory : public ShapeInferFactory {
4242
} // namespace node
4343
} // namespace intel_cpu
4444
} // namespace ov
45-

src/plugins/intel_cpu/src/shape_inference/custom/scaled_attn.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
#include "scaled_attn.hpp"
66

7-
#include "shape_inference/shape_inference_cpu.hpp"
8-
#include "shape_inference/shape_inference_ngraph.hpp"
7+
#include "shape_inference/shape_inference.hpp"
98
#include "transformations/cpu_opset/common/op/sdpa.hpp"
109
#include "utils.hpp"
1110

@@ -78,7 +77,7 @@ ShapeInferPtr SDPAShapeInferFactory::makeShapeInfer() const {
7877
return std::make_shared<SDPAShapeInfer>(config);
7978
}
8079
// fallback to ngraph shape infer on non-perf-critical case
81-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), EMPTY_PORT_MASK);
80+
return make_shape_inference(m_op);
8281
}
8382

8483
} // namespace node

src/plugins/intel_cpu/src/shape_inference/custom/strided_slice.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "strided_slice.hpp"
66
#include "utils.hpp"
77
#include "slice_shape_inference.hpp"
8-
#include "shape_inference/shape_inference_ngraph.hpp"
8+
#include "shape_inference/shape_inference.hpp"
99

1010
namespace ov {
1111
namespace intel_cpu {
@@ -75,13 +75,13 @@ Result StridedSliceShapeInfer::infer(
7575

7676
ShapeInferPtr StridedSliceShapeInferFactory::makeShapeInfer() const {
7777
if (const auto Slice_op = ov::as_type_ptr<const ov::op::v8::Slice>(m_op)) {
78-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
78+
return make_shape_inference(m_op);
7979
} else if (const auto SliceScatter_op = ov::as_type_ptr<const ov::op::v15::SliceScatter>(m_op)) {
80-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), PortMask(2, 3, 4, 5));
80+
return make_shape_inference(m_op);
8181
} else if (const auto StridedSlice_op = ov::as_type_ptr<const ov::op::v1::StridedSlice>(m_op)) {
8282
const auto& ellipsis_mask = StridedSlice_op->get_ellipsis_mask();
8383
if (std::any_of(ellipsis_mask.begin(), ellipsis_mask.end(), [](int64_t x){ return x == 1; })) {
84-
return std::make_shared<NgraphShapeInfer>(make_shape_inference(m_op), port_mask);
84+
return make_shape_inference(m_op);
8585
} else {
8686
auto vec_to_set = [](const std::vector<int64_t>& vec){
8787
std::unordered_set<int64_t> to_set;

0 commit comments

Comments
 (0)