Skip to content

Commit d018779

Browse files
mvafineaidova
andauthored
[PT FE] torch.export support (openvinotoolkit#22397)
* [PT FE] torch.export support * Apply code style * Fix build * Support torch <2.2 * Support more operations * Update tests/model_hub_tests/torch_tests/torch_utils.py * Support for model operations * Support is_causal as kwarg for SDPA * Update src/frontends/pytorch/src/op/addcmul.cpp * Update tests/model_hub_tests/torch_tests/test_timm.py * Support only decoder passed to convert_model * Fix tests * Update src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py * Update src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com> * Apply suggestions from code review * Apply suggestions from code review * Improve testing with caching model * Apply suggestions from code review --------- Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com>
1 parent 688279a commit d018779

File tree

86 files changed

+1282
-283
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1282
-283
lines changed

.github/workflows/job_python_unit_tests.yml

+9
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,15 @@ jobs:
216216
TEST_DEVICE: CPU
217217
TEST_PRECISION: FP32
218218

219+
- name: PyTorch torch.export Layer Tests
220+
if: ${{ fromJSON(inputs.affected-components).PyTorch_FE.test && runner.arch != 'ARM64' }} # Ticket: 126287
221+
run: |
222+
python3 -m pytest ${LAYER_TESTS_INSTALL_DIR}/pytorch_tests -m precommit_torch_export --junitxml=${INSTALL_TEST_DIR}/TEST-pytorch.xml
223+
env:
224+
TEST_DEVICE: CPU
225+
TEST_PRECISION: FP32
226+
PYTORCH_TRACING_MODE: EXPORT
227+
219228
- name: PyTorch torch.compile TORCHFX Layer Tests
220229
if: ${{ fromJSON(inputs.affected-components).PyTorch_FE.test && runner.os != 'macOS' }}
221230
run: |

src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

+149-70
Large diffs are not rendered by default.

src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py

+12
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def may_produce_alias(self, in_index: int, out_index: int) -> bool:
406406
def inlined_inputs(self, index):
407407
return []
408408

409+
def inlined_input(self, index):
410+
return []
411+
412+
def is_input_inlined(self, index):
413+
return False
414+
415+
def get_attribute(self, name):
416+
return OVAny(None)
417+
418+
def get_named_input(self, name):
419+
raise RuntimeError("There is no named inputs in TS graph")
420+
409421
@staticmethod
410422
def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph):
411423
# Function replaces prim::Constant containing List of Tensors with

src/bindings/python/src/openvino/frontend/pytorch/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,13 @@ def graph_has_ops(graph, op_types:list) -> bool:
137137
"torch.bool": OVType.boolean,
138138
"torch.DoubleTensor": OVType.f64,
139139
"torch.FloatTensor": OVType.f32,
140+
"torch.HalfTensor": OVType.f16,
141+
"torch.BFloat16Tensor": OVType.bf16,
140142
"torch.IntTensor": OVType.i32,
141143
"torch.LongTensor": OVType.i64,
144+
"torch.ShortTensor": OVType.i16,
145+
"torch.CharTensor": OVType.i8,
146+
"torch.ByteTensor": OVType.u8,
142147
"torch.BoolTensor": OVType.boolean,
143148
"torch.quint8": OVType.u8,
144149
"torch.qint8": OVType.i8,

src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,21 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
110110
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, may_produce_alias, in_index, out_index);
111111
}
112112

113-
ov::OutputVector inlined_inputs(size_t start_index) const override {
114-
PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, inlined_inputs, start_index); }
113+
ov::OutputVector inlined_input(size_t index) const override {
114+
PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, inlined_input, index);
115+
}
116+
117+
bool is_input_inlined(size_t index) const override {
118+
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, is_input_inlined, index);
119+
}
120+
121+
ov::Any get_attribute(const std::string &name) const override{
122+
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_attribute, name);
123+
}
124+
125+
size_t get_named_input(const std::string &name) const override{
126+
PYBIND11_OVERRIDE_PURE(size_t, TorchDecoder, get_named_input, name);
127+
}
115128

116129
const std::string& decoder_type_name() const override {
117130
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, decoder_type_name);

