Skip to content

Commit 8b79292

Browse files
authored
[PT FE] Support multiple FX operations (#23112)
### Details: - *Unpack nested tuples at model outputs* - *Support for `aten.cat.default` with default axis* - *Support `aten::reciprocal_` and `aten::abs_` for TorchScript case* - *Support:* * `aten.adaptive_max_pool3d.default` * `aten.adaptive_max_pool{N}d.default` * `aten.amin.default` * `aten.argmin.default` * `aten.bitwise_{XXX}.default` * `aten.clamp_max.default` * `aten.clamp_{max/min}.Tensor` * `aten.fill.Tensor` * `aten.flip.default` * `aten.fmod.{Scalar/Tensor}` * `aten.ge.{Scalar/Tensor}` * `aten.gt.Tensor` * `aten.index_select.default` * `aten.le.{Scalar/Tensor}` * `aten.log{10/1p/2}.default` * `aten.{max/maximum/mean/min/minimum}.default` * `aten.min.dim` * `aten.ne.Tensor` * `aten.pow.{Scalar/Tensor_Tensor}` * `aten.reciprocal.default` * `aten.rsub.Tensor` * `aten.scatter.value` * `aten.select_scatter.default` * `aten.sign.default` * `aten.sqrt.default` * `aten.sum.default` * `aten.unfold.default` * `aten.var.correction` * *trigonometric ops* - *Enabled layer tests for added ops for torch.export case* ### Tickets: - *ticket-id*
1 parent 889d02a commit 8b79292

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+590
-256
lines changed

.github/workflows/job_python_unit_tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ jobs:
231231
- name: PyTorch torch.export Layer Tests
232232
if: ${{ fromJSON(inputs.affected-components).PyTorch_FE.test && runner.arch != 'ARM64' }} # Ticket: 126287
233233
run: |
234-
python3 -m pytest ${LAYER_TESTS_INSTALL_DIR}/pytorch_tests -m precommit_torch_export --junitxml=${INSTALL_TEST_DIR}/TEST-pytorch.xml
234+
python3 -m pytest ${LAYER_TESTS_INSTALL_DIR}/pytorch_tests -n logical -m precommit_torch_export --junitxml=${INSTALL_TEST_DIR}/TEST-pytorch.xml
235235
env:
236236
TEST_DEVICE: CPU
237237
TEST_PRECISION: FP32

src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

+41-27
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
88
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
99
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
10-
from openvino.frontend.pytorch.utils import maybe_convert_max_int, make_constant, fetch_attr, pt_to_ov_type_map
10+
from openvino.frontend.pytorch.utils import maybe_convert_max_int, make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const
1111

1212
import torch
1313

@@ -21,11 +21,11 @@ def __init__(self, pt_module, fx_gm, nodes=None, mark_node_callback=None, input_
2121
self.m_decoders = []
2222
self.pt_module = pt_module
2323
self.fx_gm = fx_gm
24-
self.input_types = input_types
24+
self.input_types = [OVAny(pt_to_ov_type_map[str(t)])
25+
for t in input_types]
2526
self.input_shapes = input_shapes
2627

2728
self._input_signature = []
28-
self._output_names = []
2929

3030
if issubclass(type(pt_module), torch.fx.graph_module.GraphModule):
3131

@@ -39,41 +39,36 @@ def __init__(self, pt_module, fx_gm, nodes=None, mark_node_callback=None, input_
3939
self._input_signature.append(self._nodes[i].name)
4040
elif self._nodes[i].op == 'output':
4141
# Instead of putting output index, refer to its target
42-
args = self._nodes[i].args
43-
if isinstance(args[0], tuple):
44-
args = args[0]
45-
if isinstance(args[0], dict):
46-
for name, output in args[0].items():
47-
self._outputs.append(self._nodes.index(output))
48-
self._output_names.append(name)
49-
else:
50-
for output in args:
51-
self._outputs.append(self._nodes.index(output))
42+
uargs = self.unpack_containers(self._nodes[i].args)
43+
self._outputs = [(arg[0], self._nodes.index(arg[1])) for arg in uargs if arg[1] is not None]
5244

5345
elif issubclass(type(pt_module), torch.fx.Node):
5446

5547
self._nodes = nodes # passed from outer context
5648

5749
# FIXME: Quadratic complexity nodes*nodes considering the outer loop over all nodes
58-
for i in range(len(self._nodes)):
59-
if self._nodes[i] == pt_module:
60-
self._outputs = [i]
50+
self._outputs = [("", self._nodes.index(pt_module))]
6151

6252
# None in inputs mean the input is inlined or None (also considered inlined)
6353
self._inputs = [self._nodes.index(
6454
arg) if arg in self._nodes else (arg,) for arg in pt_module.args]
6555

6656
# FIXME: Find a better way to pass nested tuples to OV frontend. This is a temporary solution to flatten arguments.
6757
new_inputs = []
58+
self.input_types = []
6859
for i in range(len(pt_module.args)):
69-
if isinstance(pt_module.args[i], list) and any([isinstance(a, torch.fx.Node) for a in pt_module.args[i]]):
60+
if isinstance(pt_module.args[i], (list, tuple)) and any([isinstance(a, torch.fx.Node) for a in pt_module.args[i]]):
7061
for arg in pt_module.args[i]:
7162
if arg in self._nodes:
7263
new_inputs.append(self._nodes.index(arg))
7364
else:
7465
new_inputs.append((arg,))
66+
self.input_types.append(OVAny(DecoderType.List(
67+
TorchFXPythonDecoder.get_type_for_value(arg))))
7568
else:
7669
new_inputs.append(self._inputs[i])
70+
self.input_types.append(
71+
TorchFXPythonDecoder.get_type_for_value(self._inputs[i]))
7772
self._inputs = new_inputs
7873

7974
def inputs(self):
@@ -83,6 +78,24 @@ def inputs(self):
8378
def is_input_inlined(self, index):
8479
return isinstance(self._inputs[index], tuple)
8580

81+
@staticmethod
82+
def unpack_containers(arg):
83+
if isinstance(arg, (tuple, list)):
84+
res = []
85+
for e in arg:
86+
res.extend(TorchFXPythonDecoder.unpack_containers(e))
87+
return res
88+
elif isinstance(arg, dict):
89+
res = []
90+
for k, e in arg.items():
91+
unpacked = TorchFXPythonDecoder.unpack_containers(e)
92+
if len(unpacked) == 1:
93+
unpacked[0] = (k, unpacked[0][1])
94+
res.extend(unpacked)
95+
return res
96+
else:
97+
return [("", arg)]
98+
8699
@staticmethod
87100
def arg_to_constant(arg):
88101
if isinstance(arg, list):
@@ -91,7 +104,7 @@ def arg_to_constant(arg):
91104
arg[0]).__name__], Shape([len(arg)]), arg)
92105
else:
93106
# TODO: which type should we use if list is empty? Need a signaling value here
94-
return make_constant(int, Shape([0]), [])
107+
return make_constant(OVType.i32, Shape([0]), [])
95108
elif isinstance(arg, bool):
96109
return make_constant(OVType.boolean, Shape([]), [arg])
97110
elif isinstance(arg, int):
@@ -103,10 +116,10 @@ def arg_to_constant(arg):
103116
[]), [arg]) # TODO: f32? why not f64?
104117
return None
105118

