Skip to content

Commit 8196dc4

Browse files
csyongheTim Foley
authored and
Tim Foley
committed
specialize witness tables when needed when specializing lookup_witness_table instruction. (shader-slang#376)
1 parent 4044a1d commit 8196dc4

File tree

5 files changed

+64
-1
lines changed

5 files changed

+64
-1
lines changed

source/slang/ir.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -4862,6 +4862,20 @@ namespace Slang
48624862
auto interfaceDeclRef = ((IRDeclRef*)lookupInst->interfaceType.usedValue)->declRef;
48634863
auto mangledName = getMangledNameForConformanceWitness(srcDeclRef, interfaceDeclRef);
48644864
witnessTables.TryGetValue(mangledName, witnessTable);
4865+
4866+
if (!witnessTable)
4867+
{
4868+
// try specialize the witness table
4869+
auto genDeclRef = srcDeclRef;
4870+
genDeclRef.substitutions = createDefaultSubstitutions(module->session, genDeclRef.decl);
4871+
auto genName = getMangledNameForConformanceWitness(genDeclRef, interfaceDeclRef);
4872+
IRWitnessTable* genTable = nullptr;
4873+
if (witnessTables.TryGetValue(genName, genTable))
4874+
{
4875+
witnessTable = specializeWitnessTable(sharedContext, genTable, srcDeclRef, nullptr);
4876+
witnessTables.AddIfNotExists(witnessTable->mangledName, witnessTable);
4877+
}
4878+
}
48654879
if (witnessTable)
48664880
{
48674881
lookupInst->replaceUsesWith(witnessTable);

source/slang/syntax.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,12 @@ void Type::accept(IValVisitor* visitor, void* extra)
856856

857857
Type* ErrorType::CreateCanonicalType()
858858
{
859-
return this;
859+
return this;
860+
}
861+
862+
RefPtr<Val> ErrorType::SubstituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/)
863+
{
864+
return this;
860865
}
861866

862867
int ErrorType::GetHashCode()

source/slang/type-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ RAW(
3636

3737
protected:
3838
virtual bool EqualsImpl(Type * type) override;
39+
virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override;
3940
virtual Type* CreateCanonicalType() override;
4041
virtual int GetHashCode() override;
4142
)

tests/compute/int-generic.slang

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
//TEST_INPUT:type Material<1,2>
4+
RWStructuredBuffer<int> outputBuffer;
5+
6+
interface IBRDF
7+
{
8+
int compute();
9+
};
10+
11+
interface IMaterial
12+
{
13+
associatedtype TBRDF : IBRDF;
14+
TBRDF getBRDF();
15+
}
16+
17+
struct BRDF<let A:int, let B:int> : IBRDF
18+
{
19+
int c;
20+
int compute()
21+
{
22+
return A+B;
23+
}
24+
};
25+
26+
struct Material<let A:int, let B: int> : IMaterial
27+
{
28+
typedef BRDF<A,B> TBRDF;
29+
TBRDF getBRDF() { TBRDF a; a.c = 0; return a; }
30+
};
31+
32+
type_param TMaterial : IMaterial;
33+
34+
TMaterial material;
35+
36+
[numthreads(1, 1, 1)]
37+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
38+
{
39+
TMaterial.TBRDF brdf = material.getBRDF();
40+
int outVal = brdf.compute();
41+
outputBuffer[dispatchThreadID.x] = outVal;
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3

0 commit comments

Comments
 (0)