Skip to content

Commit 2d00d75

Browse files
authored
[OVC][PT FE] Cover leftovers for torch.export.ExportedProgram support (openvinotoolkit#27042)
**Details:** Cover leftovers for ExportedProgram support **Ticket:** TBD Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent 01ceeb8 commit 2d00d75

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

tools/ovc/openvino/tools/ovc/convert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ def convert_model(
2727
2828
Framework-agnostic parameters:
2929
:param input_model:
30-
Model object in original framework (PyTorch, Tensorflow) or path to model file.
30+
Model object in original framework (PyTorch, TensorFlow) or path to model file.
3131
3232
Supported formats of input model:
3333
3434
PyTorch
3535
torch.nn.Module
3636
torch.jit.ScriptModule
3737
torch.jit.ScriptFunction
38+
torch.export.ExportedProgram
3839
3940
TF
4041
tf.compat.v1.Graph

tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py

+29-28
Original file line numberDiff line numberDiff line change
@@ -89,37 +89,38 @@ def get_pytorch_decoder_for_model_on_disk(argv, args):
8989
else:
9090
input_model = argv.input_model
9191

92-
if isinstance(input_model, (str, pathlib.Path)):
93-
# attempt to load scripted model
94-
try:
95-
inputs = prepare_torch_inputs(example_inputs)
96-
model = torch.jit.load(input_model)
97-
model.eval()
98-
decoder = TorchScriptPythonDecoder(
99-
model,
100-
example_input=inputs,
101-
shared_memory=args.get("share_weights", True),
102-
module_extensions=extract_module_extensions(args))
92+
if not isinstance(input_model, (str, pathlib.Path)):
93+
return False
94+
95+
# attempt to load scripted model
96+
try:
97+
inputs = prepare_torch_inputs(example_inputs)
98+
model = torch.jit.load(input_model)
99+
model.eval()
100+
decoder = TorchScriptPythonDecoder(
101+
model,
102+
example_input=inputs,
103+
shared_memory=args.get("share_weights", True),
104+
module_extensions=extract_module_extensions(args))
105+
argv.input_model = decoder
106+
argv.framework = 'pytorch'
107+
return True
108+
except:
109+
pass
110+
# attempt to load exported model
111+
try:
112+
exported_program = torch.export.load(input_model)
113+
if hasattr(torch, "export") and isinstance(exported_program, (torch.export.ExportedProgram)):
114+
from packaging import version
115+
if version.parse(torch.__version__) >= version.parse("2.2"):
116+
exported_program = exported_program.run_decompositions()
117+
gm = exported_program.module()
118+
decoder = TorchFXPythonDecoder(gm, dynamic_shapes=True)
103119
argv.input_model = decoder
104120
argv.framework = 'pytorch'
105121
return True
106-
except:
107-
pass
108-
if isinstance(input_model, (str, pathlib.Path)):
109-
# attempt to load exported model
110-
try:
111-
exported_program = torch.export.load(input_model)
112-
if hasattr(torch, "export") and isinstance(exported_program, (torch.export.ExportedProgram)):
113-
from packaging import version
114-
if version.parse(torch.__version__) >= version.parse("2.2"):
115-
exported_program = exported_program.run_decompositions()
116-
gm = exported_program.module()
117-
decoder = TorchFXPythonDecoder(gm, dynamic_shapes=True)
118-
argv.input_model = decoder
119-
argv.framework = 'pytorch'
120-
return True
121-
except:
122-
pass
122+
except:
123+
pass
123124
return False
124125

125126

0 commit comments

Comments
 (0)