-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathharness_mha_all
118 lines (108 loc) · 11.3 KB
/
harness_mha_all
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# floating-point graphs
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/JAX-MHA-inf-fp32.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/JAX-MQA-inf-fp32.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-LLaMa-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-starcoder-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA_backward-Bert_large-train-fp32-bs4.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA_forward-Bert_large-train-fp32-bs4.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
# int8 graphs
--reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
--reset --expected-n-partitions=0 --case=complex_fusion/mha/MHA-LLaMa-inf-int8-bs1.json
--reset --case=complex_fusion/mha/MHA-bert_large-inf-int8-bs1.json
--reset --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --expected-n-partitions=0 --case=complex_fusion/mha/MHA-starcoder-inf-int8-bs1.json
--reset --expected-n-partitions=0 --case=complex_fusion/mha/dynamic_quantized_mha-Bert_large-inf-int8-bs1-fake.json
--reset --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json
--reset --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-v-int8-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json
# Re-written graphs
--reset --dt=f32,bf16,f16 --in-shapes=4:4x16x32x256+5:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=3:4x32x32x128+4:4x32x128x33+0:4x32x33x128+1:4x1x32x33 --case=complex_fusion/mha/MHA-LLaMa-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --mb=20 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=3:10x16x384x64+4:10x1x64x384+0:10x1x384x64+1:10x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=4:56x12x128x64+5:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=0:56x8x1024x80+1:56x8x77x80+2:56x8x77x80 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=5:20x117x48x128+6:20x1x128x117+19:20x1x117x128 --case=complex_fusion/mha/MHA-starcoder-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=2514:32x16x512x64+2518:32x16x512x64+2543:32x1x512x512+2547:32x16x512x512+2525:32x16x512x64 --op-attrs=4837:shape:16384x1024 --case=complex_fusion/mha/MHA_forward-Bert_large-train-fp32-bs4.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=2514:32x16x512x64+2518:32x16x512x64+2591:32x16x512x512+2545:32x16x512x512+2547:32x16x512x512+2525:32x16x512x64+2548:32x16x512x512+5178:16384x1024 --op-attrs=7392:shape:32x512x16x64 --case=complex_fusion/mha/MHA_backward-Bert_large-train-fp32-bs4.json
--reset --dt=f32,bf16,f16 --in-shapes=0:20x16x384x64+1:20x16x384x64+8:20x16x384x64+5:20x1x1x384 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=5:1x1x384x384,5:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:2x16x384x64+1:2x16x384x64+5:2x1x1x384+8:2x16x384x64 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:32x16x128x64+1:32x16x128x64+5:32x16x128x128+8:32x16x128x64 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:acbd+1:acbd+8:acbd --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=3:384,3:384x384,3:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=5:384,5:1x384 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --op-attrs=34107656704:group_shape:1x1x1x32+34107654464:transpose_b:1 --in-shapes=0:1x32x32x128+1:1x32x32x4+2:1x32x32x4 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --op-attrs=34107656704:qtype:per_channel*axis:3 --in-shapes=1:32+2:1 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --op-attrs=34107656704:group_shape:1x1x128x1 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --dt=f32,bf16,f16 --op-attrs=2:axis:-2+3:axis:-1 --in-shapes=0:1x32x128x64+1:1x32x128x64+11:1x32x128x64 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
# Re-written int8 graphs
--reset --in-shapes=5:4x16x32x256+4:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
--reset --expected-n-partitions=0 --in-shapes=4:4x32x32x128+3:4x32x128x33+0:4x32x33x128+1:4x1x32x33 --case=complex_fusion/mha/MHA-LLaMa-inf-int8-bs1.json
--reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-int8-bs1.json
--reset --in-shapes=5:56x12x128x64+4:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --expected-n-partitions=0 --in-shapes=4:20x117x48x128+3:20x1x128x117+0:20x1x117x128 --case=complex_fusion/mha/MHA-starcoder-inf-int8-bs1.json
--reset --expected-n-partitions=0 --in-shapes=4:32x16x384x64+3:32x16x64x384+0:32x16x384x64+1:32x1x1x384 --case=complex_fusion/mha/dynamic_quantized_mha-Bert_large-inf-int8-bs1-fake.json
--reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json
--reset --in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1 --op-attrs=0:group_shape:1x1x96x1+8:group_shape:1x1x1x96 --case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json
# phi3-mini-4k-instruct
--reset
--dt=0:s8+2:s8+6:s8+8:s8
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+5:1x1x384x384+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1,\
0:1x32x96x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x96+5:1x1x385x385+6:1x32x385x96+7:1x32x385x1+8:1x32x385x1,\
0:1x32x96x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x96+5:1x1x512x512+6:1x32x512x96+7:1x32x512x1+8:1x32x512x1,\
0:1x32x96x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x96+5:1x1x513x513+6:1x32x513x96+7:1x32x513x1+8:1x32x513x1,\
0:1x32x96x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x96+5:1x1x1024x1024+6:1x32x1024x96+7:1x32x1024x1+8:1x32x1024x1,\
0:1x32x96x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x96+5:1x1x1025x1025+6:1x32x1025x96+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x96x1+99:group_shape:1x1x1x96
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
--reset
--dt=0:s8+2:s8+6:s8+8:s8
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1,\
0:1x32x96x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x96+6:1x32x385x96+7:1x32x385x1+8:1x32x385x1,\
0:1x32x96x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x96+6:1x32x512x96+7:1x32x512x1+8:1x32x512x1,\
0:1x32x96x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x96+6:1x32x513x96+7:1x32x513x1+8:1x32x513x1,\
0:1x32x96x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x96+6:1x32x1024x96+7:1x32x1024x1+8:1x32x1024x1,\
0:1x32x96x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x96+6:1x32x1025x96+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x96x1+8:group_shape:1x1x1x96
--case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json
# llama-2-7b-chat
--in-shapes=0:1x32x128x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x128+5:1x1x384x384+6:1x32x384x128+7:1x32x384x1+8:1x32x384x1,\
0:1x32x128x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x128+5:1x1x385x385+6:1x32x385x128+7:1x32x385x1+8:1x32x385x1,\
0:1x32x128x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x128+5:1x1x512x512+6:1x32x512x128+7:1x32x512x1+8:1x32x512x1,\
0:1x32x128x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x128+5:1x1x513x513+6:1x32x513x128+7:1x32x513x1+8:1x32x513x1,\
0:1x32x128x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x128+5:1x1x1024x1024+6:1x32x1024x128+7:1x32x1024x1+8:1x32x1024x1,\
0:1x32x128x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x128+5:1x1x1025x1025+6:1x32x1025x128+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x128x1+99:group_shape:1x1x1x128
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
--in-shapes=0:1x32x128x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x128+6:1x32x384x128+7:1x32x384x1+8:1x32x384x1,\
0:1x32x128x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x128+6:1x32x385x128+7:1x32x385x1+8:1x32x385x1,\
0:1x32x128x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x128+6:1x32x512x128+7:1x32x512x1+8:1x32x512x1,\
0:1x32x128x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x128+6:1x32x513x128+7:1x32x513x1+8:1x32x513x1,\
0:1x32x128x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x128+6:1x32x1024x128+7:1x32x1024x1+8:1x32x1024x1,\
0:1x32x128x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x128+6:1x32x1025x128+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x128x1+8:group_shape:1x1x1x128
--case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json
# 0: key, 2: key zps, 6: value, 8: value zps. Change them to use s8 data type.
--reset --dt=0:s8+2:s8+6:s8+8:s8 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
# Change group size to 128. It also affects the shapes of the scales and zps.
--reset --dt=0:s8+2:s8+6:s8+8:s8 --op-attrs=0:group_shape:1x1x128x1+99:group_shape:1x1x1x128 --in-shapes=1:1x32x1x32+2:1x32x1x32+7:1x32x32x1+8:1x32x32x1 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
# d_qk != d_v
--reset --dt=f32,bf16,f16 --in-shapes=8:1x16x384x32,8:1x16x384x64,8:1x16x384x128 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=11:1x16x384x32,11:1x16x384x64,11:1x16x384x128 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
# regression tests
--reset --dt=f32,bf16,f16 --in-shapes=0:1x8x4096x16+1:1x8x5x16+5:1x1x4096x5+8:1x8x5x16 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json