src/bindings/python/src/pyopenvino/utils/utils.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ ov::Any py_object_to_any(const py::object& py_obj) {
375375
return py::cast<ov::Affinity>(py_obj);
376376
} else if (py::isinstance<ov::Tensor>(py_obj)) {
377377
return py::cast<ov::Tensor>(py_obj);
378+
} else if (py::isinstance<ov::Output<ov::Node>>(py_obj)) {
379+
return py::cast<ov::Output<ov::Node>>(py_obj);
378380
// FrontEnd Decoder
379381
} else if (py::isinstance<ov::frontend::IDecoder>(py_obj)) {
380382
return py::cast<std::shared_ptr<ov::frontend::IDecoder>>(py_obj);

src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TorchDecoder : public IDecoder {
1919
// fundamental types like int, float etc.
2020
virtual Any const_input(size_t index) const = 0;
2121

22-
// Using size_t for input/output unuque ids are in sync with torch code, see def in
22+
// Using size_t for input/output unique ids are in sync with torch code, see def in
2323
// torch/include/torch/csrc/jit/ir/ir.h, Value::unique_
2424

2525
// TODO: set of input and output methods are not aligned; also they are not aligned with the rest of FEs
@@ -89,7 +89,7 @@ class TorchDecoder : public IDecoder {
8989
virtual size_t output(size_t index) const = 0;
9090

9191
// Embed mapping to/from the original node representation from/to node passed as a parameter
92-
// the representation of this mapping is specific for particular decored type and may be NOP
92+
// the representation of this mapping is specific for particular decorated type and may be NOP
9393
// returns the same node as syntactically convenient way to make nested sentences in code
9494
virtual std::shared_ptr<Node> mark_node(std::shared_ptr<Node> ov_node) const = 0;
9595

@@ -109,9 +109,19 @@ class TorchDecoder : public IDecoder {
109109

110110
/// Returns new nodes for inputs inlined in the op itself
111111
// Used in Torch.FX decoder
112-
virtual OutputVector inlined_inputs(size_t start_index) const = 0;
112+
virtual OutputVector inlined_input(size_t index) const = 0;
113113

114-
/// Returns the id of the deccoder type (0: TorchFX, 1: TorchScript)
114+
/// Returns if input is inlined
115+
// Used in Torch.FX decoder
116+
virtual bool is_input_inlined(size_t index) const = 0;
117+
118+
/// Returns named attribute as Any. For example kwargs input for FX graph
119+
virtual ov::Any get_attribute(const std::string& name) const = 0;
120+
121+
/// Returns index of named input. For example kwargs input for FX graph
122+
virtual size_t get_named_input(const std::string& name) const = 0;
123+
124+
/// Returns the id of the decoder type ("fx": TorchFX, "ts": TorchScript)
115125
virtual const std::string& decoder_type_name() const = 0;
116126
};
117127

src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp

+27-4
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,36 @@ class NodeContext : public frontend::NodeContext {
5151
// Search for input in tensor map and return an output port for already converted op
5252
// TODO: int due to base class uses it, but naturally it should be size_t for PT
5353
Output<Node> get_input(int index) const override {
54-
FRONT_END_GENERAL_CHECK(!input_is_none(index), "Input is none with index: ", index);
54+
size_t index_ = static_cast<size_t>(index);
55+
FRONT_END_GENERAL_CHECK(!input_is_none(index_), "Input doesn't exist with index: ", index);
5556
auto input = m_decoder_inputs.at(index);
57+
if (input == 0) {
58+
// Case when input can be inlined (possible only for fx decoder)
59+
if (m_decoder->is_input_inlined(index_)) {
60+
auto inlined_input = m_decoder->inlined_input(index_);
61+
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1, "Incorrect inlined input with index:", index);
62+
return inlined_input[0];
63+
}
64+
}
5665
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
5766
return m_tensor_map->at(input);
5867
}
5968

69+
Output<Node> get_input(const std::string& name) const override {
70+
FRONT_END_GENERAL_CHECK(has_attribute(name), "Input with name ", name, " doesn't exist");
71+
auto attr = get_attribute_as_any(name);
72+
if (attr.is<Output<Node>>()) {
73+
// Case when input is constant value
74+
return attr.as<Output<Node>>();
75+
} else if (attr.is<type::PyNone>()) {
76+
// None means input is unknown type, most likely a Node
77+
auto input = m_decoder->get_named_input(name);
78+
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
79+
return m_tensor_map->at(input);
80+
}
81+
FRONT_END_GENERAL_CHECK(false, "Input has type which can't be converted to ov::Node.");
82+
}
83+
6084
Any get_values_from_const_input(int index) const override;
6185

6286
// TODO: upstream to base class
@@ -112,9 +136,8 @@ class NodeContext : public frontend::NodeContext {
112136
return ov_output;
113137
}
114138

115-
Any get_attribute_as_any(const std::string&) const override {
116-
throw std::runtime_error(
117-
"There is no any named attributes in PyTorch node, query by attribute name is not implemented");
139+
Any get_attribute_as_any(const std::string& name) const override {
140+
return m_decoder->get_attribute(name);
118141
}
119142

120143
void mutate_input(size_t index, Output<Node> ov_output) const;

src/frontends/pytorch/src/frontend.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
194194
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
195195
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
196196
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReplacer>();
197+
// TODO: remove AtenIndexToSelect when problem with dynamic input rank is gone.
197198
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
198199
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexPutReplacer>();
199200
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();

src/frontends/pytorch/src/node_context.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ Output<Node> NodeContext::get_tensor_from_model_or_create_input(size_t index) co
109109
}
110110

111111
Output<Node> NodeContext::get_input_from_visible_context(size_t index) const {
112-
FRONT_END_GENERAL_CHECK(index < get_input_size(), "Index is lower then number of inputs.");
112+
FRONT_END_GENERAL_CHECK(index < get_input_size(), "Index ", index, " is lower then number of inputs.");
113113
auto input_tensor = get_input(static_cast<int>(index));
114114
auto input_node = input_tensor.get_node_shared_ptr();
115115
if (std::dynamic_pointer_cast<v0::Parameter>(input_node)) {

src/frontends/pytorch/src/op/adaptive_poolnd.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ OutputVector translate_adaptive_max_pool1d(const NodeContext& context) {
117117
return translate_adaptive_max_pool_base(context, const_tile_params, const_neg_1);
118118
};
119119

120+
OutputVector translate_adaptive_max_pool3d_fx(const NodeContext& context) {
121+
auto outs = translate_adaptive_max_pool3d(context);
122+
return {context.mark_node(make_list_construct(outs))};
123+
};
124+
125+
OutputVector translate_adaptive_max_pool2d_fx(const NodeContext& context) {
126+
auto outs = translate_adaptive_max_pool2d(context);
127+
return {context.mark_node(make_list_construct(outs))};
128+
};
129+
130+
OutputVector translate_adaptive_max_pool1d_fx(const NodeContext& context) {
131+
auto outs = translate_adaptive_max_pool1d(context);
132+
return {context.mark_node(make_list_construct(outs))};
133+
};
134+
120135
} // namespace op
121136
} // namespace pytorch
122137
} // namespace frontend

src/frontends/pytorch/src/op/add.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,14 @@ OutputVector translate_add_common(const NodeContext& context, bool inplace) {
3434
} else {
3535
align_eltwise_input_types(context, lhs, rhs, true);
3636
}
37+
Output<Node> alpha;
3738
if (!context.input_is_none(2)) {
38-
auto converted_alpha = context.mark_node(std::make_shared<v1::ConvertLike>(context.get_input(2), rhs));
39+
alpha = context.get_input(2);
40+
} else if (context.has_attribute("alpha")) {
41+
alpha = context.get_attribute<Output<Node>>("alpha");
42+
}
43+
if (alpha.get_node_shared_ptr()) {
44+
auto converted_alpha = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, rhs));
3945
rhs = context.mark_node(std::make_shared<v1::Multiply>(converted_alpha, rhs));
4046
}
4147
auto add = context.mark_node(std::make_shared<v1::Add>(lhs, rhs));

