Skip to content

Commit 8ce44b4

Browse files
changed for comments
1 parent 345692c commit 8ce44b4

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

src/bindings/python/src/pyopenvino/utils/utils.cpp

+17-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ using Version = ov::pass::Serialize::Version;
2424
namespace Common {
2525
namespace utils {
2626

27-
PY_TYPE check_list_element_type(const py::list& list) {
27+
template <typename T>
28+
PY_TYPE check_list_element_type(const T& list) {
2829
PY_TYPE detected_type = PY_TYPE::UNKNOWN;
2930

3031
auto check_type = [&](PY_TYPE type) {
@@ -47,6 +48,8 @@ PY_TYPE check_list_element_type(const py::list& list) {
4748
check_type(PY_TYPE::BOOL);
4849
} else if (py::isinstance<ov::PartialShape>(it)) {
4950
check_type(PY_TYPE::PARTIAL_SHAPE);
51+
} else if (py::isinstance<ov::hint::ModelDistributionPolicy>(it)) {
52+
check_type(PY_TYPE::ModelDistributionPolicy);
5053
}
5154
}
5255

@@ -413,7 +416,7 @@ ov::Any py_object_to_any(const py::object& py_obj) {
413416
} else if (py::isinstance<py::list>(py_obj)) {
414417
auto _list = py_obj.cast<py::list>();
415418

416-
PY_TYPE detected_type = check_list_element_type(_list);
419+
PY_TYPE detected_type = check_list_element_type<py::list>(_list);
417420

418421
if (_list.empty())
419422
return ov::Any(EmptyList());
@@ -450,13 +453,19 @@ ov::Any py_object_to_any(const py::object& py_obj) {
450453
} else if (py::isinstance<ov::hint::SchedulingCoreType>(py_obj)) {
451454
return py::cast<ov::hint::SchedulingCoreType>(py_obj);
452455
} else if (py::isinstance<py::set>(py_obj)) {
453-
std::set<ov::hint::ModelDistributionPolicy> model_set;
454-
for (auto item = py_obj.begin(); item != py_obj.end(); item++) {
455-
if (py::isinstance<ov::hint::ModelDistributionPolicy>(*item)) {
456-
model_set.insert(py::cast<ov::hint::ModelDistributionPolicy>(*item));
457-
}
456+
auto _set = py_obj.cast<py::set>();
457+
458+
PY_TYPE detected_type = check_list_element_type<py::set>(_set);
459+
460+
if (_set.empty())
461+
return ov::Any(EmptyList());
462+
463+
switch (detected_type) {
464+
case PY_TYPE::ModelDistributionPolicy:
465+
return _set.cast<std::set<ov::hint::ModelDistributionPolicy>>();
466+
default:
467+
OPENVINO_ASSERT(false, "Unsupported attribute type.");
458468
}
459-
return model_set;
460469
} else if (py::isinstance<ov::hint::ExecutionMode>(py_obj)) {
461470
return py::cast<ov::hint::ExecutionMode>(py_obj);
462471
} else if (py::isinstance<ov::log::Level>(py_obj)) {

src/bindings/python/src/pyopenvino/utils/utils.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ class MemoryBuffer : public std::streambuf {
6363
}
6464
};
6565

66-
enum class PY_TYPE : int { UNKNOWN = 0, STR, INT, FLOAT, BOOL, PARTIAL_SHAPE };
66+
enum class PY_TYPE : int { UNKNOWN = 0, STR, INT, FLOAT, BOOL, PARTIAL_SHAPE, ModelDistributionPolicy };
6767

6868
struct EmptyList {};
6969

70-
PY_TYPE check_list_element_type(const py::list& list);
70+
template <typename T>
71+
PY_TYPE check_list_element_type(const T& list);
7172

7273
py::object from_ov_any_no_leaves(const ov::Any& any);
7374

0 commit comments

Comments
 (0)