@@ -89,37 +89,38 @@ def get_pytorch_decoder_for_model_on_disk(argv, args):
89
89
else :
90
90
input_model = argv .input_model
91
91
92
- if isinstance (input_model , (str , pathlib .Path )):
93
- # attempt to load scripted model
94
- try :
95
- inputs = prepare_torch_inputs (example_inputs )
96
- model = torch .jit .load (input_model )
97
- model .eval ()
98
- decoder = TorchScriptPythonDecoder (
99
- model ,
100
- example_input = inputs ,
101
- shared_memory = args .get ("share_weights" , True ),
102
- module_extensions = extract_module_extensions (args ))
92
+ if not isinstance (input_model , (str , pathlib .Path )):
93
+ return False
94
+
95
+ # attempt to load scripted model
96
+ try :
97
+ inputs = prepare_torch_inputs (example_inputs )
98
+ model = torch .jit .load (input_model )
99
+ model .eval ()
100
+ decoder = TorchScriptPythonDecoder (
101
+ model ,
102
+ example_input = inputs ,
103
+ shared_memory = args .get ("share_weights" , True ),
104
+ module_extensions = extract_module_extensions (args ))
105
+ argv .input_model = decoder
106
+ argv .framework = 'pytorch'
107
+ return True
108
+ except :
109
+ pass
110
+ # attempt to load exported model
111
+ try :
112
+ exported_program = torch .export .load (input_model )
113
+ if hasattr (torch , "export" ) and isinstance (exported_program , (torch .export .ExportedProgram )):
114
+ from packaging import version
115
+ if version .parse (torch .__version__ ) >= version .parse ("2.2" ):
116
+ exported_program = exported_program .run_decompositions ()
117
+ gm = exported_program .module ()
118
+ decoder = TorchFXPythonDecoder (gm , dynamic_shapes = True )
103
119
argv .input_model = decoder
104
120
argv .framework = 'pytorch'
105
121
return True
106
- except :
107
- pass
108
- if isinstance (input_model , (str , pathlib .Path )):
109
- # attempt to load exported model
110
- try :
111
- exported_program = torch .export .load (input_model )
112
- if hasattr (torch , "export" ) and isinstance (exported_program , (torch .export .ExportedProgram )):
113
- from packaging import version
114
- if version .parse (torch .__version__ ) >= version .parse ("2.2" ):
115
- exported_program = exported_program .run_decompositions ()
116
- gm = exported_program .module ()
117
- decoder = TorchFXPythonDecoder (gm , dynamic_shapes = True )
118
- argv .input_model = decoder
119
- argv .framework = 'pytorch'
120
- return True
121
- except :
122
- pass
122
+ except :
123
+ pass
123
124
return False
124
125
125
126
0 commit comments