Skip to content

Commit 20ad7cb

Browse files
cavusmustafaynimmagarkazantsmvafin
authored
Executorch initial support (#28425)
### Details: - OV side changes for initial ExecuTorch OV backend ### Tickets: - [*ticket-id*](https://jira.devtools.intel.com/browse/CVS-157257) --------- Co-authored-by: ynimmaga <yamini.nimmagadda@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
1 parent 5755945 commit 20ad7cb

File tree

11 files changed

+292
-20
lines changed

11 files changed

+292
-20
lines changed

src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

+35-20
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
178178
self._input_signature = []
179179
self._example_input = None
180180

181-
if issubclass(type(pt_module), torch.fx.graph_module.GraphModule):
181+
if isinstance(pt_module, torch.fx.graph_module.GraphModule):
182182
self._input_is_list = None
183183
self._nodes = list(pt_module.graph.nodes)
184184
found_types = []
@@ -187,38 +187,34 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
187187
if value.op == 'placeholder':
188188
self._inputs.append(i)
189189
self._input_signature.append(value.name)
190-
if hasattr(value, "meta") and ('tensor_meta' in value.meta.keys()) and value.meta['tensor_meta']:
191-
found_shapes.append(value.meta['tensor_meta'].shape)
192-
found_types.append(
193-
OVAny(pt_to_ov_type_map[str(value.meta['tensor_meta'].dtype)]))
194-
else:
195-
found_shapes.append(None)
196-
found_types.append(None)
190+
191+
found_shapes.append(self.get_found_shape(value))
192+
found_types.append(self.get_found_dtype(value))
193+
if found_shapes[-1] is not None:
194+
new_shape = []
195+
for dim in found_shapes[-1]:
196+
if (dynamic_shapes or type(dim).__name__ == "SymInt"):
197+
new_shape.append(-1)
198+
else:
199+
new_shape.append(dim)
200+
found_shapes[-1] = torch.Size(new_shape)
201+
197202
elif value.op == 'output':
198203
# Instead of putting output index, refer to its target
199204
uargs = self.unpack_containers(value.args)
200205
self._outputs = [(arg[0], self._nodes.index(arg[1]))
201206
for arg in uargs if arg[1] is not None]
202-
for idx, shape in enumerate(found_shapes):
203-
if shape is not None:
204-
new_shape = []
205-
for dim in shape:
206-
if (dynamic_shapes or type(dim).__name__ == "SymInt"):
207-
new_shape.append(-1)
208-
else:
209-
new_shape.append(dim)
210-
found_shapes[idx] = torch.Size(new_shape)
211207

212208
if not input_shapes or len(input_shapes) == 0:
213209
self.input_shapes = found_shapes
214210
if not input_types or len(input_types) == 0:
215211
self.input_types = found_types
216212

217-
if hasattr(pt_module, "forward"):
218-
input_params = inspect.signature(pt_module.forward).parameters
213+
if hasattr(self.pt_module, "forward"):
214+
input_params = inspect.signature(self.pt_module.forward).parameters
219215
self._input_signature = list(input_params)
220216

221-
elif issubclass(type(pt_module), torch.fx.Node):
217+
elif isinstance(pt_module, torch.fx.Node):
222218
self._nodes = nodes # passed from outer context
223219

224220
# FIXME: Quadratic complexity nodes*nodes considering the outer loop over all nodes
@@ -234,6 +230,23 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
234230
self.input_types.append(
235231
BaseFXDecoder.get_type_for_value(arg))
236232

233+
@staticmethod
234+
def get_found_shape(value) -> str:
235+
# If input is a tensor, read the shape from meta data
236+
if hasattr(value, "meta"):
237+
if ('tensor_meta' in value.meta.keys()) and value.meta['tensor_meta']:
238+
return value.meta['tensor_meta'].shape
239+
if ('val' in value.meta.keys()) and isinstance(value.meta["val"], torch.Tensor):
240+
return value.meta['val'].shape
241+
return None
242+
243+
@staticmethod
244+
def get_found_dtype(value) -> str:
245+
# If input is a tensor, read the data type from meta data
246+
if hasattr(value, "meta") and ('tensor_meta' in value.meta.keys()) and value.meta['tensor_meta']:
247+
return OVAny(pt_to_ov_type_map[str(value.meta['tensor_meta'].dtype)])
248+
return None
249+
237250
def get_input_signature_name(self, index: int) -> str:
238251
if self._input_signature is not None and index < len(self._input_signature):
239252
return self._input_signature[index]
@@ -331,6 +344,8 @@ def get_subgraph_decoder(self, index):
331344

332345
def get_op_type(self):
333346
if self.pt_module.op == 'call_function':
347+
if type(self.pt_module.target).__name__ == "EdgeOpOverload":
348+
return self.pt_module.target.__name__
334349
return str(self.pt_module.target)
335350
elif self.pt_module.op == 'get_attr':
336351
return 'get_attr' # FIXME should be aligned with get_attr from TS implementation

src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py

+9
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, options):
7575
"torch.ops.aten.argmin.default": None,
7676
"torch.ops.aten.as_strided.default": None,
7777
"torch.ops.aten.as_strided_.default": None,
78+
"torch.ops.aten.as_strided_copy.default": None,
7879
"torch.ops.aten.asin.default": None,
7980
"torch.ops.aten.asinh.default": None,
8081
"torch.ops.aten.asinh.default": None,
@@ -118,6 +119,7 @@ def __init__(self, options):
118119
"torch.ops.aten.erf.default": None,
119120
"torch.ops.aten.exp.default": None,
120121
"torch.ops.aten.expand.default": None,
122+
"torch.ops.aten.expand_copy.default": None,
121123
"torch.ops.aten.fake_quantize_per_channel_affine_cachemask.default": None,
122124
"torch.ops.aten.fill.Scalar": None,
123125
"torch.ops.aten.fill_.Scalar": None,
@@ -196,6 +198,7 @@ def __init__(self, options):
196198
"torch.ops.aten.new_zeros.default": None,
197199
"torch.ops.aten.ones.default": None,
198200
"torch.ops.aten.permute.default": None,
201+
"torch.ops.aten.permute_copy.default": None,
199202
"torch.ops.aten.pow.Scalar": None,
200203
"torch.ops.aten.pow.Tensor_Scalar": None,
201204
"torch.ops.aten.pow.Tensor_Tensor": None,
@@ -213,6 +216,7 @@ def __init__(self, options):
213216
"torch.ops.aten.scatter.src": None,
214217
"torch.ops.aten.scatter.value": None,
215218
"torch.ops.aten.select.int": None,
219+
"torch.ops.aten.select_copy.int": None,
216220
"torch.ops.aten.select_scatter.default": None,
217221
"torch.ops.aten.sigmoid.default": None,
218222
"torch.ops.aten.sigmoid_.default": None,
@@ -222,13 +226,16 @@ def __init__(self, options):
222226
"torch.ops.aten.sin.default": None,
223227
"torch.ops.aten.sinh.default": None,
224228
"torch.ops.aten.slice.Tensor": None,
229+
"torch.ops.aten.slice_copy.Tensor": None,
225230
"torch.ops.aten.slice_scatter.default": None,
226231
"torch.ops.aten.sort.default": None,
227232
"torch.ops.aten.split.Tensor": None,
228233
"torch.ops.aten.split_with_sizes.default": None,
234+
"torch.ops.aten.split_with_sizes_copy.default": None,
229235
"torch.ops.aten.sqrt.default": None,
230236
"torch.ops.aten.squeeze.dim": None,
231237
"torch.ops.aten.squeeze.dims": None,
238+
"torch.ops.aten.squeeze_copy.dims": None,
232239
"torch.ops.aten.stack.default": None,
233240
"torch.ops.aten.std.correction": None,
234241
"torch.ops.aten.sub.default": None,
@@ -246,10 +253,12 @@ def __init__(self, options):
246253
"torch.ops.aten.unbind.int": None,
247254
"torch.ops.aten.unfold.default": None,
248255
"torch.ops.aten.unsqueeze.default": None,
256+
"torch.ops.aten.unsqueeze_copy.default": None,
249257
"torch.ops.aten.upsample_nearest2d.default": None,
250258
"torch.ops.aten.var.correction": None,
251259
"torch.ops.aten.var_mean.correction": None,
252260
"torch.ops.aten.view.default": None,
261+
"torch.ops.aten.view_copy.default": None,
253262
"torch.ops.aten.where.self": None,
254263
"torch.ops.aten.zeros.default": None,
255264
"torch.ops.aten.zeros_like.default": None,

