21
21
import os
22
22
from collections import defaultdict
23
23
from scipy .stats import ttest_ind
24
+ import warnings
25
+ import statistics
24
26
25
27
26
28
def compare_two_benchdnn (file1 , file2 , tolerance = 0.05 ):
@@ -38,9 +40,9 @@ def compare_two_benchdnn(file1, file2, tolerance=0.05):
38
40
r2 = [x .split ("," ) for x in r2 if x [0 :8 ] == "--mode=P" ]
39
41
40
42
if (len (r1 ) == 0 ) or (len (r2 ) == 0 ):
41
- raise Exception ("One or both of the test results have zero lines" )
43
+ warnings . warn ("One or both of the test results have zero lines" )
42
44
if len (r1 ) != len (r2 ):
43
- raise Exception ("The number of benchdnn runs do not match" )
45
+ warnings . warn ("The number of benchdnn runs do not match" )
44
46
45
47
r1_samples = defaultdict (list )
46
48
r2_samples = defaultdict (list )
@@ -50,26 +52,33 @@ def compare_two_benchdnn(file1, file2, tolerance=0.05):
50
52
for k , v in r2 :
51
53
r2_samples [k ].append (float (v [:- 1 ]))
52
54
53
- passed = True
54
55
failed_tests = []
56
+ times = {}
55
57
for prb , r1_times in r1_samples .items ():
56
58
if prb not in r2_samples :
57
- raise Exception (f"{ prb } exists in { file1 } but not { file2 } " )
59
+ warnings .warn (f"{ prb } exists in { file1 } but not { file2 } " )
60
+ continue
61
+
58
62
r2_times = r2_samples [prb ]
59
63
60
64
res = ttest_ind (r2_times , r1_times , alternative = 'greater' )
61
-
62
- if res .pvalue < 0.05 :
63
- failed_tests .append (prb )
65
+ r1_med = statistics .median (r1_times )
66
+ r2_med = statistics .median (r2_times )
67
+ times [prb ] = (r1_med , r2_med )
68
+ times_str = f" { times [prb ][0 ]} vs { times [prb ][1 ]} "
69
+
70
+ passed = res .pvalue > 0.05 or \
71
+ ((r2_med - r1_med ) / r1_med < 0.1 and \
72
+ (min (r2_times ) - min (r1_times )) / min (r1_times ) < 0.1 )
73
+ if not passed :
74
+ failed_tests .append (prb + times_str )
64
75
passed = False
65
76
66
- print (prb + (" passed" if passed else " failed" ))
67
-
68
77
if "GITHUB_OUTPUT" in os .environ :
69
78
with open (os .environ ["GITHUB_OUTPUT" ], "a" ) as f :
70
- print (f"pass={ passed } " , file = f )
79
+ print (f"pass={ not failed_tests } " , file = f )
71
80
72
- if passed :
81
+ if not failed_tests :
73
82
print ("Regression tests passed" )
74
83
else :
75
84
message = "\n ----The following regression tests failed:----\n " + \
0 commit comments