Skip to content

Commit 99715bf

Browse files
authored
[PT FE] Add more operators for FX graph (openvinotoolkit#23698)
### Details: - *item1* - *...* ### Tickets: - *ticket-id*
1 parent 0569f5c commit 99715bf

File tree

8 files changed

+21
-24
lines changed

8 files changed

+21
-24
lines changed

src/frontends/pytorch/src/op/full.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ OutputVector translate_full_like_fx(const NodeContext& context) {
107107
if (context.has_attribute("dtype")) {
108108
auto dtype = context.get_attribute<element::Type>("dtype");
109109
filled_tensor = context.mark_node(std::make_shared<v0::Convert>(filled_tensor, dtype));
110+
} else {
111+
filled_tensor = context.mark_node(std::make_shared<v1::ConvertLike>(filled_tensor, input));
110112
}
111113
return {filled_tensor};
112114
};
@@ -200,6 +202,8 @@ OutputVector translate_zeros_like_fx(const NodeContext& context) {
200202
if (context.has_attribute("dtype")) {
201203
auto dtype = context.get_attribute<element::Type>("dtype");
202204
filled_tensor = context.mark_node(std::make_shared<v0::Convert>(filled_tensor, dtype));
205+
} else {
206+
filled_tensor = context.mark_node(std::make_shared<v1::ConvertLike>(filled_tensor, input));
203207
}
204208
return {filled_tensor};
205209
};
@@ -280,6 +284,8 @@ OutputVector translate_ones_like_fx(const NodeContext& context) {
280284
if (context.has_attribute("dtype")) {
281285
auto dtype = context.get_attribute<element::Type>("dtype");
282286
filled_tensor = context.mark_node(std::make_shared<v0::Convert>(filled_tensor, dtype));
287+
} else {
288+
filled_tensor = context.mark_node(std::make_shared<v1::ConvertLike>(filled_tensor, input));
283289
}
284290
return {filled_tensor};
285291
};

src/frontends/pytorch/src/op_table.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
829829
{"aten.hardtanh.default", op::translate_hardtanh},
830830
{"aten.hardtanh_.default", op::inplace_op<op::translate_hardtanh>},
831831
{"aten.index.Tensor", op::translate_index_fx},
832+
// aten.index_put.default - Supported in limited set of patterns
832833
{"aten.index_select.default", op::translate_index_select},
833834
{"aten.isfinite.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::IsFinite>>},
834835
{"aten.isinf.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::IsInf>>},
@@ -844,6 +845,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
844845
{"aten.log10.default", op::translate_log10},
845846
{"aten.log1p.default", op::translate_log1p},
846847
{"aten.log2.default", op::translate_log2},
848+
{"aten.logical_and.default", op::translate_and},
847849
{"aten.logsumexp.default", op::translate_logsumexp},
848850
{"aten.lt.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
849851
{"aten.lt.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
@@ -898,6 +900,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
898900
{"aten.scalar_tensor.default", op::translate_scalar_tensor_fx},
899901
{"aten.scatter.src", op::translate_scatter},
900902
{"aten.scatter.value", op::translate_scatter},
903+
{"aten.scatter_add.default", op::translate_scatter_add},
901904
{"aten.select.int", op::translate_select},
902905
{"aten.select_scatter.default", op::translate_select_scatter_fx},
903906
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},

src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
4949
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
5050
auto index_op = cast_fw_node(m.get_match_root(), "aten::index_put_");
5151
if (!index_op) {
52-
return false;
52+
index_op = cast_fw_node(m.get_match_root(), "aten.index_put.default");
53+
if (!index_op) {
54+
return false;
55+
}
5356
}
5457
NodeVector rt_copy_from;
5558
ov::pass::NodeRegistry rg;

tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,12 @@ def numpy_to_torch_recursively(x):
9595
if self.use_torch_export():
9696
from openvino import convert_model
9797
from torch.export import export
98-
from torch.fx.experimental.proxy_tensor import make_fx
9998

10099
em = export(model, tuple(torch_inputs))
101100
if version.parse(torch.__version__) >= version.parse("2.3"):
102101
em = em.run_decompositions()
103-
print(em.graph_module.code)
104-
105-
try:
106-
gm = make_fx(em)(*torch_inputs)
107-
except:
108-
gm = make_fx(em, tracing_mode='symbolic')(*torch_inputs)
102+
gm = em.module()
103+
print(gm.code)
109104

110105
input_shapes = []
111106
input_types = []

tests/layer_tests/pytorch_tests/test_scatter.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66
import torch
7-
from pytorch_layer_test_class import PytorchLayerTest
7+
from pytorch_layer_test_class import PytorchLayerTest, skip_if_export
88

99

1010
class TestScatter(PytorchLayerTest):
@@ -256,6 +256,7 @@ def forward(self, x: torch.Tensor):
256256

257257
@pytest.mark.nightly
258258
@pytest.mark.precommit
259+
@pytest.mark.precommit_torch_export
259260
@pytest.mark.parametrize("dim", [1, -1, 0])
260261
@pytest.mark.parametrize(
261262
"index",
@@ -267,7 +268,7 @@ def forward(self, x: torch.Tensor):
267268
)
268269
@pytest.mark.parametrize("src", [torch.arange(1, 26).reshape(5, 5)])
269270
@pytest.mark.parametrize("dtype", ["int32", "int64", "float32", "float64"])
270-
@pytest.mark.parametrize("inplace", [True, False])
271+
@pytest.mark.parametrize("inplace", [skip_if_export(True), False])
271272
def test_scatter_add(self, dim, index, src, dtype, inplace, ie_device, precision, ir_version):
272273
if isinstance(src, torch.Tensor):
273274
src = src.to(getattr(torch, dtype))

tests/model_hub_tests/pytorch/hf_transformers_models

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ liamcripwell/pgdyn-plan,context-roberta,skip,Load problem
219219
linhdo/graphdoc,graphdoc,skip,Load problem
220220
LinkSoul/Chinese-LLaVA-Baichuan,llava,skip,Load problem
221221
LinkSoul/LLaSM-Cllama2,llaaa,skip,Load problem
222-
lintang/t5-v2-xl-flan,umt5
222+
lintang/pile-t5-base-flan,umt5
223223
liuhaotian/LLaVA-Lightning-MPT-7B-preview,llava_mpt,skip,Load problem
224224
liya0121/my_finetune_0121,progen,skip,Load problem
225225
lucadiliello/BLEURT-20,bleurt,skip,Load problem

tests/model_hub_tests/pytorch/torch_utils.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def prepare_inputs(self, inputs_info):
6464

6565
def convert_model_impl(self, model_obj):
6666
if hasattr(self, "mode") and self.mode == "export":
67-
from torch.fx.experimental.proxy_tensor import make_fx, get_isolated_graphmodule
6867
from torch.export import export
6968
from packaging import version
7069
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
@@ -90,17 +89,7 @@ def convert_model_impl(self, model_obj):
9089
if version.parse(torch.__version__) >= version.parse("2.2"):
9190
graph = graph.run_decompositions()
9291

93-
if isinstance(self.example, dict):
94-
try:
95-
gm = get_isolated_graphmodule(graph, tuple(), self.example)
96-
except:
97-
gm = get_isolated_graphmodule(graph, tuple(), self.example, tracing_mode='symbolic')
98-
else:
99-
try:
100-
gm = make_fx(graph)(*self.example)
101-
except:
102-
gm = make_fx(graph, tracing_mode='symbolic')(*self.example)
103-
92+
gm = graph.module()
10493
print(gm.code)
10594

10695
decoder = TorchFXPythonDecoder(gm, gm, input_shapes=input_shapes, input_types=input_types)

tests/model_hub_tests/pytorch/torchbench_models

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ pyhpc_equation_of_state,None,xfail,Accuracy validation failed
6666
pyhpc_isoneutral_mixing,None,xfail,Accuracy validation failed
6767
pyhpc_turbulent_kinetic_energy,None,xfail,Unsupported op aten::empty_like
6868
pytorch_CycleGAN_and_pix2pix,None
69-
pytorch_stargan,None,xfail,CPU plugin error: Unsupported operation of type: BatchNormInference
69+
pytorch_stargan,None
7070
pytorch_unet,None
7171
#resnet152,None - Already tested by torchvision tests
7272
#resnet18,None - Already tested by torchvision tests

0 commit comments

Comments
 (0)