106-
107119
def inlined_input(self, index):
108120
assert index < len(self._inputs), "Requested input doesn't exist"
109-
assert isinstance(self._inputs[index], tuple), "Requested input which is not inlined"
121+
assert isinstance(
122+
self._inputs[index], tuple), "Requested input which is not inlined"
110123
assert self._inputs[index][0] is not None, "Requested None inlined input"
111124
constant = None
112125
arg = self._inputs[index][0]
@@ -144,13 +157,13 @@ def get_input_strides(self, index: int) -> list:
144157

145158
def get_input_type(self, index):
146159
if index < len(self.input_types):
147-
return OVAny(pt_to_ov_type_map[str(self.input_types[index])])
160+
return self.input_types[index]
148161
input = self._raw_input(index)
149162
return self.get_type_for_value(input)
150163

151164
def get_output_debug_name(self, index):
152-
if self._output_names is not None and index < len(self._output_names):
153-
return self._output_names[index]
165+
if self._outputs is not None and index < len(self._outputs) and self._outputs[index][0]:
166+
return self._outputs[index][0]
154167
name = getattr(self.pt_module, "name", "output")
155168
return name + ":" + str(index)
156169

@@ -168,7 +181,8 @@ def get_shape_for_value(self, value):
168181
return PartialShape(len(value.meta['tensor_meta'].shape) * [-1])
169182
return PartialShape.dynamic()
170183

