Skip to content

Commit 18f12ad

Browse files
Use SPIRV integer vector dot product instructions (#6141)
* Use SPIRV integer vector dot product instructions * fix test --------- Co-authored-by: Yong He <yonghe@outlook.com>
1 parent 14211ec commit 18f12ad

File tree

3 files changed

+62
-16
lines changed

3 files changed

+62
-16
lines changed

source/slang/hlsl.meta.slang

+24
Original file line numberDiff line numberDiff line change
@@ -8292,6 +8292,30 @@ T dot(vector<T, N> x, vector<T, N> y)
82928292
{
82938293
case hlsl: __intrinsic_asm "dot";
82948294
case wgsl: __intrinsic_asm "dot";
8295+
case spirv:
8296+
{
8297+
spirv_asm
8298+
{
8299+
OpCapability DotProduct;
8300+
OpCapability DotProductInputAll;
8301+
OpExtension "SPV_KHR_integer_dot_product";
8302+
};
8303+
8304+
if (__isSignedInt<T>())
8305+
{
8306+
return spirv_asm
8307+
{
8308+
result:$$T = OpSDot $x $y;
8309+
};
8310+
}
8311+
else
8312+
{
8313+
return spirv_asm
8314+
{
8315+
result:$$T = OpUDot $x $y;
8316+
};
8317+
}
8318+
}
82958319
default:
82968320
T result = T(0);
82978321
for(int i = 0; i < N; ++i)
+33-12
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,44 @@
1-
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj
2-
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
3-
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -shaderobj
4-
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
5-
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj
1+
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -shaderobj -output-using-type
3+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
4+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
5+
6+
// No 16-bit and 64-bit integer support on DX11.
7+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -xslang -DDX11
8+
9+
//TEST(compute, vulkan):SIMPLE(filecheck=CHECK_SPV): -stage compute -entry computeMain -target spirv
610

711
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
812
RWStructuredBuffer<int> outputBuffer;
913

1014
[numthreads(1, 1, 1)]
1115
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
1216
{
13-
int idx = int(dispatchThreadID.x);
17+
int index = int(dispatchThreadID.x);
18+
uint outIndex = 0;
1419

15-
float tmp = dot(float3(idx), float3(1));
16-
17-
int3 a = { idx + 1, idx + 2, idx + 3};
20+
// CHECK_SPV: OpSDot
21+
int3 a = { index - 1, index - 2, index - 3};
1822
int3 b = { 1, 2, 3};
23+
outputBuffer[outIndex++] = dot(a, b);
24+
25+
// CHECK_SPV: OpUDot
26+
uint3 c = { index + 1, index + 2, index + 3};
27+
uint3 d = { 2, 4, 6};
28+
outputBuffer[outIndex++] = int(dot(c, d));
1929

20-
int result = dot(a, b);
30+
#if !defined(DX11)
31+
// CHECK_SPV: OpUDot
32+
uint64_t2 e = { index + 1, index + 2};
33+
uint64_t2 f = { 4, 8};
34+
outputBuffer[outIndex++] = int(dot(e, f));
2135

22-
outputBuffer[idx] = result;
23-
}
36+
// CHECK_SPV: OpSDot
37+
int16_t4 g = { int16_t(index + 1), int16_t(index + 2), int16_t(index + 3), int16_t(index + 4)};
38+
int16_t4 h = { -1, 2, 2, -1};
39+
outputBuffer[outIndex++] = int(dot(g, h));
40+
#else
41+
outputBuffer[outIndex++] = 20;
42+
outputBuffer[outIndex++] = 5;
43+
#endif
44+
}
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
E
2-
0
3-
0
4-
0
1+
type: int32_t
2+
-14
3+
28
4+
20
5+
5

0 commit comments

Comments
 (0)