Skip to content

Commit 1a6536a

Browse files
authored
[Common, PT, TF FE] Introduce common translators to support complex tensors in PyTorch (#29150)
**Details:** It introduces new component with common translators for PyTorch and TensorFlow. This common translators allow to support complex tensors for operations: add, mul, sub, atan2, angle, complex, real and imag. It helps to support RoPE implementation based on complex tensor computation, Kokoru model. **Ticket:** 162659 --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent 138699e commit 1a6536a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1453
-443
lines changed

cmake/developer_package/frontends/frontends.cmake

+2-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ macro(ov_add_frontend)
228228
ov_add_vs_version_file(NAME ${TARGET_NAME}
229229
FILEDESCRIPTION ${OV_FRONTEND_FILEDESCRIPTION})
230230

231-
target_link_libraries(${TARGET_NAME} PRIVATE ${OV_FRONTEND_LINK_LIBRARIES} PUBLIC openvino::runtime)
231+
target_link_libraries(${TARGET_NAME} PRIVATE ${OV_FRONTEND_LINK_LIBRARIES} openvino::frontend::common_translators
232+
PUBLIC openvino::runtime)
232233
ov_add_library_version(${TARGET_NAME})
233234

234235
if(OV_FRONTEND_PROTOBUF_REQUIRED)

src/frontends/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
add_subdirectory(common)
66

7+
add_subdirectory(common_translators)
8+
79
if(ENABLE_TESTS)
810
add_subdirectory(tests)
911
endif()

src/frontends/common/include/openvino/frontend/complex_type_mark.hpp

+68-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#pragma once
66

77
#include "openvino/core/type/element_type.hpp"
8+
#include "openvino/frontend/exception.hpp"
9+
#include "openvino/frontend/node_context.hpp"
810
#include "openvino/frontend/visibility.hpp"
911
#include "openvino/op/util/framework_node.hpp"
1012

@@ -22,11 +24,12 @@ class FRONTEND_API ComplexTypeMark : public ov::op::util::FrameworkNode {
2224
public:
2325
OPENVINO_OP("ComplexTypeMark", "util", ov::op::util::FrameworkNode);
2426

25-
ComplexTypeMark(const ov::Output<ov::Node>& input, const ov::element::Type& complex_part_type)
26-
: ov::op::util::FrameworkNode(ov::OutputVector{input}, 1),
27-
m_complex_part_type(complex_part_type) {
28-
validate_and_infer_types();
29-
}
27+
ComplexTypeMark(const ov::Output<ov::Node>& input,
28+
const ov::element::Type& complex_part_type = ov::element::dynamic);
29+
30+
ComplexTypeMark(const ov::Output<ov::Node>& real,
31+
const ov::Output<ov::Node>& imag,
32+
const ov::element::Type& complex_part_type = ov::element::dynamic);
3033

3134
~ComplexTypeMark() override;
3235

@@ -44,8 +47,68 @@ class FRONTEND_API ComplexTypeMark : public ov::op::util::FrameworkNode {
4447
return m_complex_part_type;
4548
}
4649

50+
// Get a real part of the complex tensor
51+
ov::Output<ov::Node> get_real(bool squeezed = true);
52+
53+
// Get an imaginary part of the complex tensor
54+
ov::Output<ov::Node> get_imag(bool squeezed = true);
55+
56+
// Get floating-point representation of the complex tensor
57+
ov::Output<ov::Node> get_data();
58+
59+
// Compute summation of two operands that can be of complex type
60+
// if operand is of complex type, complex type will be indicated by bool flag
61+
// complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
62+
// types of both operands must be aligned prior to the call
63+
static ov::Output<ov::Node> add(const NodeContext& context,
64+
const ov::Output<ov::Node>& lhs,
65+
const ov::Output<ov::Node>& rhs);
66+
67+
// Compute subtraction of two operands that can be of complex type
68+
// if operand is of complex type, complex type will be indicated by bool flag
69+
// complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
70+
// types of both operands must be aligned prior to the call
71+
static ov::Output<ov::Node> sub(const NodeContext& context,
72+
const ov::Output<ov::Node>& lhs,
73+
const ov::Output<ov::Node>& rhs);
74+
75+
// Compute multiplication of two operands that can be of complex type
76+
// if operand is of complex type, complex type will be indicated by bool flag
77+
// complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
78+
// types of both operands must be aligned prior to the call
79+
static ov::Output<ov::Node> mul(const NodeContext& context,
80+
const ov::Output<ov::Node>& lhs,
81+
const ov::Output<ov::Node>& rhs);
82+
83+
// Compute inverse of operand that can be of complex type
84+
// if operand is of complex type, complex type will be indicated by bool flag
85+
// complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
86+
static ov::Output<ov::Node> inv(const NodeContext& context, const ov::Output<ov::Node>& data);
87+
88+
// Compute division of two operands that can be of complex type
89+
// if operand is of complex type, complex type will be indicated by bool flag
90+
// complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
91+
// types of both operands must be aligned prior to the call
92+
static ov::Output<ov::Node> div(const NodeContext& context,
93+
const ov::Output<ov::Node>& lhs,
94+
const ov::Output<ov::Node>& rhs);
95+
96+
// Convert type of real and imaginary parts of input to like type
97+
static ov::Output<ov::Node> convert_like(const NodeContext& context,
98+
const ov::Output<ov::Node>& input,
99+
const ov::Output<ov::Node>& like);
100+
47101
private:
48102
ov::element::Type m_complex_part_type;
103+
104+
// floating-point tensor that represents complex tensor
105+
ov::Output<ov::Node> m_data;
106+
107+
// real part of the complex tensor in squeezed form (no auxiliary dimension)
108+
ov::Output<ov::Node> m_real;
109+
110+
// imaginary part of the complex tensor in squeezed form (no auxiliary dimension)
111+
ov::Output<ov::Node> m_imag;
49112
};
50113

51114
} // namespace frontend

src/frontends/common/include/openvino/frontend/node_context.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,21 @@ class FRONTEND_API NodeContext {
120120
FRONT_END_NOT_IMPLEMENTED(get_subgraph);
121121
}
122122

123+
/// \brief Returns Node object that can be with updated attributes
124+
/// such node name, runtime info, etc.
125+
/// By default, it returns the same node without update
126+
virtual std::shared_ptr<Node> mark_node(std::shared_ptr<Node> ov_node) const {
127+
return ov_node;
128+
}
129+
130+
/// \brief PyTorch may have None inputs coming to operations
131+
/// Other frontends do not have it per our observation
132+
virtual bool input_is_none(size_t index) const {
133+
auto num_inputs = get_input_size();
134+
FRONT_END_GENERAL_CHECK(index < num_inputs, "Input index is out of allowed indices range");
135+
return false;
136+
}
137+
123138
private:
124139
virtual ov::Any apply_additional_conversion_rules(const ov::Any& data, const std::type_info& type_info) const {
125140
return data;

0 commit comments

Comments
 (0)