6
6
7
7
from common .tf2_layer_test_class import CommonTF2LayerTest
8
8
9
+ def fn_1 (x ):
10
+ return (x [0 ] * x [1 ] + x [2 ])
11
+
12
+ def fn_2 (x ):
13
+ return (x [0 ] + x [1 ] + x [2 ], x [0 ] - x [2 ] + x [1 ], 2 + x [2 ])
14
+
15
+ def fn_3 (x ):
16
+ return (x [0 ] * x [1 ])
17
+
18
+ def fn_4 (x ):
19
+ return (x [0 ] * x [1 ] + 2 * x [2 ])
20
+
21
+ def fn_5 (x ):
22
+ return (x [0 ] * x [1 ], x [0 ] + x [1 ])
23
+
24
+ def fn_6 (x ):
25
+ return (x [0 ] * x [1 ] + x [2 ], x [0 ] + x [2 ] * x [1 ], 2 * x [2 ])
26
+
27
+ def fn_7 (x ):
28
+ return (x [0 ] * x [1 ] + x [2 ])
29
+
30
+ def fn_8 (x ):
31
+ return (x [0 ] + x [1 ] + x [2 ], x [0 ] - x [2 ] + x [1 ], 2 + x [2 ])
32
+
33
+ list_fns = [fn_1 , fn_2 , fn_3 , fn_4 , fn_5 , fn_6 , fn_7 , fn_8 ]
9
34
10
35
class MapFNLayer (tf .keras .layers .Layer ):
11
36
def __init__ (self , fn , input_type , fn_output_signature , back_prop ):
12
37
super (MapFNLayer , self ).__init__ ()
13
- self .fn = fn
38
+ self .fn = list_fns [ fn - 1 ]
14
39
self .input_type = input_type
15
40
self .fn_output_signature = fn_output_signature
16
41
self .back_prop = back_prop
@@ -20,7 +45,6 @@ def call(self, x):
20
45
fn_output_signature = self .fn_output_signature ,
21
46
back_prop = self .back_prop )
22
47
23
-
24
48
class TestMapFN (CommonTF2LayerTest ):
25
49
def create_map_fn_net (self , fn , input_type , fn_output_signature , back_prop ,
26
50
input_names , input_shapes , ir_version ):
@@ -39,10 +63,10 @@ def create_map_fn_net(self, fn, input_type, fn_output_signature, back_prop,
39
63
return tf2_net , ref_net
40
64
41
65
test_basic = [
42
- dict (fn = lambda x : x [ 0 ] * x [ 1 ] + x [ 2 ] , input_type = tf .float32 ,
66
+ dict (fn = 1 , input_type = tf .float32 ,
43
67
fn_output_signature = tf .float32 , back_prop = False ,
44
68
input_names = ["x1" , "x2" , "x3" ], input_shapes = [[2 , 3 , 4 ], [2 , 3 , 4 ], [2 , 3 , 4 ]]),
45
- pytest .param (dict (fn = lambda x : ( x [ 0 ] + x [ 1 ] + x [ 2 ], x [ 0 ] - x [ 2 ] + x [ 1 ], 2 + x [ 2 ]) ,
69
+ pytest .param (dict (fn = 2 ,
46
70
input_type = tf .float32 ,
47
71
fn_output_signature = (tf .float32 , tf .float32 , tf .float32 ), back_prop = True ,
48
72
input_names = ["x1" , "x2" , "x3" ],
@@ -59,10 +83,10 @@ def test_basic(self, params, ie_device, precision, ir_version, temp_dir, use_leg
59
83
** params )
60
84
61
85
test_multiple_inputs = [
62
- dict (fn = lambda x : x [ 0 ] * x [ 1 ] , input_type = tf .float32 ,
86
+ dict (fn = 3 , input_type = tf .float32 ,
63
87
fn_output_signature = tf .float32 , back_prop = True ,
64
88
input_names = ["x1" , "x2" ], input_shapes = [[2 , 4 ], [2 , 4 ]]),
65
- dict (fn = lambda x : x [ 0 ] * x [ 1 ] + 2 * x [ 2 ] , input_type = tf .float32 ,
89
+ dict (fn = 4 , input_type = tf .float32 ,
66
90
fn_output_signature = tf .float32 , back_prop = False ,
67
91
input_names = ["x1" , "x2" , "x3" ], input_shapes = [[2 , 1 , 3 , 4 ],
68
92
[2 , 1 , 3 , 4 ],
@@ -77,11 +101,11 @@ def test_multiple_inputs(self, params, ie_device, precision, ir_version, temp_di
77
101
** params )
78
102
79
103
test_multiple_outputs = [
80
- pytest .param (dict (fn = lambda x : ( x [ 0 ] * x [ 1 ], x [ 0 ] + x [ 1 ]) , input_type = tf .float32 ,
104
+ pytest .param (dict (fn = 5 , input_type = tf .float32 ,
81
105
fn_output_signature = (tf .float32 , tf .float32 ), back_prop = True ,
82
106
input_names = ["x1" , "x2" ], input_shapes = [[2 , 4 ], [2 , 4 ]]),
83
107
marks = pytest .mark .xfail (reason = "61587" )),
84
- pytest .param (dict (fn = lambda x : ( x [ 0 ] * x [ 1 ] + x [ 2 ], x [ 0 ] + x [ 2 ] * x [ 1 ], 2 * x [ 2 ]) ,
108
+ pytest .param (dict (fn = 6 ,
85
109
input_type = tf .float32 ,
86
110
fn_output_signature = (tf .float32 , tf .float32 , tf .float32 ), back_prop = True ,
87
111
input_names = ["x1" , "x2" , "x3" ],
@@ -97,12 +121,12 @@ def test_multiple_outputs(self, params, ie_device, precision, ir_version, temp_d
97
121
** params )
98
122
99
123
test_multiple_inputs_outputs_int32 = [
100
- dict (fn = lambda x : x [ 0 ] * x [ 1 ] + x [ 2 ] ,
124
+ dict (fn = 7 ,
101
125
input_type = tf .int32 ,
102
126
fn_output_signature = tf .int32 , back_prop = True ,
103
127
input_names = ["x1" , "x2" , "x3" ],
104
128
input_shapes = [[2 , 1 , 3 ], [2 , 1 , 3 ], [2 , 1 , 3 ]]),
105
- pytest .param (dict (fn = lambda x : ( x [ 0 ] + x [ 1 ] + x [ 2 ], x [ 0 ] - x [ 2 ] + x [ 1 ], 2 + x [ 2 ]) ,
129
+ pytest .param (dict (fn = 8 ,
106
130
input_type = tf .int32 ,
107
131
fn_output_signature = (tf .int32 , tf .int32 , tf .int32 ), back_prop = True ,
108
132
input_names = ["x1" , "x2" , "x3" ],
0 commit comments