Skip to content

Commit 45ef0ce

Browse files
Fix lowering of associated types and synthesis of dispatch functions. (shader-slang#4568)
* Treat global variables and parameters as non-differentiable when checking derivative data-flow Global parameters are by-default not differentiable (even if they are of a differentiable type), because our auto-diff passes do not touch anything outside of function bodies. The solution is to use wrapper objects with differentiable getter/setter methods (and we should provide a few such objects in the stdlib). Fixes: shader-slang#3289 This is a potentially breaking change: User code that was previously working with global variables of a differentiable type will now throw an error (previously the gradient would be dropped without warning). The solution is to use `detach()` to keep same behavior as before or rewrite the access using differentiable getter/setter methods. * Fix issues with lookup witness lowering * Update slang-ir-lower-witness-lookup.cpp * Add tests * Update slang-ir-lower-witness-lookup.cpp * Cleanup * Update nested-assoc-types.slang --------- Co-authored-by: Yong He <yonghe@outlook.com>
1 parent 16a4781 commit 45ef0ce

5 files changed

+126
-8
lines changed

source/slang/slang-ir-insts.h

-6
Original file line numberDiff line numberDiff line change
@@ -1292,12 +1292,6 @@ struct IRGetSequentialID : IRInst
12921292
IRInst* getRTTIOperand() { return getOperand(0); }
12931293
};
12941294

1295-
struct IRLookupWitnessTable : IRInst
1296-
{
1297-
IRUse sourceType;
1298-
IRUse interfaceType;
1299-
};
1300-
13011295
/// Allocates space from local stack.
13021296
///
13031297
struct IRAlloca : IRInst

source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ struct AssociatedTypeLookupSpecializationContext
158158
builder.setInsertBefore(inst);
159159
auto witnessTableArg = inst->getWitnessTable();
160160
auto callInst = builder.emitCallInst(
161-
builder.getWitnessTableIDType(interfaceType), func, witnessTableArg);
161+
func->getResultType(), func, witnessTableArg);
162162
inst->replaceUsesWith(callInst);
163163
inst->removeAndDeallocate();
164164
}

source/slang/slang-lower-to-ir.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10288,7 +10288,7 @@ static void _addFlattenedTupleArgs(
1028810288

1028910289
bool isAbstractWitnessTable(IRInst* inst)
1029010290
{
10291-
if (as<IRThisTypeWitness>(inst))
10291+
if (as<IRThisTypeWitness>(inst) || as<IRInterfaceRequirementEntry>(inst))
1029210292
return true;
1029310293
if (auto lookup = as<IRLookupWitnessMethod>(inst))
1029410294
return isAbstractWitnessTable(lookup->getWitnessTable());
+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Test calling differentiable function through dynamic dispatch.
2+
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
4+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
5+
6+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
7+
RWStructuredBuffer<float> outputBuffer;
8+
9+
[anyValueSize(16)]
10+
interface IFoo
11+
{
12+
float foo();
13+
}
14+
15+
[anyValueSize(16)]
16+
interface INestedInterface
17+
{
18+
associatedtype NestedAssocType : IFoo;
19+
}
20+
21+
[anyValueSize(16)]
22+
interface IInterface
23+
{
24+
associatedtype MyAssocType : INestedInterface;
25+
MyAssocType.NestedAssocType calc(float x);
26+
}
27+
28+
// ================================
29+
30+
struct A_Assoc_Assoc : IFoo
31+
{
32+
float a;
33+
34+
float foo()
35+
{
36+
return a;
37+
}
38+
}
39+
40+
struct A_Assoc : INestedInterface
41+
{
42+
typedef A_Assoc_Assoc NestedAssocType;
43+
}
44+
45+
struct A : IInterface
46+
{
47+
typedef A_Assoc MyAssocType
48+
49+
int data1;
50+
51+
__init(int data1) { this.data1 = data1; }
52+
53+
A_Assoc_Assoc calc(float x) { return { x * x * x * data1 }; }
54+
};
55+
56+
// ================================
57+
58+
struct B_Assoc_Assoc : IFoo
59+
{
60+
float b;
61+
62+
float foo()
63+
{
64+
return b;
65+
}
66+
}
67+
68+
struct B_Assoc : INestedInterface
69+
{
70+
typedef B_Assoc_Assoc NestedAssocType;
71+
}
72+
73+
struct B : IInterface
74+
{
75+
typedef B_Assoc MyAssocType
76+
77+
int data1;
78+
int data2;
79+
80+
__init(int data1, int data2) { this.data1 = data1; this.data2 = data2; }
81+
82+
B_Assoc_Assoc calc(float x) { return { x * x * data1 * data2 }; }
83+
};
84+
85+
// ================================
86+
87+
float doThing(IInterface obj, float x)
88+
{
89+
let o = obj.calc(x);
90+
return o.foo();
91+
}
92+
93+
float f(uint id, float x)
94+
{
95+
IInterface obj;
96+
97+
switch (id)
98+
{
99+
case 0:
100+
obj = A(2);
101+
break;
102+
103+
default:
104+
obj = B(2, 3);
105+
}
106+
107+
return doThing(obj, x);
108+
}
109+
110+
//TEST_INPUT: type_conformance A:IInterface = 0
111+
//TEST_INPUT: type_conformance B:IInterface = 1
112+
113+
[numthreads(1, 1, 1)]
114+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
115+
{
116+
outputBuffer[0] = f(dispatchThreadID.x, 1.0); // A.calc, expect 2
117+
outputBuffer[1] = f(dispatchThreadID.x + 1, 1.5); // B.calc, expect 13.5
118+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
type: float
2+
2.000000
3+
13.500000
4+
0.000000
5+
0.000000
6+
0.000000

0 commit comments

Comments
 (0)