@@ -24,7 +24,8 @@ using Version = ov::pass::Serialize::Version;
24
24
namespace Common {
25
25
namespace utils {
26
26
27
- PY_TYPE check_list_element_type (const py::list& list) {
27
+ template <typename T>
28
+ PY_TYPE check_list_element_type (const T& list) {
28
29
PY_TYPE detected_type = PY_TYPE::UNKNOWN;
29
30
30
31
auto check_type = [&](PY_TYPE type) {
@@ -47,6 +48,8 @@ PY_TYPE check_list_element_type(const py::list& list) {
47
48
check_type (PY_TYPE::BOOL);
48
49
} else if (py::isinstance<ov::PartialShape>(it)) {
49
50
check_type (PY_TYPE::PARTIAL_SHAPE);
51
+ } else if (py::isinstance<ov::hint::ModelDistributionPolicy>(it)) {
52
+ check_type (PY_TYPE::ModelDistributionPolicy);
50
53
}
51
54
}
52
55
@@ -413,7 +416,7 @@ ov::Any py_object_to_any(const py::object& py_obj) {
413
416
} else if (py::isinstance<py::list>(py_obj)) {
414
417
auto _list = py_obj.cast <py::list>();
415
418
416
- PY_TYPE detected_type = check_list_element_type (_list);
419
+ PY_TYPE detected_type = check_list_element_type<py::list> (_list);
417
420
418
421
if (_list.empty ())
419
422
return ov::Any (EmptyList ());
@@ -450,13 +453,19 @@ ov::Any py_object_to_any(const py::object& py_obj) {
450
453
} else if (py::isinstance<ov::hint::SchedulingCoreType>(py_obj)) {
451
454
return py::cast<ov::hint::SchedulingCoreType>(py_obj);
452
455
} else if (py::isinstance<py::set>(py_obj)) {
453
- std::set<ov::hint::ModelDistributionPolicy> model_set;
454
- for (auto item = py_obj.begin (); item != py_obj.end (); item++) {
455
- if (py::isinstance<ov::hint::ModelDistributionPolicy>(*item)) {
456
- model_set.insert (py::cast<ov::hint::ModelDistributionPolicy>(*item));
457
- }
456
+ auto _set = py_obj.cast <py::set>();
457
+
458
+ PY_TYPE detected_type = check_list_element_type<py::set>(_set);
459
+
460
+ if (_set.empty ())
461
+ return ov::Any (EmptyList ());
462
+
463
+ switch (detected_type) {
464
+ case PY_TYPE::ModelDistributionPolicy:
465
+ return _set.cast <std::set<ov::hint::ModelDistributionPolicy>>();
466
+ default :
467
+ OPENVINO_ASSERT (false , " Unsupported attribute type." );
458
468
}
459
- return model_set;
460
469
} else if (py::isinstance<ov::hint::ExecutionMode>(py_obj)) {
461
470
return py::cast<ov::hint::ExecutionMode>(py_obj);
462
471
} else if (py::isinstance<ov::log ::Level>(py_obj)) {
0 commit comments