Skip to content

Commit 113a257

Browse files
Add auto-diff support for IRSwizzleStore (shader-slang#3102)
* Add auto-diff support for `IRSwizzleStore` - Lower IRSwizzleStore to multiple stores during AD preprocess. - Fix typo in `transcribeNonDiffInst` * Remove unnecessary file & add more robust check for 'local' addresses * Fix. * Update slang-ir-autodiff-fwd.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
1 parent b05b126 commit 113a257

6 files changed

+107
-5
lines changed

source/slang/slang-ir-autodiff-fwd.cpp

+60-2
Original file line numberDiff line numberDiff line change
@@ -1617,6 +1617,51 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
16171617
}
16181618
}
16191619

1620+
bool isLocalPointer(IRInst* ptrInst)
1621+
{
1622+
// If it's not a local var or a function parameter, then it's probably
1623+
// referencing something outside the function scope.
1624+
//
1625+
auto addr = getRootAddr(ptrInst);
1626+
return as<IRVar>(addr) || as<IRParam>(addr);
1627+
}
1628+
1629+
void lowerSwizzledStores(IRModule* module, IRFunc* func)
1630+
{
1631+
List<IRInst*> instsToRemove;
1632+
1633+
IRBuilder builder(module);
1634+
for (auto block : func->getBlocks())
1635+
{
1636+
for (auto inst : block->getChildren())
1637+
{
1638+
if (auto swizzledStore = as<IRSwizzledStore>(inst))
1639+
{
1640+
if (!isLocalPointer(swizzledStore->getDest()))
1641+
continue;
1642+
1643+
builder.setInsertBefore(inst);
1644+
for (UIndex ii = 0; ii < swizzledStore->getElementCount(); ii++)
1645+
{
1646+
auto indexVal = swizzledStore->getElementIndex(ii);
1647+
auto indexedPtr = builder.emitElementAddress(swizzledStore->getDest(), indexVal);
1648+
builder.emitStore(
1649+
indexedPtr,
1650+
builder.emitElementExtract(
1651+
swizzledStore->getSource(),
1652+
builder.getIntValue(builder.getIntType(), ii)));
1653+
}
1654+
instsToRemove.add(inst);
1655+
}
1656+
}
1657+
}
1658+
1659+
for (auto inst : instsToRemove)
1660+
{
1661+
inst->removeAndDeallocate();
1662+
}
1663+
}
1664+
16201665
SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
16211666
{
16221667
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
@@ -1626,6 +1671,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
16261671

16271672
initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);
16281673

1674+
lowerSwizzledStores(autoDiffSharedContext->moduleInst->getModule(), func);
1675+
16291676
auto result = eliminateAddressInsts(func, sink);
16301677

16311678
if (SLANG_SUCCEEDED(result))
@@ -1846,6 +1893,17 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
18461893
case kIROp_undefined:
18471894
return transcribeUndefined(builder, origInst);
18481895

1896+
// Differentiable insts that should have been lowered in a previous pass.
1897+
case kIROp_SwizzledStore:
1898+
{
1899+
// If we have a non-null dest ptr, then we error out because something went wrong
1900+
// when lowering swizzle-stores to regular stores
1901+
//
1902+
auto swizzledStore = as<IRSwizzledStore>(origInst);
1903+
SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr);
1904+
return transcribeNonDiffInst(builder, swizzledStore);
1905+
}
1906+
18491907
// Known non-differentiable insts.
18501908
case kIROp_Not:
18511909
case kIROp_BitAnd:
@@ -1875,13 +1933,13 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
18751933
case kIROp_DetachDerivative:
18761934
case kIROp_GetSequentialID:
18771935
case kIROp_GetStringHash:
1878-
return trascribeNonDiffInst(builder, origInst);
1936+
return transcribeNonDiffInst(builder, origInst);
18791937

18801938
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
18811939
// so we treat this inst as non differentiable.
18821940
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
18831941
case kIROp_CreateExistentialObject:
1884-
return trascribeNonDiffInst(builder, origInst);
1942+
return transcribeNonDiffInst(builder, origInst);
18851943

18861944
case kIROp_StructKey:
18871945
return InstPair(origInst, nullptr);

source/slang/slang-ir-autodiff-rev.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ namespace Slang
218218
case kIROp_WrapExistential:
219219
case kIROp_MakeExistential:
220220
case kIROp_MakeExistentialWithRTTI:
221-
return trascribeNonDiffInst(builder, origInst);
221+
return transcribeNonDiffInst(builder, origInst);
222222

223223
case kIROp_StructKey:
224224
return InstPair(origInst, nullptr);

source/slang/slang-ir-autodiff-transcriber-base.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBloc
872872
return InstPair(diffBlock, diffBlock);
873873
}
874874

875-
InstPair AutoDiffTranscriberBase::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
875+
InstPair AutoDiffTranscriberBase::transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
876876
{
877877
auto primal = cloneInst(&cloneEnv, builder, origInst);
878878
return InstPair(primal, nullptr);

source/slang/slang-ir-autodiff-transcriber-base.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ struct AutoDiffTranscriberBase
114114

115115
IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
116116

117-
InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
117+
InstPair transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
118118

119119
InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn);
120120

tests/autodiff/swizzled-store.slang

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
typedef DifferentialPair<float2> dpfloat2;
8+
typedef DifferentialPair<float3> dpfloat3;
9+
typedef DifferentialPair<float4> dpfloat4;
10+
11+
[Differentiable]
12+
float2 f(float3 x)
13+
{
14+
float3 u;
15+
u.zy = x.yx;
16+
return u.zy;
17+
}
18+
19+
[numthreads(1, 1, 1)]
20+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
21+
{
22+
{
23+
float3 a = float3(2.0, 2.0, 2.0);
24+
float3 da = float3(1.0, 0.5, 1.0);
25+
26+
outputBuffer[0] = fwd_diff(f)(dpfloat3(a, da)).d.x;
27+
}
28+
29+
{
30+
float3 a = float3(2.0, 2.0, 2.0);
31+
var dpa = diffPair(a);
32+
33+
bwd_diff(f)(dpa, float2(0.5, 1.0));
34+
35+
outputBuffer[1] = dpa.d.x; // 1.0
36+
outputBuffer[2] = dpa.d.y; // 0.5
37+
outputBuffer[3] = dpa.d.z; // 0.0
38+
}
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
type: float
2+
0.500000
3+
1.000000
4+
0.500000
5+
0.000000

0 commit comments

Comments
 (0)