Skip to content

Commit ee052a9

Browse files
FIx issue with specializing witness tables (shader-slang#4839)
1 parent 1124407 commit ee052a9

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

source/slang/slang-ir-specialize.cpp

+14-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ struct SpecializationContext
131131
return false;
132132
}
133133

134+
// Check if an inst is a dynamic dispatch witness table.
135+
// These insts may not have any uses yet, and do not have side effects,
136+
// but should be specialized if necessary.
137+
//
138+
bool isWitnessTableType(IRInst* inst)
139+
{
140+
return inst->findDecoration<IRDynamicDispatchWitnessDecoration>();
141+
}
142+
134143
// When an instruction isn't fully specialized, but its operands *are*
135144
// then it is a candidate for specialization itself, so we will have
136145
// a query to check for the "all operands fully specialized" case.
@@ -826,8 +835,12 @@ struct SpecializationContext
826835
// specialization opportunities (generic specialization,
827836
// existential specialization, simplifications, etc.)
828837
//
829-
if (inst->hasUses() || inst->mightHaveSideEffects())
838+
if (inst->hasUses() ||
839+
inst->mightHaveSideEffects() ||
840+
isWitnessTableType(inst))
841+
{
830842
hasSpecialization |= maybeSpecializeInst(inst);
843+
}
831844

832845
// Finally, we need to make our logic recurse through
833846
// the whole IR module, so we want to add the children
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// unit-test-translation-unit-import.cpp
2+
3+
#include "slang.h"
4+
5+
#include <stdio.h>
6+
#include <stdlib.h>
7+
8+
#include "tools/unit-test/slang-unit-test.h"
9+
#include "slang-com-ptr.h"
10+
#include "../../source/core/slang-io.h"
11+
#include "../../source/core/slang-process.h"
12+
13+
using namespace Slang;
14+
15+
// Test that the IModule::findAndCheckEntryPoint API supports discovering
16+
// entrypoints without a [shader] attribute.
17+
18+
SLANG_UNIT_TEST(genericInterfaceConformance)
19+
{
20+
// Source for a module that contains an undecorated entrypoint.
21+
const char* userSourceBody = R"(
22+
public interface ITestInterface<Real : IFloat> {
23+
Real sample();
24+
}
25+
26+
struct TestInterfaceImpl<Real : IFloat> : ITestInterface<Real> {
27+
Real sample() {
28+
return x;
29+
}
30+
Real x;
31+
}
32+
33+
//TEST_INPUT: set data = new StructuredBuffer<ITestInterface<float> >[new TestInterfaceImpl<float>{1.0}];
34+
StructuredBuffer<ITestInterface<float>> data;
35+
36+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4);
37+
RWStructuredBuffer<int> outputBuffer;
38+
39+
//TEST_INPUT: type_conformance TestInterfaceImpl<float>:ITestInterface<float> = 3
40+
41+
[numthreads(1, 1, 1)]
42+
void computeMain()
43+
{
44+
let obj = data[0];
45+
// CHECK: 1
46+
outputBuffer[0] = int(obj.sample());
47+
}
48+
)";
49+
50+
ComPtr<slang::IGlobalSession> globalSession;
51+
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
52+
slang::TargetDesc targetDesc = {};
53+
targetDesc.format = SLANG_HLSL;
54+
55+
slang::SessionDesc sessionDesc = {};
56+
sessionDesc.targetCount = 1;
57+
sessionDesc.targets = &targetDesc;
58+
sessionDesc.allowGLSLSyntax = true;
59+
60+
ComPtr<slang::ISession> session;
61+
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
62+
63+
ComPtr<slang::IBlob> diagnosticBlob;
64+
auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef());
65+
SLANG_CHECK(module != nullptr);
66+
67+
ComPtr<slang::IEntryPoint> entryPoint;
68+
module->findAndCheckEntryPoint("computeMain", SLANG_STAGE_COMPUTE, entryPoint.writeRef(), diagnosticBlob.writeRef());
69+
SLANG_CHECK(entryPoint != nullptr);
70+
71+
ComPtr<slang::IComponentType> compositeProgram;
72+
slang::IComponentType* components[] = { module, entryPoint.get() };
73+
session->createCompositeComponentType(components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef());
74+
SLANG_CHECK(compositeProgram != nullptr);
75+
76+
ComPtr<slang::ITypeConformance> typeConformance;
77+
auto result = session->createTypeConformanceComponentType(
78+
compositeProgram->getLayout()->findTypeByName("TestInterfaceImpl<float>"),
79+
compositeProgram->getLayout()->findTypeByName("ITestInterface<float>"),
80+
typeConformance.writeRef(),
81+
3,
82+
diagnosticBlob.writeRef());
83+
SLANG_CHECK(result == SLANG_OK);
84+
SLANG_CHECK(typeConformance != nullptr);
85+
86+
ComPtr<slang::IComponentType> compositeProgram2;
87+
slang::IComponentType* components2[] = { compositeProgram.get(), typeConformance.get() };
88+
session->createCompositeComponentType(
89+
components2, 2, compositeProgram2.writeRef(), diagnosticBlob.writeRef());
90+
91+
ComPtr<slang::IComponentType> linkedProgram;
92+
compositeProgram2->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
93+
SLANG_CHECK(linkedProgram != nullptr);
94+
95+
ComPtr<slang::IBlob> code;
96+
linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
97+
SLANG_CHECK(code != nullptr);
98+
99+
auto codeSrc = UnownedStringSlice((const char*)code->getBufferPointer());
100+
SLANG_CHECK(codeSrc.indexOf(toSlice("computeMain")) != -1);
101+
}
102+

0 commit comments

Comments
 (0)