Skip to content

Commit 46a6ccd

Browse files
authored
List decompositions for torch.export (#26878)
### Details: - *item1* - *...* ### Tickets: - *ticket-id*
1 parent 9027e1d commit 46a6ccd

File tree

8 files changed

+231
-39
lines changed

8 files changed

+231
-39
lines changed

src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
# flake8: noqa
55
# mypy: ignore-errors
66

7+
import logging
8+
import torch
9+
710
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
811
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
9-
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
12+
from openvino.runtime import PartialShape, Type as OVType, OVAny, Shape
1013
from openvino.frontend.pytorch.utils import make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const
1114

12-
import torch
13-
14-
import logging
1515
logger = logging.getLogger(__name__)
1616
logger.setLevel(logging.WARNING)
1717

src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py

+194-11
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def convolution_backward(
4646

4747
return grad_input, grad_weight, grad_bias
4848

49+
4950
if len(get_decompositions([aten._scaled_dot_product_flash_attention.default])) == 0:
51+
5052
@register_decomposition(aten._scaled_dot_product_flash_attention.default)
5153
def scaled_dot_product_flash_attention(
5254
query,
@@ -101,16 +103,197 @@ def scaled_dot_product_flash_attention(
101103

102104

103105
def get_aot_decomposition_list():
104-
return ([torch.ops.aten._scaled_dot_product_flash_attention.default,
105-
torch.ops.aten._softmax.default,
106-
torch.ops.aten._softmax_backward_data.default,
107-
torch.ops.aten.convolution_backward.default,
108-
torch.ops.aten.gelu_backward.default,
109-
torch.ops.aten.native_group_norm.default,
110-
torch.ops.aten.native_group_norm_backward.default,
111-
torch.ops.aten.native_layer_norm.default,
112-
torch.ops.aten.native_layer_norm_backward.default,
113-
torch.ops.aten.slice_backward.default])
106+
return [
107+
torch.ops.aten._scaled_dot_product_flash_attention.default,
108+
torch.ops.aten._softmax.default,
109+
torch.ops.aten._softmax_backward_data.default,
110+
torch.ops.aten.convolution_backward.default,
111+
torch.ops.aten.gelu_backward.default,
112+
torch.ops.aten.native_group_norm.default,
113+
torch.ops.aten.native_group_norm_backward.default,
114+
torch.ops.aten.native_layer_norm.default,
115+
torch.ops.aten.native_layer_norm_backward.default,
116+
torch.ops.aten.slice_backward.default,
117+
]
118+
114119

115120
def get_inf_decomposition_list():
116-
return ([torch.ops.aten.nll_loss_forward.default])
121+
return [torch.ops.aten.nll_loss_forward.default]
122+
123+
124+
def get_export_decomposition_list():
125+
# List of decompositions from torch._decomp.core_aten_decompositions
126+
# removed _backward ops and ops supported without decomposition
127+
decomp = [
128+
torch.ops.aten.addcdiv,
129+
torch.ops.aten.addcdiv_,
130+
torch.ops.aten.addcmul,
131+
torch.ops.aten.addcmul_,
132+
torch.ops.aten.addr,
133+
torch.ops.aten.affine_grid_generator,
134+
torch.ops.aten.all,
135+
torch.ops.aten.aminmax,
136+
torch.ops.aten.arange.default,
137+
torch.ops.aten.arange.start,
138+
torch.ops.aten.baddbmm,
139+
torch.ops.aten.binary_cross_entropy,
140+
torch.ops.aten.binary_cross_entropy_with_logits,
141+
torch.ops.aten.block_diag,
142+
torch.ops.aten.celu,
143+
torch.ops.aten.celu_,
144+
torch.ops.aten.clamp_max,
145+
torch.ops.aten.clamp_min,
146+
torch.ops.aten.count_nonzero,
147+
torch.ops.aten.linalg_cross,
148+
torch.ops.aten.cudnn_batch_norm,
149+
torch.ops.aten.deg2rad,
150+
torch.ops.aten.deg2rad_,
151+
torch.ops.aten.detach,
152+
torch.ops.aten.diag_embed,
153+
torch.ops.aten.dot,
154+
torch.ops.aten.vdot,
155+
torch.ops.aten.elu,
156+
torch.ops.aten.elu_,
157+
torch.ops.aten._embedding_bag,
158+
torch.ops.aten.empty_like,
159+
torch.ops.aten._euclidean_dist.default,
160+
torch.ops.aten.expand_as,
161+
torch.ops.aten.eye,
162+
torch.ops.aten.fill,
163+
torch.ops.aten.fill_,
164+
torch.ops.aten.floor_divide,
165+
torch.ops.aten.frac,
166+
torch.ops.aten.frac_,
167+
torch.ops.aten._fused_moving_avg_obs_fq_helper,
168+
torch.ops.aten.gelu_,
169+
torch.ops.aten.glu,
170+
torch.ops.aten.hardshrink,
171+
torch.ops.aten.hardsigmoid,
172+
torch.ops.aten.hardsigmoid_,
173+
torch.ops.aten.hardswish,
174+
torch.ops.aten.hardswish_,
175+
torch.ops.aten.hardtanh_,
176+
torch.ops.aten.heaviside,
177+
torch.ops.aten.heaviside_,
178+
torch.ops.aten.huber_loss,
179+
torch.ops.aten.im2col,
180+
torch.ops.aten.index_add,
181+
torch.ops.aten.index_add_,
182+
torch.ops.aten.index_copy,
183+
torch.ops.aten.index_copy_,
184+
torch.ops.aten.index_fill,
185+
torch.ops.aten.index_fill_,
186+
torch.ops.aten.isin,
187+
torch.ops.aten.isneginf,
188+
torch.ops.aten.isposinf,
189+
torch.ops.aten.l1_loss,
190+
torch.ops.aten.leaky_relu_,
191+
torch.ops.aten.lerp,
192+
torch.ops.aten.lerp_,
193+
torch.ops.aten.linspace,
194+
torch.ops.aten.logaddexp,
195+
torch.ops.aten.logaddexp2,
196+
torch.ops.aten.logit,
197+
torch.ops.aten.logit_,
198+
torch.ops.aten.log_sigmoid_forward,
199+
torch.ops.aten.logspace,
200+
torch.ops.aten.logsumexp.default,
201+
torch.ops.aten.masked_fill,
202+
torch.ops.aten.masked_fill_,
203+
torch.ops.aten.mish,
204+
torch.ops.aten.mish_,
205+
torch.ops.aten.mse_loss,
206+
torch.ops.aten.multi_margin_loss,
207+
torch.ops.aten.multilabel_margin_loss_forward,
208+
torch.ops.aten.mv,
209+
torch.ops.aten.mvlgamma,
210+
torch.ops.aten.mvlgamma_,
211+
torch.ops.aten.nansum,
212+
torch.ops.aten.nan_to_num,
213+
torch.ops.aten.nan_to_num_,
214+
torch.ops.aten.narrow,
215+
torch.ops.aten.new_empty,
216+
torch.ops.aten.new_full,
217+
torch.ops.aten.new_ones,
218+
torch.ops.aten.new_zeros,
219+
torch.ops.aten.nll_loss_forward,
220+
torch.ops.aten.norm,
221+
torch.ops.aten.ones,
222+
torch.ops.aten.ones_like,
223+
torch.ops.aten._prelu_kernel,
224+
torch.ops.aten._reshape_alias,
225+
torch.ops.aten.rad2deg,
226+
torch.ops.aten.rad2deg_,
227+
torch.ops.aten.reflection_pad1d,
228+
torch.ops.aten.reflection_pad2d,
229+
torch.ops.aten.reflection_pad3d,
230+
torch.ops.aten.replication_pad1d,
231+
torch.ops.aten.replication_pad2d,
232+
torch.ops.aten.replication_pad3d,
233+
torch.ops.aten.renorm,
234+
torch.ops.aten.renorm_,
235+
torch.ops.aten.resize_as,
236+
torch.ops.aten.roll,
237+
torch.ops.aten.rot90,
238+
torch.ops.aten.rrelu_with_noise,
239+
torch.ops.aten.rrelu_with_noise_,
240+
torch.ops.aten.rsub,
241+
torch.ops.aten.select_scatter,
242+
torch.ops.aten.sgn,
243+
torch.ops.aten.sgn_,
244+
torch.ops.aten.silu,
245+
torch.ops.aten.silu_,
246+
torch.ops.aten.sinc,
247+
torch.ops.aten.sinc_,
248+
torch.ops.aten.smooth_l1_loss,
249+
torch.ops.aten.soft_margin_loss,
250+
torch.ops.aten.softplus,
251+
torch.ops.aten.softshrink,
252+
torch.ops.aten.special_entr,
253+
torch.ops.aten.special_log_ndtr,
254+
torch.ops.aten.special_xlog1py,
255+
torch.ops.aten.split.Tensor,
256+
torch.ops.aten.split_with_sizes_copy,
257+
torch.ops.aten.squeeze.default,
258+
torch.ops.aten.squeeze.dim,
259+
torch.ops.aten.std,
260+
torch.ops.aten.std_mean,
261+
torch.ops.aten.stack,
262+
torch.ops.aten.sum.default,
263+
torch.ops.aten.sum.out,
264+
torch.ops.aten.t,
265+
torch.ops.aten.take,
266+
torch.ops.aten.threshold,
267+
torch.ops.aten.threshold_,
268+
torch.ops.aten.trace,
269+
torch.ops.aten.transpose.int,
270+
torch.ops.aten.tril,
271+
torch.ops.aten.tril_,
272+
torch.ops.aten.triu,
273+
torch.ops.aten.triu_,
274+
torch.ops.aten.unbind,
275+
torch.ops.aten.unfold_copy,
276+
torch.ops.aten._unsafe_index,
277+
torch.ops.aten.unsafe_split.Tensor,
278+
torch.ops.aten.unsafe_split_with_sizes,
279+
torch.ops.aten._unsafe_view,
280+
torch.ops.aten.view_as_complex,
281+
torch.ops.aten.xlogy,
282+
torch.ops.aten.xlogy_,
283+
torch.ops.aten.zero,
284+
torch.ops.aten.zero_,
285+
torch.ops.aten.zeros,
286+
torch.ops.aten.zeros_like,
287+
torch.ops.aten._weight_norm_interface,
288+
]
289+
try:
290+
from packaging import version
291+
if version.parse(torch.__version__) >= version.parse("2.3"):
292+
decomp += [
293+
torch.ops.aten._lazy_clone,
294+
torch.ops.aten._test_parallel_materialize,
295+
torch.ops.aten._chunk_cat,
296+
]
297+
except ImportError:
298+
pass
299+
return decomp

src/frontends/pytorch/src/op_table.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
787787
{"aten.clamp_min.default", op::translate_1to1_match_2_inputs_align_types<opset10::Maximum>},
788788
{"aten.clamp_min.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Maximum>},
789789
{"aten.clone.default", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
790+
{"aten.col2im.default", op::translate_col2im},
790791
{"aten.constant_pad_nd.default", op::translate_constant_pad_nd_fx},
791792
{"aten.convolution.default", op::translate_convolution},
792793
{"aten.copy.default", op::translate_copy_fx},

tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
import warnings
66
from copy import deepcopy
77
import os
8-
8+
import torch
9+
import pytest
10+
import logging
911
import numpy as np
12+
1013
from common.constants import test_device, test_precision
1114
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
12-
1315
from openvino.frontend import FrontEndManager
1416
from openvino.runtime import Core, Type, PartialShape
1517
import openvino.properties.hint as hints
16-
import torch
17-
from packaging import version
18-
import pytest
18+
19+
logging.basicConfig(level=logging.DEBUG)
1920

2021

2122
def skip_check(param):
@@ -124,13 +125,9 @@ def numpy_to_torch_recursively(x):
124125
from torch.export import export
125126

126127
em = export(model, tuple(torch_inputs))
127-
if version.parse(torch.__version__) >= version.parse("2.3"):
128-
em = em.run_decompositions()
129-
gm = em.module()
130-
print(gm.code)
131128

132129
converted_model = convert_model(
133-
em, example_input=torch_inputs)
130+
em, example_input=torch_inputs, verbose=True)
134131
self._resolve_input_shape_dtype(
135132
converted_model, ov_inputs, dynamic_shapes)
136133
smodel = model
@@ -242,7 +239,7 @@ def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_i
242239
if not dynamic_shapes:
243240
input_shapes = [inp.shape for inp in ov_inputs]
244241
kwargs["input"] = input_shapes
245-
om = convert_model(decoder, **kwargs)
242+
om = convert_model(decoder, verbose=True, **kwargs)
246243
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
247244
return smodel, om
248245

tests/layer_tests/pytorch_tests/test_col2im.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def forward(self, x):
4040

4141
@pytest.mark.nightly
4242
@pytest.mark.precommit
43+
@pytest.mark.precommit_torch_export
4344
@pytest.mark.parametrize("output_size,kernel_size", [([4, 5], [2, 2])])
4445
@pytest.mark.parametrize("dilation", [1, 2, [1, 2]])
4546
@pytest.mark.parametrize("padding", [0, 5, [2, 3]])

tests/layer_tests/pytorch_tests/test_eye.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
from packaging import version
67

78
from pytorch_layer_test_class import PytorchLayerTest
89

@@ -14,7 +15,6 @@ def _prepare_input(self, m, n=None):
1415
return (np.array(m, dtype="int32"), )
1516
return (np.array(m, dtype="int32"), np.array(n, dtype="int32"))
1617

17-
1818
def create_model(self, num_inputs, dtype):
1919
import torch
2020
dtype_map = {
@@ -45,29 +45,31 @@ def __init__(self, dtype):
4545
def forward(self, x, y):
4646
return torch.eye(x, y, dtype=self.dtype)
4747

48-
49-
ref_net = None
50-
51-
return aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype), ref_net, ("aten::eye", "aten::IntImplicit")
48+
model = aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype)
49+
return model, None, ["aten::eye", "aten::IntImplicit"]
5250

5351
@pytest.mark.nightly
5452
@pytest.mark.precommit
5553
@pytest.mark.precommit_torch_export
5654
@pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"])
5755
@pytest.mark.parametrize("m", [2, 3, 4, 5])
58-
@pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.")
5956
def test_eye_square(self, dtype, m, ie_device, precision, ir_version):
57+
if PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3"):
58+
pytest.skip("Not supported in PyTorch versions earlier than 2.3.")
6059
if ie_device == "GPU":
6160
pytest.xfail(reason="eye is not supported on GPU")
62-
self._test(*self.create_model(1, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m})
61+
self._test(*self.create_model(1, dtype), ie_device, precision,
62+
ir_version, kwargs_to_prepare_input={"m": m})
6363

6464
@pytest.mark.nightly
6565
@pytest.mark.precommit
6666
@pytest.mark.precommit_torch_export
6767
@pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"])
6868
@pytest.mark.parametrize(("m", "n"), [[2, 2], [3, 4], [5, 3]])
69-
@pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.")
7069
def test_eye(self, dtype, m, n, ie_device, precision, ir_version):
70+
if (PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3")):
71+
pytest.skip("Not supported in PyTorch versions earlier than 2.3.")
7172
if ie_device == "GPU":
7273
pytest.xfail(reason="eye is not supported on GPU")
73-
self._test(*self.create_model(2, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m, "n": n})
74+
self._test(*self.create_model(2, dtype), ie_device, precision,
75+
ir_version, kwargs_to_prepare_input={"m": m, "n": n})

tests/model_hub_tests/pytorch/torch_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def convert_model_impl(self, model_obj):
7575
pt_res = model_obj(**self.example)
7676
graph = export(model_obj, tuple(), self.example)
7777
if version.parse(torch.__version__) >= version.parse("2.2"):
78-
graph = graph.run_decompositions()
78+
from torch._decomp import get_decompositions
79+
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
80+
decomp = get_decompositions(get_export_decomposition_list())
81+
graph = graph.run_decompositions(decomp_table=decomp)
7982

8083
gm = graph.module()
8184
print(gm.code)

tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,20 @@ def extract_module_extensions(args):
4040
except:
4141
pass
4242
if not is_good_version:
43-
raise RuntimeError(
44-
"NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.")
43+
raise RuntimeError("NNCF models produced by nncf<2.6 are not "
44+
"supported directly. Please upgrade nncf or "
45+
"export to ONNX first.")
4546
inputs = prepare_torch_inputs(example_inputs)
4647
if not isinstance(model, (TorchScriptPythonDecoder, TorchFXPythonDecoder)):
4748
if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)):
4849
from packaging import version
4950
if version.parse(torch.__version__) >= version.parse("2.2"):
50-
model = model.run_decompositions()
51+
from torch._decomp import get_decompositions
52+
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
53+
decomp = get_decompositions(get_export_decomposition_list())
54+
model = model.run_decompositions(decomp_table=decomp)
5155
gm = model.module()
56+
log.debug(gm.code)
5257
decoder = TorchFXPythonDecoder(gm)
5358
else:
5459
decoder = TorchScriptPythonDecoder(

0 commit comments

Comments
 (0)