src/frontends/pytorch/src/op_table.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
804804
{"aten.argmin.default", op::translate_argmin},
805805
{"aten.as_strided.default", op::translate_as_strided},
806806
{"aten.as_strided_.default", op::translate_as_strided},
807+
{"aten.as_strided_copy.default", op::translate_as_strided},
807808
{"aten.asin.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asin>},
808809
{"aten.asinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>},
809810
{"aten.atan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>},
@@ -854,6 +855,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
854855
{"aten.exp.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>},
855856
{"aten.expm1.default", op::translate_expm1},
856857
{"aten.expand.default", op::translate_expand},
858+
{"aten.expand_copy.default", op::translate_expand},
857859
{"aten.eye.m", op::translate_eye_fx},
858860
{"aten.fake_quantize_per_channel_affine_cachemask.default", op::translate_fake_quantize_per_channel_affine_fx},
859861
{"aten.fill.Scalar", op::translate_fill},
@@ -936,6 +938,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
936938
{"aten.ones.names", op::translate_ones_fx},
937939
{"aten.ones_like.default", op::translate_ones_like_fx},
938940
{"aten.permute.default", op::translate_permute},
941+
{"aten.permute_copy.default", op::translate_1to1_match_2_inputs<opset10::Transpose>},
939942
{"aten.pow.Scalar", op::translate_pow},
940943
{"aten.pow.Tensor_Scalar", op::translate_pow},
941944
{"aten.pow.Tensor_Tensor", op::translate_pow},
@@ -958,6 +961,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
958961
{"aten.scatter.value", op::translate_scatter},
959962
{"aten.scatter_add.default", op::translate_scatter_add},
960963
{"aten.select.int", op::translate_select},
964+
{"aten.select_copy.int", op::translate_select},
961965
{"aten.select_scatter.default", op::translate_select_scatter_fx},
962966
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
963967
{"aten.sigmoid_.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
@@ -967,13 +971,16 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
967971
{"aten.sin.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sin>},
968972
{"aten.sinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sinh>},
969973
{"aten.slice.Tensor", op::translate_slice_fx},
974+
{"aten.slice_copy.Tensor", op::translate_slice_fx},
970975
{"aten.slice_scatter.default", op::translate_slice_scatter_fx},
971976
{"aten.sort.default", op::translate_sort_fx},
972977
{"aten.split.Tensor", op::translate_chunk_fx},
973978
{"aten.split_with_sizes.default", op::translate_split_with_sizes_fx},
979+
{"aten.split_with_sizes_copy.default", op::translate_split_with_sizes_fx},
974980
{"aten.sqrt.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>},
975981
{"aten.squeeze.dim", op::translate_squeeze},
976982
{"aten.squeeze.dims", op::translate_squeeze},
983+
{"aten.squeeze_copy.dims", op::translate_squeeze},
977984
{"aten.stack.default", op::translate_stack_fx},
978985
{"aten.std.correction", op::translate_std_fx},
979986
{"aten.sub.default", op::translate_sub_fx},
@@ -991,10 +998,12 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
991998
{"aten.unbind.int", op::translate_unbind_int_fx},
992999
{"aten.unfold.default", op::translate_unfold},
9931000
{"aten.unsqueeze.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
1001+
{"aten.unsqueeze_copy.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
9941002
{"aten.upsample_nearest2d.default", op::translate_upsample_nearest2d},
9951003
{"aten.var.correction", op::translate_var_fx},
9961004
{"aten.var_mean.correction", op::translate_var_mean_fx},
9971005
{"aten.view.default", op::translate_reshape},
1006+
{"aten.view_copy.default", op::translate_reshape},
9981007
{"aten.view_as_complex.default", op::translate_view_as_complex},
9991008
{"aten.view_as_real.default", op::translate_view_as_real},
10001009
{"aten.where.self", op::translate_where},

tests/layer_tests/pytorch_tests/test_as_strided.py

+34
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,40 @@ def forward(self, x):
4545
def test_as_strided(self, size, stride, offset, ie_device, precision, ir_version):
4646
self._test(*self.create_model(size, stride, offset), ie_device, precision, ir_version, trace_model=True)
4747

48+
class TestAsStridedCopy(PytorchLayerTest):
49+
def _prepare_input(self):
50+
return (np.random.randn(8, 8).astype(np.float32),)
51+
52+
def create_model(self, size, stride, offset):
53+
class aten_as_strided_copy(torch.nn.Module):
54+
def __init__(self, size, stride, offset):
55+
super().__init__()
56+
self.size = size
57+
self.stride = stride
58+
self.offset = offset
59+
60+
def forward(self, x):
61+
return torch.as_strided_copy(x, self.size, self.stride, self.offset)
62+
63+
ref_net = None
64+
65+
return aten_as_strided_copy(size, stride, offset), ref_net, "aten::as_strided_copy"
66+
67+
@pytest.mark.parametrize(
68+
"size,stride",
69+
[
70+
([1], [1]),
71+
([2, 2], [1, 1]),
72+
([5, 4, 3], [1, 3, 7]),
73+
([5, 5, 5], [5, 0, 5]),
74+
([1, 2, 3, 4], [4, 3, 2, 1]),
75+
],
76+
)
77+
@pytest.mark.parametrize("offset", [None, 1, 3, 7])
78+
@pytest.mark.precommit_fx_backend
79+
def test_as_strided_copy(self, size, stride, offset, ie_device, precision, ir_version):
80+
self._test(*self.create_model(size, stride, offset), ie_device, precision, ir_version, trace_model=True)
81+
4882

4983
class TestAsStridedListConstruct(PytorchLayerTest):
5084
def _prepare_input(self, size_shape_tensor=[1], stride_shape_tensor=[1]):

tests/layer_tests/pytorch_tests/test_expand.py

+25
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,31 @@ def forward_broadcast(self, x):
4141
def test_expand(self, dims, op_type, ie_device, precision, ir_version):
4242
self._test(*self.create_model(dims, op_type), ie_device, precision, ir_version)
4343

44+
class TestExpandCopy(PytorchLayerTest):
45+
def _prepare_input(self):
46+
import numpy as np
47+
return (np.random.randn(1, 3).astype(np.float32),)
48+
49+
def create_model(self, dim):
50+
import torch
51+
52+
class aten_expand_copy(torch.nn.Module):
53+
def __init__(self, dims):
54+
super(aten_expand_copy, self).__init__()
55+
self.dims = dims
56+
57+
def forward(self, x):
58+
return torch.expand_copy(x, self.dims)
59+
60+
ref_net = None
61+
62+
return aten_expand_copy(dim), ref_net, f"aten::expand_copy"
63+
64+
@pytest.mark.parametrize("dims", [(4, 3), (-1, -1), (1, 2, 3), (1, 2, 2, 3)])
65+
@pytest.mark.precommit_fx_backend
66+
def test_expand_copy(self, dims, ie_device, precision, ir_version):
67+
self._test(*self.create_model(dims), ie_device, precision, ir_version)
68+
4469
class TestExpandList(PytorchLayerTest):
4570
def _prepare_input(self, broadcast_shape):
4671
import numpy as np

tests/layer_tests/pytorch_tests/test_permute.py

+26
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,35 @@ def forward(self, x):
3838
@pytest.mark.nightly
3939
@pytest.mark.precommit
4040
@pytest.mark.precommit_torch_export
41+
@pytest.mark.precommit_fx_backend
4142
def test_permute(self, order, complex_type, ie_device, precision, ir_version):
4243
self._test(*self.create_model(order, complex_type), ie_device, precision, ir_version)
4344

45+
class TestPermuteCopy(PytorchLayerTest):
46+
def _prepare_input(self):
47+
import numpy as np
48+
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
49+
50+
def create_model(self, order):
51+
import torch
52+
53+
class aten_permute_copy(torch.nn.Module):
54+
def __init__(self, order):
55+
super(aten_permute_copy, self).__init__()
56+
self.order = order
57+
58+
def forward(self, x):
59+
return torch.permute_copy(x, self.order)
60+
61+
ref_net = None
62+
63+
return aten_permute_copy(order), ref_net, "aten::permute_copy"
64+
65+
@pytest.mark.parametrize("order", [[0, 2, 3, 1], [0, 3, 1, 2], [0, -1, 1, -2]])
66+
@pytest.mark.precommit_fx_backend
67+
def test_permute_copy(self, order, ie_device, precision, ir_version):
68+
self._test(*self.create_model(order), ie_device, precision, ir_version)
69+
4470

4571
class TestPermuteList(PytorchLayerTest):
4672
def _prepare_input(self, permute_shape):

tests/layer_tests/pytorch_tests/test_select.py

+28
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,34 @@ def forward(self, input_tensor):
3333
@pytest.mark.nightly
3434
@pytest.mark.precommit
3535
@pytest.mark.precommit_torch_export
36+
@pytest.mark.precommit_fx_backend
3637
def test_select(self, ie_device, precision, ir_version, input_dim, input_index):
3738
self._test(*self.create_model(input_dim, input_index),
3839
ie_device, precision, ir_version)
40+
41+
@pytest.mark.parametrize('input_dim', list(range(-3, 4)))
42+
@pytest.mark.parametrize('input_index', list(range(-3, 4)))
43+
class TestSelectCopy(PytorchLayerTest):
44+
45+
def _prepare_input(self):
46+
return (np.random.randn(4, 4, 5, 5).astype(np.float32),)
47+
48+
def create_model(self, input_dim, input_index):
49+
class aten_select_copy(torch.nn.Module):
50+
51+
def __init__(self, input_dim, input_index) -> None:
52+
super().__init__()
53+
self.dim = input_dim
54+
self.index = input_index
55+
56+
def forward(self, input_tensor):
57+
return torch.select_copy(input_tensor, int(self.dim), int(self.index))
58+
59+
ref_net = None
60+
61+
return aten_select_copy(input_dim, input_index), ref_net, "aten::select_copy"
62+
63+
@pytest.mark.precommit_fx_backend
64+
def test_select_copy(self, ie_device, precision, ir_version, input_dim, input_index):
65+
self._test(*self.create_model(input_dim, input_index),
66+
ie_device, precision, ir_version)

tests/layer_tests/pytorch_tests/test_split.py

+25
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,31 @@ def forward(self, x, y):
9999

100100
@pytest.mark.nightly
101101
@pytest.mark.precommit
102+
@pytest.mark.precommit_fx_backend
102103
def test_split_with_sizes(self, ie_device, precision, ir_version):
103104
self._test(*self.create_model(),
104105
ie_device, precision, ir_version, trace_model=True)
106+
107+
class TestSplitWithSizesCopy(PytorchLayerTest):
108+
def _prepare_input(self):
109+
import numpy as np
110+
return (np.random.randn(20).astype(np.float32),np.random.randn(20).astype(np.float32))
111+
112+
def create_model(self):
113+
import torch
114+
115+
class aten_split_with_sizes_copy(torch.nn.Module):
116+
def __init__(self):
117+
super(aten_split_with_sizes_copy, self).__init__()
118+
119+
def forward(self, x, y):
120+
return torch.split_with_sizes_copy(x, [y.shape[0]], dim=0)
121+
122+
ref_net = None
123+
124+
return aten_split_with_sizes_copy(), ref_net, ["aten::split_with_sizes", "prim::ListConstruct"]
125+
126+
@pytest.mark.precommit_fx_backend
127+
def test_split_with_sizes_copy(self, ie_device, precision, ir_version):
128+
self._test(*self.create_model(),
129+
ie_device, precision, ir_version, trace_model=True)

0 commit comments

Comments
 (0)