src/frontends/pytorch/src/op/addcmul.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,30 @@ namespace op {
1717

1818
using namespace ov::op;
1919

20-
OutputVector translate_addcmul(const NodeContext& context) {
21-
num_inputs_check(context, 4, 4);
20+
namespace {
21+
OutputVector addcmul_common(const NodeContext& context, const Output<Node>& value) {
2222
const auto eltwise_mult = std::make_shared<v1::Multiply>(context.get_input(1), context.get_input(2));
23-
const auto value = context.get_input(3);
2423
const auto converted_value = std::make_shared<v1::ConvertLike>(value, context.get_input(1));
2524
const auto scalar_mult = std::make_shared<v1::Multiply>(eltwise_mult, converted_value);
2625
context.mark_nodes({eltwise_mult, converted_value, scalar_mult});
2726
return {context.mark_node(std::make_shared<v1::Add>(context.get_input(0), scalar_mult))};
2827
};
28+
} // namespace
29+
30+
OutputVector translate_addcmul(const NodeContext& context) {
31+
num_inputs_check(context, 4, 4);
32+
const auto value = context.get_input(3);
33+
return addcmul_common(context, value);
34+
};
35+
36+
OutputVector translate_addcmul_fx(const NodeContext& context) {
37+
num_inputs_check(context, 3, 3);
38+
Output<Node> value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
39+
if (context.has_attribute("value")) {
40+
value = context.get_input("value");
41+
}
42+
return addcmul_common(context, value);
43+
};
2944

3045
} // namespace op
3146
} // namespace pytorch

