7
7
from openvino .frontend .pytorch .py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
8
8
from openvino .frontend .pytorch .py_pytorch_frontend import _Type as DecoderType
9
9
from openvino .runtime import op , PartialShape , Type as OVType , OVAny , Shape
10
- from openvino .frontend .pytorch .utils import maybe_convert_max_int , make_constant , fetch_attr , pt_to_ov_type_map
10
+ from openvino .frontend .pytorch .utils import maybe_convert_max_int , make_constant , fetch_attr , pt_to_ov_type_map , torch_tensor_to_ov_const
11
11
12
12
import torch
13
13
@@ -21,11 +21,11 @@ def __init__(self, pt_module, fx_gm, nodes=None, mark_node_callback=None, input_
21
21
self .m_decoders = []
22
22
self .pt_module = pt_module
23
23
self .fx_gm = fx_gm
24
- self .input_types = input_types
24
+ self .input_types = [OVAny (pt_to_ov_type_map [str (t )])
25
+ for t in input_types ]
25
26
self .input_shapes = input_shapes
26
27
27
28
self ._input_signature = []
28
- self ._output_names = []
29
29
30
30
if issubclass (type (pt_module ), torch .fx .graph_module .GraphModule ):
31
31
@@ -39,41 +39,36 @@ def __init__(self, pt_module, fx_gm, nodes=None, mark_node_callback=None, input_
39
39
self ._input_signature .append (self ._nodes [i ].name )
40
40
elif self ._nodes [i ].op == 'output' :
41
41
# Instead of putting output index, refer to its target
42
- args = self ._nodes [i ].args
43
- if isinstance (args [0 ], tuple ):
44
- args = args [0 ]
45
- if isinstance (args [0 ], dict ):
46
- for name , output in args [0 ].items ():
47
- self ._outputs .append (self ._nodes .index (output ))
48
- self ._output_names .append (name )
49
- else :
50
- for output in args :
51
- self ._outputs .append (self ._nodes .index (output ))
42
+ uargs = self .unpack_containers (self ._nodes [i ].args )
43
+ self ._outputs = [(arg [0 ], self ._nodes .index (arg [1 ])) for arg in uargs if arg [1 ] is not None ]
52
44
53
45
elif issubclass (type (pt_module ), torch .fx .Node ):
54
46
55
47
self ._nodes = nodes # passed from outer context
56
48
57
49
# FIXME: Quadratic complexity nodes*nodes considering the outer loop over all nodes
58
- for i in range (len (self ._nodes )):
59
- if self ._nodes [i ] == pt_module :
60
- self ._outputs = [i ]
50
+ self ._outputs = [("" , self ._nodes .index (pt_module ))]
61
51
62
52
# None in inputs mean the input is inlined or None (also considered inlined)
63
53
self ._inputs = [self ._nodes .index (
64
54
arg ) if arg in self ._nodes else (arg ,) for arg in pt_module .args ]
65
55
66
56
# FIXME: Find a better way to pass nested tuples to OV frontend. This is a temporary solution to flatten arguments.
67
57
new_inputs = []
58
+ self .input_types = []
68
59
for i in range (len (pt_module .args )):
69
- if isinstance (pt_module .args [i ], list ) and any ([isinstance (a , torch .fx .Node ) for a in pt_module .args [i ]]):
60
+ if isinstance (pt_module .args [i ], ( list , tuple ) ) and any ([isinstance (a , torch .fx .Node ) for a in pt_module .args [i ]]):
70
61
for arg in pt_module .args [i ]:
71
62
if arg in self ._nodes :
72
63
new_inputs .append (self ._nodes .index (arg ))
73
64
else :
74
65
new_inputs .append ((arg ,))
66
+ self .input_types .append (OVAny (DecoderType .List (
67
+ TorchFXPythonDecoder .get_type_for_value (arg ))))
75
68
else :
76
69
new_inputs .append (self ._inputs [i ])
70
+ self .input_types .append (
71
+ TorchFXPythonDecoder .get_type_for_value (self ._inputs [i ]))
77
72
self ._inputs = new_inputs
78
73
79
74
def inputs (self ):
@@ -83,6 +78,24 @@ def inputs(self):
83
78
def is_input_inlined (self , index ):
84
79
return isinstance (self ._inputs [index ], tuple )
85
80
81
+ @staticmethod
82
+ def unpack_containers (arg ):
83
+ if isinstance (arg , (tuple , list )):
84
+ res = []
85
+ for e in arg :
86
+ res .extend (TorchFXPythonDecoder .unpack_containers (e ))
87
+ return res
88
+ elif isinstance (arg , dict ):
89
+ res = []
90
+ for k , e in arg .items ():
91
+ unpacked = TorchFXPythonDecoder .unpack_containers (e )
92
+ if len (unpacked ) == 1 :
93
+ unpacked [0 ] = (k , unpacked [0 ][1 ])
94
+ res .extend (unpacked )
95
+ return res
96
+ else :
97
+ return [("" , arg )]
98
+
86
99
@staticmethod
87
100
def arg_to_constant (arg ):
88
101
if isinstance (arg , list ):
@@ -91,7 +104,7 @@ def arg_to_constant(arg):
91
104
arg [0 ]).__name__ ], Shape ([len (arg )]), arg )
92
105
else :
93
106
# TODO: which type should we use if list is empty? Need a signaling value here
94
- return make_constant (int , Shape ([0 ]), [])
107
+ return make_constant (OVType . i32 , Shape ([0 ]), [])
95
108
elif isinstance (arg , bool ):
96
109
return make_constant (OVType .boolean , Shape ([]), [arg ])
97
110
elif isinstance (arg , int ):
@@ -103,10 +116,10 @@ def arg_to_constant(arg):
103
116
[]), [arg ]) # TODO: f32? why not f64?
104
117
return None
105
118
106
-
107
119
def inlined_input (self , index ):
108
120
assert index < len (self ._inputs ), "Requested input doesn't exist"
109
- assert isinstance (self ._inputs [index ], tuple ), "Requested input which is not inlined"
121
+ assert isinstance (
122
+ self ._inputs [index ], tuple ), "Requested input which is not inlined"
110
123
assert self ._inputs [index ][0 ] is not None , "Requested None inlined input"
111
124
constant = None
112
125
arg = self ._inputs [index ][0 ]
@@ -144,13 +157,13 @@ def get_input_strides(self, index: int) -> list:
144
157
145
158
def get_input_type (self , index ):
146
159
if index < len (self .input_types ):
147
- return OVAny ( pt_to_ov_type_map [ str ( self .input_types [index ])])
160
+ return self .input_types [index ]
148
161
input = self ._raw_input (index )
149
162
return self .get_type_for_value (input )
150
163
151
164
def get_output_debug_name (self , index ):
152
- if self ._output_names is not None and index < len (self ._output_names ) :
153
- return self ._output_names [index ]
165
+ if self ._outputs is not None and index < len (self ._outputs ) and self . _outputs [ index ][ 0 ] :
166
+ return self ._outputs [index ][ 0 ]
154
167
name = getattr (self .pt_module , "name" , "output" )
155
168
return name + ":" + str (index )
156
169
@@ -168,7 +181,8 @@ def get_shape_for_value(self, value):
168
181
return PartialShape (len (value .meta ['tensor_meta' ].shape ) * [- 1 ])
169
182
return PartialShape .dynamic ()
170
183
171
- def get_type_for_value (self , value ):
184
+ @staticmethod
185
+ def get_type_for_value (value ):
172
186
if issubclass (type (value ), torch .fx .Node ):
173
187
if ('tensor_meta' in value .meta .keys ()):
174
188
if value .meta ['tensor_meta' ] and isinstance (value .meta ['tensor_meta' ], torch .Tensor ):
@@ -259,10 +273,10 @@ def get_schema(self):
259
273
return self .pt_module .schema ()
260
274
261
275
def outputs (self ):
262
- return self ._outputs
276
+ return [ o [ 1 ] for o in self ._outputs ]
263
277
264
278
def _raw_outputs (self ):
265
- return [self ._nodes [x ] for x in self ._outputs ]
279
+ return [self ._nodes [x [ 1 ] ] for x in self ._outputs ]
266
280
267
281
def _raw_output (self , index ):
268
282
return self ._raw_outputs ()[index ]
@@ -293,7 +307,7 @@ def as_constant(self):
293
307
if self .pt_module .op == 'get_attr' :
294
308
# Extract Constant from FX module field
295
309
ret = fetch_attr (self .fx_gm , self .pt_module .target )
296
- ov_const = op . Constant (ret . numpy ( force = True ) , shared_memory = True )
310
+ ov_const = torch_tensor_to_ov_const (ret , shared_memory = True )
297
311
return ov_const .outputs ()
298
312
299
313
if not self .get_op_type () == 'prim::Constant' :
0 commit comments