Skip to content

Commit

Permalink
Get original_traced_args as example_inputs. (#5511)
Browse files Browse the repository at this point in the history
Change due to changing name in pytorch/pytorch#107978
  • Loading branch information
qihqi authored Aug 28, 2023
1 parent 808015d commit 33f1cd2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
11 changes: 11 additions & 0 deletions test/stablehlo/test_stablehlo_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def test_save_load2(self):
result = program2(*inputs).detach().cpu()
self.assertTrue(torch.allclose(model(*inputs), result))

def test_save_load3(self):
model = ElementwiseAdd()
inputs = model.get_random_inputs()
exported = torch._export.export(model, inputs)
with tempfile.TemporaryDirectory() as tempdir:
# Shouldnt need specify options because exported has example_input inside
save_as_stablehlo(exported, tempdir)
program2 = StableHLOGraphModule.load(tempdir)
result = program2(*inputs).detach().cpu()
self.assertTrue(torch.allclose(model(*inputs), result))


if __name__ == '__main__':
test = unittest.main()
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ def _exported_program_to_stablehlo_bundle(exported_model,
if options.override_tracing_arguments is not None:
args = options.override_tracing_arguments
else:
args = getattr(exported_model, 'original_traced_arguments', None)
if hasattr(exported_model, 'example_inputs'):
args, _ = getattr(exported_model, 'example_inputs', None)
elif hasattr(exported_model, 'original_traced_arguments'):
args = getattr(exported_model, 'original_traced_arguments', None)

if args is None:
raise ValueError(
'No argument is provided, please set tracing argument in options.override_tracing_arguments'
Expand Down Expand Up @@ -463,6 +467,8 @@ def exported_program_to_stablehlo(
so it might specialize on the shapes of the sample input.
"""
if options is None:
options = StableHLOExportOptions()
bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
return StableHLOGraphModule(bundle)

Expand Down

0 comments on commit 33f1cd2

Please sign in to comment.