@@ -198,37 +198,45 @@ def find_named_entry(name, entries):
198
198
return entry_args
199
199
return None
200
200
201
- def is_ambiguous (r , c ):
201
+ def accept_results (r , c ):
202
+ if r == c :
203
+ return True
204
+
202
205
# TODO: Handle cases with non-unique md tags
203
206
# * multiple size-1 dimensions with the same stride
204
207
# * multiple dimensions with 0 stride
205
- if driver != "matmul" :
206
- return False
207
- # XXX: In matmul cases with runtime dims that resolve to ones, the bias
208
- # memory descriptor will potentially have the wrong mask printed in the
209
- # verbose line. We do not maintain enough information to always print
210
- # the correct mask, but the reference and computed verbose lines will
211
- # match, up to implementation name.
212
- parts = r .split ("," )
213
- mds = parts [8 ].split ()
214
- aux = parts [10 ].split ()
215
- shapes = parts [11 ].split (":" , 1 )
216
- wei , act = list (map (lambda x : list (map (int , x .split ("x" ))), shapes ))
217
- if find_named_entry ("bia" , mds ) is None :
218
- return False
219
- rt_dim_mask = find_named_entry ("runtime_dims_masks" , aux )
220
- if rt_dim_mask is None :
221
- return False
222
- wei_mask , act_mask = list (map (int , rt_dim_mask ))
223
- if wei [- 2 ] == 1 and wei_mask & (1 << (len (wei ) - 2 )):
224
- return without_impl (r ) == without_impl (c )
225
- if act [- 1 ] == 1 and act_mask & (1 << (len (act ) - 1 )):
208
+ if driver == "matmul" :
209
+ # In matmul cases with runtime dims that resolve to ones, the bias
210
+ # memory descriptor will potentially have the wrong mask printed in
211
+ # the verbose line. We do not maintain enough information to always
212
+ # print the correct mask, but the reference and computed verbose
213
+ # lines will match, up to implementation name.
214
+ parts = r .split ("," )
215
+ mds = parts [8 ].split ()
216
+ aux = parts [10 ].split ()
217
+ shapes = parts [11 ].split (":" , 1 )
218
+ wei , act = list (map (lambda x : list (map (int , x .split ("x" ))), shapes ))
219
+ if find_named_entry ("bia" , mds ) is None :
220
+ return False
221
+ rt_dim_mask = find_named_entry ("runtime_dims_masks" , aux )
222
+ if rt_dim_mask is None :
223
+ return False
224
+ wei_mask , act_mask = list (map (int , rt_dim_mask ))
225
+ if wei [- 2 ] == 1 and wei_mask & (1 << (len (wei ) - 2 )):
226
+ return without_impl (r ) == without_impl (c )
227
+ if act [- 1 ] == 1 and act_mask & (1 << (len (act ) - 1 )):
228
+ return without_impl (r ) == without_impl (c )
229
+ elif driver == "sum" :
230
+ # There is no information in a sum verbose line about scales, so if
231
+ # dispatch depends on particular scale values, the implementation
232
+ # may change with default scales. In this case, we check that the
233
+ # rest of the verbose line is the same.
226
234
return without_impl (r ) == without_impl (c )
227
235
return False
228
236
229
237
file_map = {"reference" : ref_v , "computed" : comp_v }
230
238
for r , c in zip (filter_lines (ref_v ), filter_lines (comp_v )):
231
- if r == c or is_ambiguous (r , c ):
239
+ if accept_results (r , c ):
232
240
continue
233
241
for log_type , content in file_map .items ():
234
242
with open (f"{ driver } .{ log_type } .log" , "w" ) as fd :
@@ -251,6 +259,11 @@ def test(path_to_benchdnn, engine, driver, batch):
251
259
com_batch = generate_batch (ref_verbose , driver )
252
260
com_verbose = generate_verbose (path_to_benchdnn , engine , driver , com_batch )
253
261
compare (driver , ref_verbose , com_verbose )
262
+ # XXX: Maybe run an additional loop
263
+ # ref -> ref verbose -> com 1 -> com 1 verbose -> com 2 -> com 2 verbose
264
+ # Comparing com 1 and com 2 verbose instead would address the special cases
265
+ # in accept_results. We can even compare just the cases where ref and com 1
266
+ # don't match.
254
267
255
268
256
269
def main ():
0 commit comments