Skip to content

Commit 825cf55

Browse files
committedMar 17, 2025
benchdnn: inputs: graph: add sdpa cases with f32 intermediate type
1 parent 2bfce18 commit 825cf55

File tree

3 files changed

+356
-0
lines changed

3 files changed

+356
-0
lines changed
 

‎tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
1616
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
1717
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
18+
# f16 inputs + f32 intermediates + f16 outputs
19+
--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
20+
# bf16 inputs + f32 intermediates + bf16 outputs
21+
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
1822

1923
# int8 graphs
2024
--reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json

‎tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
1414
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
1515
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
16+
# f16 inputs + f32 intermediates + f16 outputs
17+
--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
18+
# bf16 inputs + f32 intermediates + bf16 outputs
19+
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
20+
1621

1722
# int8 graphs
1823
--reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
{
2+
"version": "3.8.0",
3+
"engine_kind": "cpu",
4+
"fpmath_mode": "strict",
5+
"fpmath_mode_apply_to_int": "false",
6+
"input_ports": [
7+
1,
8+
2,
9+
4,
10+
5,
11+
3
12+
],
13+
"output_ports": [
14+
6
15+
],
16+
"graph": [
17+
{
18+
"id": 0,
19+
"name": "matmul_qk",
20+
"kind": "MatMul",
21+
"attrs": {
22+
"transpose_a": {
23+
"type": "bool",
24+
"value": 0
25+
},
26+
"transpose_b": {
27+
"type": "bool",
28+
"value": 1
29+
}
30+
},
31+
"inputs": [
32+
{
33+
"id": 1,
34+
"dtype": "f16",
35+
"shape": [
36+
1,
37+
16,
38+
384,
39+
64
40+
],
41+
"stride": [
42+
393216,
43+
24576,
44+
64,
45+
1
46+
],
47+
"layout_type": "strided",
48+
"property_type": "undef"
49+
},
50+
{
51+
"id": 2,
52+
"dtype": "f16",
53+
"shape": [
54+
1,
55+
16,
56+
384,
57+
64
58+
],
59+
"stride": [
60+
393216,
61+
24576,
62+
64,
63+
1
64+
],
65+
"layout_type": "strided",
66+
"property_type": "undef"
67+
}
68+
],
69+
"outputs": [
70+
{
71+
"id": 101,
72+
"dtype": "f32",
73+
"shape": [
74+
1,
75+
16,
76+
384,
77+
384
78+
],
79+
"stride": [
80+
2359296,
81+
147456,
82+
384,
83+
1
84+
],
85+
"layout_type": "strided",
86+
"property_type": "undef"
87+
}
88+
]
89+
},
90+
{
91+
"id": 1,
92+
"name": "scale_div",
93+
"kind": "Divide",
94+
"attrs": {
95+
"auto_broadcast": {
96+
"type": "string",
97+
"value": "numpy"
98+
}
99+
},
100+
"inputs": [
101+
{
102+
"id": 101,
103+
"dtype": "f32",
104+
"shape": [
105+
1,
106+
16,
107+
384,
108+
384
109+
],
110+
"stride": [
111+
2359296,
112+
147456,
113+
384,
114+
1
115+
],
116+
"layout_type": "strided",
117+
"property_type": "undef"
118+
},
119+
{
120+
"id": 4,
121+
"dtype": "f16",
122+
"shape": [
123+
1
124+
],
125+
"stride": [
126+
1
127+
],
128+
"layout_type": "strided",
129+
"property_type": "constant"
130+
}
131+
],
132+
"outputs": [
133+
{
134+
"id": 102,
135+
"dtype": "f32",
136+
"shape": [
137+
1,
138+
16,
139+
384,
140+
384
141+
],
142+
"stride": [
143+
2359296,
144+
147456,
145+
384,
146+
1
147+
],
148+
"layout_type": "strided",
149+
"property_type": "undef"
150+
}
151+
]
152+
},
153+
{
154+
"id": 2,
155+
"name": "mask_add",
156+
"kind": "Add",
157+
"attrs": {
158+
"auto_broadcast": {
159+
"type": "string",
160+
"value": "numpy"
161+
}
162+
},
163+
"inputs": [
164+
{
165+
"id": 102,
166+
"dtype": "f32",
167+
"shape": [
168+
1,
169+
16,
170+
384,
171+
384
172+
],
173+
"stride": [
174+
2359296,
175+
147456,
176+
384,
177+
1
178+
],
179+
"layout_type": "strided",
180+
"property_type": "undef"
181+
},
182+
{
183+
"id": 5,
184+
"dtype": "f16",
185+
"shape": [
186+
1,
187+
1,
188+
384,
189+
384
190+
],
191+
"stride": [
192+
147456,
193+
147456,
194+
384,
195+
1
196+
],
197+
"layout_type": "strided",
198+
"property_type": "undef"
199+
}
200+
],
201+
"outputs": [
202+
{
203+
"id": 103,
204+
"dtype": "f32",
205+
"shape": [
206+
1,
207+
16,
208+
384,
209+
384
210+
],
211+
"stride": [
212+
2359296,
213+
147456,
214+
384,
215+
1
216+
],
217+
"layout_type": "strided",
218+
"property_type": "undef"
219+
}
220+
]
221+
},
222+
{
223+
"id": 3,
224+
"name": "softmax",
225+
"kind": "SoftMax",
226+
"attrs": {
227+
"axis": {
228+
"type": "s64",
229+
"value": -1
230+
}
231+
},
232+
"inputs": [
233+
{
234+
"id": 103,
235+
"dtype": "f32",
236+
"shape": [
237+
1,
238+
16,
239+
384,
240+
384
241+
],
242+
"stride": [
243+
2359296,
244+
147456,
245+
384,
246+
1
247+
],
248+
"layout_type": "strided",
249+
"property_type": "undef"
250+
}
251+
],
252+
"outputs": [
253+
{
254+
"id": 104,
255+
"dtype": "f16",
256+
"shape": [
257+
1,
258+
16,
259+
384,
260+
384
261+
],
262+
"stride": [
263+
2359296,
264+
147456,
265+
384,
266+
1
267+
],
268+
"layout_type": "strided",
269+
"property_type": "undef"
270+
}
271+
]
272+
},
273+
{
274+
"id": 4,
275+
"name": "matmul_v",
276+
"kind": "MatMul",
277+
"attrs": {
278+
"transpose_a": {
279+
"type": "bool",
280+
"value": 0
281+
},
282+
"transpose_b": {
283+
"type": "bool",
284+
"value": 0
285+
}
286+
},
287+
"inputs": [
288+
{
289+
"id": 104,
290+
"dtype": "f16",
291+
"shape": [
292+
1,
293+
16,
294+
384,
295+
384
296+
],
297+
"stride": [
298+
2359296,
299+
147456,
300+
384,
301+
1
302+
],
303+
"layout_type": "strided",
304+
"property_type": "undef"
305+
},
306+
{
307+
"id": 3,
308+
"dtype": "f16",
309+
"shape": [
310+
1,
311+
16,
312+
384,
313+
64
314+
],
315+
"stride": [
316+
393216,
317+
24576,
318+
64,
319+
1
320+
],
321+
"layout_type": "strided",
322+
"property_type": "undef"
323+
}
324+
],
325+
"outputs": [
326+
{
327+
"id": 6,
328+
"dtype": "f16",
329+
"shape": [
330+
1,
331+
16,
332+
384,
333+
64
334+
],
335+
"stride": [
336+
393216,
337+
24576,
338+
64,
339+
1
340+
],
341+
"layout_type": "strided",
342+
"property_type": "undef"
343+
}
344+
]
345+
}
346+
]
347+
}

0 commit comments

Comments
 (0)