Skip to content

Commit 27a3805

Browse files
kcin96praaszakuporos
authored
Add get_sink_index method to Python API. (#25415)
### Details: - Adds `get_sink_index` method to the Python API. - Adds tests. ### Tickets: - 131043 This PR links issue #25116 --------- Co-authored-by: Pawel Raasz <pawel.raasz@intel.com> Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>
1 parent 8c58aa1 commit 27a3805

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

src/bindings/python/src/pyopenvino/graph/model.cpp

+81
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,18 @@ static std::unordered_map<std::string, ov::PartialShape> get_variables_shapes(co
128128
return variables_shape_map;
129129
}
130130

131+
template <typename T>
132+
static int64_t find_sink_position(const ov::SinkVector& sinks, const std::shared_ptr<T>& sink) {
133+
int64_t pos = 0;
134+
for (const auto& s : sinks) {
135+
if (s == sink) {
136+
return pos;
137+
}
138+
pos++;
139+
};
140+
return -1;
141+
}
142+
131143
void regclass_graph_Model(py::module m) {
132144
py::class_<ov::Model, std::shared_ptr<ov::Model>> model(m, "Model", py::module_local());
133145
model.doc() = "openvino.runtime.Model wraps ov::Model";
@@ -750,6 +762,75 @@ void regclass_graph_Model(py::module m) {
750762
:rtype: int
751763
)");
752764

765+
model.def(
766+
"get_sink_index",
767+
[](ov::Model& self, const ov::Output<ov::Node>& value) -> int64_t {
768+
auto node = value.get_node_shared_ptr();
769+
if (ov::is_type<ov::op::v6::Assign>(node)) {
770+
return find_sink_position(self.get_sinks(), std::dynamic_pointer_cast<ov::op::Sink>(node));
771+
} else {
772+
throw py::type_error("Incorrect argument type. Output sink node is expected as argument.");
773+
}
774+
},
775+
py::arg("value"),
776+
R"(
777+
Return index of sink.
778+
779+
Return -1 if `value` not matched.
780+
781+
:param value: Output sink node handle
782+
:type value: openvino.runtime.Output
783+
:return: Index of sink node referenced by output handle.
784+
:rtype: int
785+
)");
786+
787+
model.def(
788+
"get_sink_index",
789+
[](ov::Model& self, const ov::Output<const ov::Node>& value) -> int64_t {
790+
auto node = value.get_node_shared_ptr();
791+
if (ov::is_type<ov::op::v6::Assign>(node)) {
792+
return find_sink_position(self.get_sinks(), std::dynamic_pointer_cast<const ov::op::Sink>(node));
793+
} else {
794+
throw py::type_error("Incorrect argument type. Output sink node is expected as argument.");
795+
}
796+
},
797+
py::arg("value"),
798+
R"(
799+
Return index of sink.
800+
801+
Return -1 if `value` not matched.
802+
803+
:param value: Output sink node handle
804+
:type value: openvino.runtime.Output
805+
:return: Index of sink node referenced by output handle.
806+
:rtype: int
807+
)");
808+
809+
model.def(
810+
"get_sink_index",
811+
[](ov::Model& self, const py::object& node) -> int64_t {
812+
if (py::isinstance<ov::op::v6::Assign>(node)) {
813+
auto sink = std::dynamic_pointer_cast<ov::op::Sink>(node.cast<std::shared_ptr<ov::op::v6::Assign>>());
814+
return find_sink_position(self.get_sinks(), sink);
815+
} else if (py::isinstance<ov::Node>(node)) {
816+
auto sink = std::dynamic_pointer_cast<ov::op::Sink>(node.cast<std::shared_ptr<ov::Node>>());
817+
return find_sink_position(self.get_sinks(), sink);
818+
} else {
819+
throw py::type_error("Incorrect argument type. Sink node is expected as argument.");
820+
}
821+
},
822+
py::arg("sink"),
823+
R"(
824+
Return index of sink node.
825+
826+
Return -1 if `sink` not matched.
827+
828+
:param sink: Sink node.
829+
:type sink: openvino.runtime.Node
830+
:return: Index of sink node.
831+
:rtype: int
832+
)");
833+
753834
model.def("get_name",
754835
&ov::Model::get_name,
755836
R"(

src/bindings/python/tests/test_runtime/test_model.py

+29
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,35 @@ def test_replace_parameter():
204204
assert model.get_parameter_index(param1) == -1
205205

206206

207+
def test_get_sink_index(device):
208+
input_shape = PartialShape([2, 2])
209+
param = ops.parameter(input_shape, dtype=np.float64, name="data")
210+
relu1 = ops.relu(param, name="relu1")
211+
relu1.get_output_tensor(0).set_names({"relu_t1"})
212+
model = Model(relu1, [param], "TestModel")
213+
214+
# test get_sink_index with openvino.runtime.Node argument
215+
assign = ops.assign()
216+
assign2 = ops.assign()
217+
assign3 = ops.assign()
218+
model.add_sinks([assign, assign2, assign3])
219+
assign_nodes = model.sinks
220+
assert model.get_sink_index(assign_nodes[2]) == 2
221+
assert model.get_sink_index(relu1) == -1
222+
223+
# test get_sink_index with openvino.runtime.Output argument
224+
assign4 = ops.assign(relu1, "assign4")
225+
model.add_sinks([assign4])
226+
assert model.get_sink_index(assign4.output(0)) == 3
227+
228+
# check exceptions
229+
with pytest.raises(TypeError) as e:
230+
model.get_sink_index(0)
231+
assert (
232+
"Incorrect argument type. Sink node is expected as argument." in str(e.value)
233+
)
234+
235+
207236
@pytest.mark.parametrize(("args1", "args2", "expectation", "raise_msg"), [
208237
(Tensor("float32", Shape([2, 1])),
209238
[Tensor(np.array([2, 1], dtype=np.float32).reshape(2, 1)),

0 commit comments

Comments
 (0)