171-
def get_type_for_value(self, value):
184+
@staticmethod
185+
def get_type_for_value(value):
172186
if issubclass(type(value), torch.fx.Node):
173187
if ('tensor_meta' in value.meta.keys()):
174188
if value.meta['tensor_meta'] and isinstance(value.meta['tensor_meta'], torch.Tensor):
@@ -259,10 +273,10 @@ def get_schema(self):
259273
return self.pt_module.schema()
260274

261275
def outputs(self):
262-
return self._outputs
276+
return [o[1] for o in self._outputs]
263277

264278
def _raw_outputs(self):
265-
return [self._nodes[x] for x in self._outputs]
279+
return [self._nodes[x[1]] for x in self._outputs]
266280

267281
def _raw_output(self, index):
268282
return self._raw_outputs()[index]
@@ -293,7 +307,7 @@ def as_constant(self):
293307
if self.pt_module.op == 'get_attr':
294308
# Extract Constant from FX module field
295309
ret = fetch_attr(self.fx_gm, self.pt_module.target)
296-
ov_const = op.Constant(ret.numpy(force=True), shared_memory=True)
310+
ov_const = torch_tensor_to_ov_const(ret, shared_memory=True)
297311
return ov_const.outputs()
298312

299313
if not self.get_op_type() == 'prim::Constant':

src/frontends/pytorch/src/op/arange.cpp