src/frontends/pytorch/src/op/addmm.cpp

+27-7
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,22 @@ namespace op {
1616

1717
using namespace ov::op;
1818

19-
OutputVector translate_addmm(const NodeContext& context) {
20-
num_inputs_check(context, 3, 5);
19+
namespace {
20+
OutputVector translate_addmm_common(const NodeContext& context, const Output<Node> beta, const Output<Node> alpha) {
2121
auto input = context.get_input(0);
2222
auto m1 = context.get_input(1);
2323
auto m2 = context.get_input(2);
2424
auto mm = context.mark_node(std::make_shared<v0::MatMul>(m1, m2));
25+
auto beta_converted = context.mark_node(std::make_shared<v1::ConvertLike>(beta, input));
26+
auto alpha_converted = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, mm));
27+
auto input_beta = context.mark_node(std::make_shared<v1::Multiply>(input, beta_converted));
28+
auto mm_alpha = context.mark_node(std::make_shared<v1::Multiply>(mm, alpha_converted));
29+
return {context.mark_node(std::make_shared<v1::Add>(input_beta, mm_alpha))};
30+
};
31+
} // namespace
32+
33+
OutputVector translate_addmm(const NodeContext& context) {
34+
num_inputs_check(context, 3, 5);
2535
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
2636
ov::Output<Node> alpha = one;
2737
ov::Output<Node> beta = one;
@@ -31,11 +41,21 @@ OutputVector translate_addmm(const NodeContext& context) {
3141
if (!context.input_is_none(4)) {
3242
alpha = context.get_input(4);
3343
}
34-
auto beta_converted = context.mark_node(std::make_shared<v1::ConvertLike>(beta, input));
35-
auto alpha_converted = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, mm));
36-
auto input_beta = context.mark_node(std::make_shared<v1::Multiply>(input, beta_converted));
37-
auto mm_alpha = context.mark_node(std::make_shared<v1::Multiply>(mm, alpha_converted));
38-
return {context.mark_node(std::make_shared<v1::Add>(input_beta, mm_alpha))};
44+
return {translate_addmm_common(context, beta, alpha)};
45+
};
46+
47+
OutputVector translate_addmm_fx(const NodeContext& context) {
48+
num_inputs_check(context, 3, 3);
49+
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
50+
ov::Output<Node> alpha = one;
51+
ov::Output<Node> beta = one;
52+
if (context.has_attribute("beta")) {
53+
beta = context.get_input("beta");
54+
}
55+
if (context.has_attribute("alpha")) {
56+
alpha = context.get_input("alpha");
57+
}
58+
return {translate_addmm_common(context, beta, alpha)};
3959
};
4060

4161
} // namespace op

0 commit comments

Comments
 (0)