Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Executorch initial support #28425

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

cavusmustafa
Copy link
Contributor

@cavusmustafa cavusmustafa commented Jan 14, 2025

Details:

  • OV side changes for initial ExecuTorch OV backend

Tickets:

@github-actions github-actions bot added category: Python API OpenVINO Python bindings category: PyTorch FE OpenVINO PyTorch Frontend labels Jan 14, 2025
@cavusmustafa cavusmustafa marked this pull request as ready for review January 14, 2025 02:54
@cavusmustafa cavusmustafa requested review from a team as code owners January 14, 2025 02:54
Copy link
Contributor

@rkazants rkazants left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have several question about this feature:

  1. Will is support ExecuTorch binary compiled to OV ExecuTorch backend?
  2. What other ExecuTorch binaries will be supported: default, xnnpack?
  3. We won't support ExecuTorch binary inference using native OpenVINO workflow, right?
  4. Do you have PR/branch with ExecuTorch OV backend implementation?

Best regards,
Roman

@@ -83,3 +83,10 @@ def _is_testing(options) -> Optional[Any]:
return True
return False

def _executorch(options) -> Optional[Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If executorch is added in options, this will also be available to torch.compile users and may break the flow if the option is set to True as ExecuTorchPythonDecoder will be used. Instead, can we add it as a kwarg to openvino_compile and set it to False by default? This can be set to True in Executorch openvino backend. That way it is not visible to the user. Please let me know your thoughts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I was trying not to increase the number of parameters for openvino_compile at first but it makes sense not to expose this feature to users. I think it is a better idea to pass this as kwarg. I will update accordingly.

@cavusmustafa
Copy link
Contributor Author

I have several question about this feature:

  1. Will is support ExecuTorch binary compiled to ExecuTorch backend?
  2. What other ExecuTorch binaries will be supported: default, xnnpack?
  3. We won't support ExecuTorch binary inference using native OpenVINO workflow, right?
  4. Do you have PR/branch with ExecuTorch OV backend implementation?

Best regards, Roman

  1. Yes, the changes included in this PR to support ExecuTorch OpenVINO backend.
  2. POC will support OpenVINO backend only but we are also working on enabling fallback to default backend for unsupported ops.
  3. ExecuTorch binary inference will be using OpenVINO workflow similar to inference execution with torch.compile. The partition compiled for OpenVINO execution (which is stored in .pte file) is executed in ExecuTorch OpenVINO backend which uses OpenVINO C++ API.
  4. Branch: https://github.com/ynimmaga/executorch/tree/openvino_backend

Please let me know for additional questions and clarifications needed.

self._inputs.append(i)
self._input_signature.append(value.name)
if hasattr(value, "meta") and ('tensor_meta' in value.meta.keys()) and value.meta['tensor_meta']:
found_shapes.append(value.meta['tensor_meta'].shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please comment what type of this input that requires special branch?
for else, we need a comment as well

self._input_signature = []
self._example_input = None

if issubclass(type(pt_module), torch.fx.graph_module.GraphModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if issubclass(type(pt_module), torch.fx.graph_module.GraphModule):
if isinstance(pt_module, torch.fx.graph_module.GraphModule):

can we do it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please create separate internal function in this class to handle torch.fx.graph_module.GraphModule case
and separate function to handle torch.fx.Node because now it is quite long for constructor

Comment on lines +500 to +508
for idx, shape in enumerate(found_shapes):
if shape is not None:
new_shape = []
for dim in shape:
if (dynamic_shapes or type(dim).__name__ == "SymInt"):
new_shape.append(-1)
else:
new_shape.append(dim)
found_shapes[idx] = torch.Size(new_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should have helper function for this. Can you please check or create?
Why can't you reformat found_shapes[idx] right away when you were adding elements above? Now you have to re-iterate through it again

@@ -789,6 +789,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.argmin.default", op::translate_argmin},
{"aten.as_strided.default", op::translate_as_strided},
{"aten.as_strided_.default", op::translate_as_strided},
{"aten.as_strided_copy.default", op::translate_as_strided},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add layer tests for newly added translators?

from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
from openvino import Core, Type, PartialShape, serialize
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder, ExecuTorchPythonDecoder
from openvino.runtime import Core, Type, PartialShape, serialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from openvino.runtime import Core, Type, PartialShape, serialize
from openvino import Core, Type, PartialShape, serialize

openvino.runtime gets deprecated this release

@@ -78,7 +78,7 @@ def openvino_compile_cached_model(cached_model_path, options, *example_inputs):

return compiled_model

def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, options=None):
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, options=None, executorch=False):
core = Core()

device = _get_device(options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for ExecuTorch, will you cache ov::Model or ov::CompiledModel?

@@ -13,8 +13,8 @@
from torch.fx import GraphModule
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to have separate file for executorch implementation? Just to avoid mix with torch.compile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: Python API OpenVINO Python bindings category: PyTorch FE OpenVINO PyTorch Frontend Code Freeze
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants