Skip to content

Commit 56ac8ed

Browse files
committed
Python API: allow regular Dict[str, ?] to describe expected attributes
Signed-off-by: Evgeniia Nugmanova <evgeniia.nugmanova@intel.com>
1 parent 2cc7a82 commit 56ac8ed

File tree

4 files changed

+23
-3
lines changed

4 files changed

+23
-3
lines changed

src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "openvino/pass/pattern/op/pattern.hpp"
1919
#include "openvino/pass/pattern/op/wrap_type.hpp"
2020
#include "pyopenvino/core/common.hpp"
21+
#include "pyopenvino/utils/utils.hpp"
2122

2223
static ov::NodeTypeInfo get_type(const std::string& type_name) {
2324
// Supported types: opsetX.OpName or opsetX::OpName
@@ -1014,7 +1015,9 @@ inline void reg_predicates(py::module m) {
10141015
m.def("type_matches", &ov::pass::pattern::type_matches);
10151016
m.def("type_matches_any", &ov::pass::pattern::type_matches_any);
10161017
m.def("shape_matches", &ov::pass::pattern::shape_matches);
1017-
m.def("attrs_match", &ov::pass::pattern::attrs_match);
1018+
m.def("attrs_match", [](py::object& attrs) {
1019+
return ov::pass::pattern::attrs_match(Common::utils::py_object_to_unordered_any_map(attrs));
1020+
});
10181021
}
10191022

10201023
void reg_passes_pattern_ops(py::module m) {

src/bindings/python/src/pyopenvino/utils/utils.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,21 @@ ov::AnyMap py_object_to_any_map(const py::object& py_obj) {
404404
return return_value;
405405
}
406406

407+
std::unordered_map<std::string, ov::Any> py_object_to_unordered_any_map(const py::object& py_obj) {
408+
OPENVINO_ASSERT(py_object_is_any_map(py_obj), "Unsupported attribute type.");
409+
std::unordered_map<std::string, ov::Any> return_value = {};
410+
for (auto& item : py::cast<py::dict>(py_obj)) {
411+
std::string key = py::cast<std::string>(item.first);
412+
py::object value = py::cast<py::object>(item.second);
413+
if (py_object_is_any_map(value)) {
414+
return_value[key] = Common::utils::py_object_to_any_map(value);
415+
} else {
416+
return_value[key] = Common::utils::py_object_to_any(value);
417+
}
418+
}
419+
return return_value;
420+
}
421+
407422
ov::Any py_object_to_any(const py::object& py_obj) {
408423
// Python types
409424
py::object float_32_type = py::module_::import("numpy").attr("float32");

src/bindings/python/src/pyopenvino/utils/utils.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class MemoryBuffer : public std::streambuf {
9191

9292
ov::AnyMap py_object_to_any_map(const py::object& py_obj);
9393

94+
std::unordered_map<std::string, ov::Any> py_object_to_unordered_any_map(const py::object& py_obj);
95+
9496
ov::Any py_object_to_any(const py::object& py_obj);
9597

9698
ov::pass::Serialize::Version convert_to_version(const std::string& version);

src/bindings/python/tests/test_transformations/test_pattern_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66

7-
from openvino import PartialShape, Symbol, Dimension, OVAny
7+
from openvino import PartialShape, Symbol, Dimension
88
from openvino import opset13 as ops
99
from openvino.passes import Matcher, WrapType, Or, AnyInput, Optional
1010
from openvino.passes import (
@@ -284,7 +284,7 @@ def test_attrs_match():
284284

285285
def test_shape_of_attribute(et: str):
286286
node = ops.shape_of(param, output_type=et)
287-
attr = {"output_type": OVAny(et)}
287+
attr = {"output_type": et}
288288
matcher = Matcher(AnyInput(attrs_match(attr)), "Find shape_of with attribute")
289289
assert matcher.match(node), f"Match failed for {node} with attribute"
290290

0 commit comments

Comments
 (0)