@@ -1016,6 +1016,28 @@ def forward(self, x):
1016
1016
"compress_to_fp16" : False }
1017
1017
1018
1018
1019
+ def create_pytorch_module_with_nested_dict_input (tmp_dir ):
1020
+ class PTModel (torch .nn .Module ):
1021
+ def forward (self , a , b ):
1022
+ return a ["1" ] * a ["2" ] + b
1023
+
1024
+ net = PTModel ()
1025
+ a1 = ov .opset10 .parameter (PartialShape ([- 1 ]), dtype = np .float32 )
1026
+ a2 = ov .opset10 .parameter (PartialShape ([- 1 ]), dtype = np .float32 )
1027
+ b = ov .opset10 .parameter (PartialShape ([- 1 ]), dtype = np .float32 )
1028
+ mul = ov .opset10 .multiply (a1 , a2 )
1029
+ add = ov .opset10 .add (mul , b )
1030
+ ref_model = Model ([add ], [a1 , a2 , b ], "test" )
1031
+ return net , ref_model , {
1032
+ "example_input" : (
1033
+ {
1034
+ "1" : torch .tensor ([1 , 2 ], dtype = torch .float32 ),
1035
+ "2" : torch .tensor ([3 , 4 ], dtype = torch .float32 )
1036
+ },
1037
+ torch .tensor ([5 , 6 ], dtype = torch .float32 )
1038
+ )}
1039
+
1040
+
1019
1041
class TestMoConvertPyTorch (CommonMOConvertTest ):
1020
1042
test_data = [
1021
1043
create_pytorch_nn_module_case1 ,
@@ -1067,7 +1089,8 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
1067
1089
create_pytorch_module_with_nested_inputs5 ,
1068
1090
create_pytorch_module_with_nested_inputs6 ,
1069
1091
create_pytorch_module_with_nested_list_and_single_input ,
1070
- create_pytorch_module_with_single_input_as_list
1092
+ create_pytorch_module_with_single_input_as_list ,
1093
+ create_pytorch_module_with_nested_dict_input
1071
1094
]
1072
1095
1073
1096
@pytest .mark .parametrize ("create_model" , test_data )
0 commit comments