5
5
#pragma once
6
6
7
7
#include " openvino/core/type/element_type.hpp"
8
+ #include " openvino/frontend/exception.hpp"
9
+ #include " openvino/frontend/node_context.hpp"
8
10
#include " openvino/frontend/visibility.hpp"
9
11
#include " openvino/op/util/framework_node.hpp"
10
12
@@ -22,11 +24,12 @@ class FRONTEND_API ComplexTypeMark : public ov::op::util::FrameworkNode {
22
24
public:
23
25
OPENVINO_OP (" ComplexTypeMark" , " util" , ov::op::util::FrameworkNode);
24
26
25
- ComplexTypeMark (const ov::Output<ov::Node>& input, const ov::element::Type& complex_part_type)
26
- : ov::op::util::FrameworkNode(ov::OutputVector{input}, 1 ),
27
- m_complex_part_type (complex_part_type) {
28
- validate_and_infer_types ();
29
- }
27
+ ComplexTypeMark (const ov::Output<ov::Node>& input,
28
+ const ov::element::Type& complex_part_type = ov::element::dynamic);
29
+
30
+ ComplexTypeMark (const ov::Output<ov::Node>& real,
31
+ const ov::Output<ov::Node>& imag,
32
+ const ov::element::Type& complex_part_type = ov::element::dynamic);
30
33
31
34
~ComplexTypeMark () override ;
32
35
@@ -44,8 +47,68 @@ class FRONTEND_API ComplexTypeMark : public ov::op::util::FrameworkNode {
44
47
return m_complex_part_type;
45
48
}
46
49
50
+ // Get a real part of the complex tensor
51
+ ov::Output<ov::Node> get_real (bool squeezed = true );
52
+
53
+ // Get an imaginary part of the complex tensor
54
+ ov::Output<ov::Node> get_imag (bool squeezed = true );
55
+
56
+ // Get floating-point representation of the complex tensor
57
+ ov::Output<ov::Node> get_data ();
58
+
59
+ // Compute summation of two operands that can be of complex type
60
+ // if operand is of complex type, complex type will be indicated by bool flag
61
+ // complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
62
+ // types of both operands must be aligned prior to the call
63
+ static ov::Output<ov::Node> add (const NodeContext& context,
64
+ const ov::Output<ov::Node>& lhs,
65
+ const ov::Output<ov::Node>& rhs);
66
+
67
+ // Compute subtraction of two operands that can be of complex type
68
+ // if operand is of complex type, complex type will be indicated by bool flag
69
+ // complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
70
+ // types of both operands must be aligned prior to the call
71
+ static ov::Output<ov::Node> sub (const NodeContext& context,
72
+ const ov::Output<ov::Node>& lhs,
73
+ const ov::Output<ov::Node>& rhs);
74
+
75
+ // Compute multiplication of two operands that can be of complex type
76
+ // if operand is of complex type, complex type will be indicated by bool flag
77
+ // complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
78
+ // types of both operands must be aligned prior to the call
79
+ static ov::Output<ov::Node> mul (const NodeContext& context,
80
+ const ov::Output<ov::Node>& lhs,
81
+ const ov::Output<ov::Node>& rhs);
82
+
83
+ // Compute inverse of operand that can be of complex type
84
+ // if operand is of complex type, complex type will be indicated by bool flag
85
+ // complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
86
+ static ov::Output<ov::Node> inv (const NodeContext& context, const ov::Output<ov::Node>& data);
87
+
88
+ // Compute division of two operands that can be of complex type
89
+ // if operand is of complex type, complex type will be indicated by bool flag
90
+ // complex tensor is represented as a real tensor with auxiliary dimension 2 in the tail
91
+ // types of both operands must be aligned prior to the call
92
+ static ov::Output<ov::Node> div (const NodeContext& context,
93
+ const ov::Output<ov::Node>& lhs,
94
+ const ov::Output<ov::Node>& rhs);
95
+
96
+ // Convert type of real and imaginary parts of input to like type
97
+ static ov::Output<ov::Node> convert_like (const NodeContext& context,
98
+ const ov::Output<ov::Node>& input,
99
+ const ov::Output<ov::Node>& like);
100
+
47
101
private:
48
102
ov::element::Type m_complex_part_type;
103
+
104
+ // floating-point tensor that represents complex tensor
105
+ ov::Output<ov::Node> m_data;
106
+
107
+ // real part of the complex tensor in squeezed form (no auxiliary dimension)
108
+ ov::Output<ov::Node> m_real;
109
+
110
+ // imaginary part of the complex tensor in squeezed form (no auxiliary dimension)
111
+ ov::Output<ov::Node> m_imag;
49
112
};
50
113
51
114
} // namespace frontend
0 commit comments