+7-47
Original file line numberDiff line numberDiff line change
@@ -85,72 +85,32 @@ OutputVector translate_arange(const NodeContext& context) {
8585
OutputVector translate_arange_fx(const NodeContext& context) {
8686
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
8787
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
88-
int dtype_port = -1;
8988
auto dtype = element::f32;
90-
bool dtype_applied = false;
9189
auto num_inputs = context.get_input_size();
9290
ov::Output<Node> end;
93-
ov::Output<Node> out_tensor;
9491
ov::Output<Node> start = zero;
9592
ov::Output<Node> step = one;
9693

9794
if (num_inputs == 1) {
98-
// aten::arange(Scalar end, tensor out)
95+
// arange = torch.ops.aten.arange.default(_local_scalar_dense, dtype = torch.int8, device = device(type='cpu'),
96+
// pin_memory = False);
9997
end = context.get_input(0);
100-
out_tensor = end; // context.input_is_none(1) ? end : context.get_input(1);
10198
} else if (num_inputs == 2) {
102-
// aten::arange(Scalar end, tensor out)
10399
start = context.get_input(0);
104100
end = context.get_input(1);
105-
out_tensor = end; // context.input_is_none(1) ? end : context.get_input(1);
106101
} else if (num_inputs == 3) {
107-
// aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
108102
start = context.get_input(0);
109103
end = context.get_input(1);
110104
step = context.get_input(2);
111-
out_tensor = end; // context.input_is_none(3) ? end : context.get_input(3);
112-
} else if (num_inputs == 5) {
113-
// aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
114-
end = context.get_input(0);
115-
out_tensor = end;
116-
dtype_port = 1;
117-
} else if (num_inputs == 6) {
118-
// aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
119-
start = context.get_input(0);
120-
end = context.get_input(1);
121-
out_tensor = end;
122-
dtype_port = 2;
123-
dtype_applied = true;
124-
} else if (num_inputs == 7) {
125-
// aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
126-
start = context.get_input(0);
127-
end = context.get_input(1);
128-
step = context.get_input(2);
129-
out_tensor = end;
130-
dtype_port = 3;
131-
dtype_applied = true;
132105
} else {
133106
PYTORCH_OP_CONVERSION_CHECK(false, "Not expected number of inputs for ", context.get_op_type());
134107
}
135-
if (dtype_port >= 0 && !context.input_is_none(dtype_port)) {
136-
if (std::dynamic_pointer_cast<v0::Constant>(
137-
context.get_input_from_visible_context(dtype_port).get_node_shared_ptr())) {
138-
dtype = convert_dtype(context.const_input<int64_t>(dtype_port));
139-
dtype_applied = true;
140-
} else if (const auto& fw_node =
141-
cast_fw_node(context.get_input(dtype_port).get_node_shared_ptr(), "prim::dtype")) {
142-
out_tensor = fw_node->input_value(0);
143-
dtype_applied = false;
144-
} else {
145-
PYTORCH_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
146-
}
108+
if (context.has_attribute("dtype")) {
109+
dtype = context.get_attribute<element::Type>("dtype");
147110
}
148-
auto r_end = context.mark_node(std::make_shared<v0::Convert>(end, dtype));
149-
auto r_start = context.mark_node(std::make_shared<v0::Convert>(start, dtype));
150-
auto r_step = context.mark_node(std::make_shared<v0::Convert>(step, dtype));
151-
auto range = context.mark_node(std::make_shared<v4::Range>(r_start, r_end, r_step, dtype));
152-
if (!dtype_applied) {
153-
range = context.mark_node(std::make_shared<v1::ConvertLike>(range, out_tensor));
111+
auto range = context.mark_node(std::make_shared<v4::Range>(start, end, step, dtype));
112+
if (!context.has_attribute("dtype")) {
113+
range = context.mark_node(std::make_shared<v1::ConvertLike>(range, context.get_input(0)));
154114
}
155115
return {range};
156116
};

src/frontends/pytorch/src/op/argmax_argmin.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using namespace ov::op;
2121

2222
namespace {
2323
OutputVector create_argmax_argmin_op(const NodeContext& context, TopKMode mode) {
24-
num_inputs_check(context, 2, 3);
24+
num_inputs_check(context, 1, 3);
2525
auto input = context.get_input(0);
2626
bool keep_dims = false;
2727
auto k = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));

src/frontends/pytorch/src/op/avg_poolnd.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace op {
1919
using namespace ov::op;
2020

2121
OutputVector translate_avg_poolnd(const NodeContext& context) {
22-
num_inputs_check(context, 3, 7);
22+
num_inputs_check(context, 2, 7);
2323
auto input = context.get_input(0);
2424
auto kernel = context.const_input<Shape>(1);
2525
Strides strides;

src/frontends/pytorch/src/op/cat.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,18 @@ OutputVector translate_cat(const NodeContext& context) {
7474

7575
OutputVector translate_cat_fx(const NodeContext& context) {
7676
// This translator is only needed to get axis as constant from external scope
77-
num_inputs_check(context, 2, context.get_input_size());
77+
num_inputs_check(context, 1, context.get_input_size());
7878
std::deque<Output<Node>> list_elems;
7979
for (size_t i = 0; i < context.get_input_size() - 1; i++) {
8080
list_elems.push_back(context.get_input(static_cast<int>(i)));
8181
}
82-
auto axis = context.const_input<int64_t>(context.get_input_size() - 1);
82+
int64_t axis = 0;
83+
if (!context.get_input_type(context.get_input_size() - 1).is<type::List>()) {
84+
// axis can be not present and that means that last input will have List type
85+
axis = context.const_input<int64_t>(context.get_input_size() - 1);
86+
} else {
87+
list_elems.push_back(context.get_input(static_cast<int>(context.get_input_size() - 1)));
88+
}
8389
return translate_cat_common(context, list_elems, axis, true);
8490
};
8591

src/frontends/pytorch/src/op/cumsum.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//
44

55
#include "openvino/frontend/pytorch/node_context.hpp"
6+
#include "openvino/op/convert.hpp"
67
#include "openvino/op/cum_sum.hpp"
78
#include "utils.hpp"
89

@@ -28,6 +29,19 @@ OutputVector translate_cumsum(const NodeContext& context) {
2829
return {result};
2930
};
3031

32+
OutputVector translate_cumsum_fx(const NodeContext& context) {
33+
// cumsum = torch.ops.aten.cumsum.default(arg0_1, 0, dtype = torch.float64)
34+
num_inputs_check(context, 2, 2);
35+
auto x = context.get_input(0);
36+
auto dim = context.get_input(1);
37+
if (context.has_attribute("dtype")) {
38+
auto dtype = context.get_attribute<element::Type>("dtype");
39+
x = context.mark_node(std::make_shared<v0::Convert>(x, dtype));
40+
}
41+
auto result = context.mark_node(std::make_shared<v0::CumSum>(x, dim));
42+
return {result};
43+
};
44+
3145
} // namespace op
3246
} // namespace pytorch
3347
} // namespace frontend

src/frontends/pytorch/src/op/mean.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//
44

55
#include "openvino/frontend/pytorch/node_context.hpp"
6+
#include "openvino/op/convert.hpp"
67
#include "openvino/op/reduce_mean.hpp"
78
#include "utils.hpp"
89

@@ -11,6 +12,8 @@ namespace frontend {
1112
namespace pytorch {
1213
namespace op {
1314

15+
using namespace ov::op;
16+
1417
OutputVector translate_mean(const NodeContext& context) {
1518
num_inputs_check(context, 2, 5);
1619
auto x = context.get_input(0);
@@ -35,20 +38,26 @@ OutputVector translate_mean(const NodeContext& context) {
3538
x = apply_dtype(context, 3, x);
3639
}
3740
}
38-
auto mean = context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(x, axes, keep_dims));
41+
auto mean = context.mark_node(std::make_shared<v1::ReduceMean>(x, axes, keep_dims));
3942
if (num_inputs == 5 && !context.input_is_none(4)) {
4043
context.mutate_input(4, mean);
4144
}
4245
return {mean};
4346
};
4447

4548
OutputVector translate_mean_fx(const NodeContext& context) {
46-
num_inputs_check(context, 2, 5);
49+
num_inputs_check(context, 1, 5);
4750
auto x = context.get_input(0);
4851
auto num_inputs = context.get_input_size();
4952
bool keep_dims = false;
53+
if (context.has_attribute("dtype")) {
54+
auto dtype = context.get_attribute<element::Type>("dtype");
55+
x = context.mark_node(std::make_shared<v0::Convert>(x, dtype));
56+
}
5057
Output<Node> axes;
51-
if (num_inputs == 2) {
58+
if (num_inputs == 1) {
59+
axes = get_node_axes_range(context, x);
60+
} else if (num_inputs == 2) {
5261
axes = context.get_input(1);
5362
} else {
5463
axes = context.get_input(1);
@@ -59,7 +68,7 @@ OutputVector translate_mean_fx(const NodeContext& context) {
5968
x = apply_dtype(context, 3, x);
6069
}
6170
}
62-
auto mean = context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(x, axes, keep_dims));
71+
auto mean = context.mark_node(std::make_shared<v1::ReduceMean>(x, axes, keep_dims));
6372
if (num_inputs == 5 && !context.input_is_none(4)) {
6473
context.mutate_input(4, mean);
6574
}

0 commit comments

Comments
 (0)