Skip to content

Commit 737caf5

Browse files
mvafinmlukasze
andauthored
[PT FE] Support prim::fork and aten::wait (#26839)
### Details: - *Support `prim::fork` and `aten::wait`* ### Tickets: - *CVS-153613* --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
1 parent 4b43150 commit 737caf5

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def __init__(
7373
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html."
7474
) from e
7575
self.graph_element = pt_module.inlined_graph
76-
log.debug("Inlined graph:\n%s", pt_module.inlined_graph)
7776
self.alias_db = self.graph_element.alias_db()
7877
else:
7978
self.graph_element = graph_element
@@ -96,6 +95,7 @@ def __init__(
9695
self._transform_tensor_list_constants_to_listconstruct(
9796
self.graph_element)
9897
self._transform_optional_constants(self.graph_element)
98+
log.debug("Inlined graph:\n%s", self.graph_element)
9999

100100
@staticmethod
101101
def _get_preserved_attributes(model) -> list:
@@ -293,22 +293,31 @@ def decoder_type_name(self) -> str:
293293
return "ts"
294294

295295
def get_subgraphs(self) -> list:
296-
if self.graph_element.kind() == "prim::PythonOp":
296+
if self.graph_element.kind() in ["prim::PythonOp", "prim::fork"]:
297297
if "Subgraph" in self.graph_element.attributeNames():
298298
assert isinstance(
299299
self.graph_element, torch.Node), "Graph element must be of type torch.Node."
300-
return [getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")]
300+
subgraph = getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")
301+
torch._C._jit_pass_inline(subgraph)
302+
return [subgraph]
301303
else:
302304
# Attribute "Subgraph" is only available if Graph was created using tracing.
303305
# TODO Find way to extract subgraph for scripted Graph.
304306
return []
305307
return list(self.graph_element.blocks())
306308

307309
def get_subgraph_decoder(self, index: int):
308-
decoder = TorchScriptPythonDecoder(
309-
self.pt_module, self.get_subgraphs(
310-
)[index], alias_db=self.alias_db, shared_memory=self._shared_memory, module_extensions=self.module_extensions
311-
)
310+
module = self.pt_module
311+
if self.graph_element.kind() == "prim::fork":
312+
in0 = self.raw_inputs[0]
313+
if in0.node().kind() == "prim::GetAttr":
314+
module, _ = get_value_from_getattr(in0.node(), self.pt_module)
315+
decoder = TorchScriptPythonDecoder(module,
316+
self.get_subgraphs()[index],
317+
alias_db=self.alias_db,
318+
shared_memory=self._shared_memory,
319+
module_extensions=self.module_extensions
320+
)
312321
self.m_decoders.append(decoder)
313322
return decoder
314323

@@ -456,8 +465,8 @@ def as_string(self):
456465

457466
@staticmethod
458467
def _as_constant_list(pt_value: torch.Value):
459-
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
460-
# rewrite them in that part where constant attributes are queried
468+
# For now we treat a list as a 1D tensor; it is required by converters to avoid
469+
# need to massively rewrite them in that part where constant attributes are queried
461470
pt_element_type = str(pt_value.type().getElementType())
462471
ivalue = pt_value.toIValue()
463472
is_known_type = pt_element_type in pt_to_ov_type_map
@@ -467,6 +476,7 @@ def _as_constant_list(pt_value: torch.Value):
467476
ovshape = PartialShape([len(ivalue)])
468477
ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue)
469478
return ov_const.outputs()
479+
return []
470480

471481
def _get_device_string(self) -> str:
472482
assert self.graph_element.kind(

src/frontends/pytorch/src/op_table.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
675675
{"aten::var_mean", op::translate_var_mean},
676676
{"aten::view", op::quantizable_op<op::translate_reshape>},
677677
{"aten::view_as", op::translate_reshape_as},
678+
{"aten::wait", op::skip_node},
678679
{"aten::where", op::translate_where},
679680
{"aten::zero", op::translate_zeros_like},
680681
{"aten::zeros", op::translate_zeros},
@@ -685,6 +686,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
685686
{"prim::Constant", op::translate_constant},
686687
{"prim::device", op::translate_constant},
687688
// prim::DictConstruct - Supported in limited set of patterns
689+
{"prim::fork", op::translate_pythonop},
688690
{"prim::GetAttr", op::translate_get_attr},
689691
{"prim::If", op::translate_if},
690692
{"prim::is_cuda", op::return_false_scalar},
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import numpy as np
5+
import pytest
6+
import torch
7+
8+
from pytorch_layer_test_class import PytorchLayerTest
9+
10+
11+
class TestForkWait(PytorchLayerTest):
12+
13+
def _prepare_input(self):
14+
return (np.random.randn(10, 20),)
15+
16+
def create_model(self):
17+
18+
class AddMod(torch.nn.Module):
19+
def forward(self, a: torch.Tensor, b: int):
20+
return a + b, a - b
21+
22+
class Mod(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.mod = AddMod()
26+
27+
def forward(self, input):
28+
fut = torch.jit.fork(self.mod, a=input, b=2)
29+
return torch.jit.wait(fut)
30+
31+
return Mod(), None, ["prim::fork", "aten::wait"]
32+
33+
@pytest.mark.nightly
34+
@pytest.mark.precommit
35+
@pytest.mark.parametrize(("to_trace"), [True, False])
36+
def test_fork_wait(self, to_trace, ie_device, precision, ir_version):
37+
self._test(*self.create_model(), ie_device, precision,
38+
ir_version, trace_model=to_trace)

0 commit comments

Comments
 (0)