@@ -73,7 +73,6 @@ def __init__(
73
73
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html."
74
74
) from e
75
75
self .graph_element = pt_module .inlined_graph
76
- log .debug ("Inlined graph:\n %s" , pt_module .inlined_graph )
77
76
self .alias_db = self .graph_element .alias_db ()
78
77
else :
79
78
self .graph_element = graph_element
@@ -96,6 +95,7 @@ def __init__(
96
95
self ._transform_tensor_list_constants_to_listconstruct (
97
96
self .graph_element )
98
97
self ._transform_optional_constants (self .graph_element )
98
+ log .debug ("Inlined graph:\n %s" , self .graph_element )
99
99
100
100
@staticmethod
101
101
def _get_preserved_attributes (model ) -> list :
@@ -293,22 +293,31 @@ def decoder_type_name(self) -> str:
293
293
return "ts"
294
294
295
295
def get_subgraphs (self ) -> list :
296
- if self .graph_element .kind () == "prim::PythonOp" :
296
+ if self .graph_element .kind () in [ "prim::PythonOp" , "prim::fork" ] :
297
297
if "Subgraph" in self .graph_element .attributeNames ():
298
298
assert isinstance (
299
299
self .graph_element , torch .Node ), "Graph element must be of type torch.Node."
300
- return [getattr (self .graph_element , self .graph_element .kindOf ("Subgraph" ))("Subgraph" )]
300
+ subgraph = getattr (self .graph_element , self .graph_element .kindOf ("Subgraph" ))("Subgraph" )
301
+ torch ._C ._jit_pass_inline (subgraph )
302
+ return [subgraph ]
301
303
else :
302
304
# Attribute "Subgraph" is only available if Graph was created using tracing.
303
305
# TODO Find way to extract subgraph for scripted Graph.
304
306
return []
305
307
return list (self .graph_element .blocks ())
306
308
307
309
def get_subgraph_decoder (self , index : int ):
308
- decoder = TorchScriptPythonDecoder (
309
- self .pt_module , self .get_subgraphs (
310
- )[index ], alias_db = self .alias_db , shared_memory = self ._shared_memory , module_extensions = self .module_extensions
311
- )
310
+ module = self .pt_module
311
+ if self .graph_element .kind () == "prim::fork" :
312
+ in0 = self .raw_inputs [0 ]
313
+ if in0 .node ().kind () == "prim::GetAttr" :
314
+ module , _ = get_value_from_getattr (in0 .node (), self .pt_module )
315
+ decoder = TorchScriptPythonDecoder (module ,
316
+ self .get_subgraphs ()[index ],
317
+ alias_db = self .alias_db ,
318
+ shared_memory = self ._shared_memory ,
319
+ module_extensions = self .module_extensions
320
+ )
312
321
self .m_decoders .append (decoder )
313
322
return decoder
314
323
@@ -456,8 +465,8 @@ def as_string(self):
456
465
457
466
@staticmethod
458
467
def _as_constant_list (pt_value : torch .Value ):
459
- # For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
460
- # rewrite them in that part where constant attributes are queried
468
+ # For now we treat a list as a 1D tensor; it is required by converters to avoid
469
+ # need to massively rewrite them in that part where constant attributes are queried
461
470
pt_element_type = str (pt_value .type ().getElementType ())
462
471
ivalue = pt_value .toIValue ()
463
472
is_known_type = pt_element_type in pt_to_ov_type_map
@@ -467,6 +476,7 @@ def _as_constant_list(pt_value: torch.Value):
467
476
ovshape = PartialShape ([len (ivalue )])
468
477
ov_const = op .Constant (ovtype , ovshape .get_shape (), ivalue )
469
478
return ov_const .outputs ()
479
+ return []
470
480
471
481
def _get_device_string (self ) -> str :
472
482
assert self .graph_element .kind (
0 commit comments