Skip to content

Commit 46a4d98

Browse files
authored
Full address insts elimination for backward autodiff. (shader-slang#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 263ca18 commit 46a4d98

40 files changed

+1230
-171
lines changed

build/visual-studio/slang/slang.vcxproj

+5
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
340340
<ClInclude Include="..\..\..\source\slang\slang-hlsl-intrinsic-set.h" />
341341
<ClInclude Include="..\..\..\source\slang\slang-image-format-defs.h" />
342342
<ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h" />
343+
<ClInclude Include="..\..\..\source\slang\slang-ir-address-analysis.h" />
343344
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h" />
344345
<ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h" />
345346
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
@@ -404,6 +405,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
404405
<ClInclude Include="..\..\..\source\slang\slang-ir-missing-return.h" />
405406
<ClInclude Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.h" />
406407
<ClInclude Include="..\..\..\source\slang\slang-ir-peephole.h" />
408+
<ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h" />
407409
<ClInclude Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.h" />
408410
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure-scoping.h" />
409411
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure.h" />
@@ -521,8 +523,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
521523
<ClCompile Include="..\..\..\source\slang\slang-glsl-extension-tracker.cpp" />
522524
<ClCompile Include="..\..\..\source\slang\slang-hlsl-intrinsic-set.cpp" />
523525
<ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp" />
526+
<ClCompile Include="..\..\..\source\slang\slang-ir-address-analysis.cpp" />
524527
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp" />
525528
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp" />
529+
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-addr-inst-elimination.cpp" />
526530
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp" />
527531
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp" />
528532
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp" />
@@ -582,6 +586,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
582586
<ClCompile Include="..\..\..\source\slang\slang-ir-missing-return.cpp" />
583587
<ClCompile Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.cpp" />
584588
<ClCompile Include="..\..\..\source\slang\slang-ir-peephole.cpp" />
589+
<ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp" />
585590
<ClCompile Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.cpp" />
586591
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure-scoping.cpp" />
587592
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure.cpp" />

build/visual-studio/slang/slang.vcxproj.filters

+15
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@
126126
<ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h">
127127
<Filter>Header Files</Filter>
128128
</ClInclude>
129+
<ClInclude Include="..\..\..\source\slang\slang-ir-address-analysis.h">
130+
<Filter>Header Files</Filter>
131+
</ClInclude>
129132
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h">
130133
<Filter>Header Files</Filter>
131134
</ClInclude>
@@ -318,6 +321,9 @@
318321
<ClInclude Include="..\..\..\source\slang\slang-ir-peephole.h">
319322
<Filter>Header Files</Filter>
320323
</ClInclude>
324+
<ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h">
325+
<Filter>Header Files</Filter>
326+
</ClInclude>
321327
<ClInclude Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.h">
322328
<Filter>Header Files</Filter>
323329
</ClInclude>
@@ -665,12 +671,18 @@
665671
<ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp">
666672
<Filter>Source Files</Filter>
667673
</ClCompile>
674+
<ClCompile Include="..\..\..\source\slang\slang-ir-address-analysis.cpp">
675+
<Filter>Source Files</Filter>
676+
</ClCompile>
668677
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp">
669678
<Filter>Source Files</Filter>
670679
</ClCompile>
671680
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp">
672681
<Filter>Source Files</Filter>
673682
</ClCompile>
683+
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-addr-inst-elimination.cpp">
684+
<Filter>Source Files</Filter>
685+
</ClCompile>
674686
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp">
675687
<Filter>Source Files</Filter>
676688
</ClCompile>
@@ -848,6 +860,9 @@
848860
<ClCompile Include="..\..\..\source\slang\slang-ir-peephole.cpp">
849861
<Filter>Source Files</Filter>
850862
</ClCompile>
863+
<ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp">
864+
<Filter>Source Files</Filter>
865+
</ClCompile>
851866
<ClCompile Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.cpp">
852867
<Filter>Source Files</Filter>
853868
</ClCompile>

source/slang/slang-diagnostic-defs.h

+2
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,8 @@ DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable fun
581581
DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.")
582582
DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal")
583583

584+
DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.")
585+
DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.")
584586
//
585587
// 5xxxx - Target code generation.
586588
//
+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#include "slang-ir-address-analysis.h"
2+
#include "slang-ir-insts.h"
3+
#include "slang-ir-util.h"
4+
5+
namespace Slang
6+
{
7+
void moveInstToEarliestPoint(IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRInst* inst)
8+
{
9+
if (!as<IRBlock>(inst->getParent()))
10+
return;
11+
if (domTree->isUnreachable(as<IRBlock>(inst->getParent())))
12+
return;
13+
14+
List<IRBlock*> blocks;
15+
HashSet<IRInst*> operandInsts;
16+
for (UInt i = 0; i < inst->getOperandCount(); i++)
17+
{
18+
operandInsts.Add(inst->getOperand(i));
19+
auto parentBlock = as<IRBlock>(inst->getOperand(i)->getParent());
20+
if (parentBlock)
21+
{
22+
if (!domTree->isUnreachable(parentBlock))
23+
blocks.add(parentBlock);
24+
}
25+
}
26+
{
27+
operandInsts.Add(inst->getFullType());
28+
auto parentBlock = as<IRBlock>(inst->getFullType()->getParent());
29+
if (parentBlock)
30+
{
31+
if (!domTree->isUnreachable(parentBlock))
32+
blocks.add(parentBlock);
33+
}
34+
}
35+
// Find earliest block that is dominated by all operand blocks.
36+
IRBlock* earliestBlock = as<IRBlock>(inst->getParent());
37+
for (auto block : func->getBlocks())
38+
{
39+
bool dominated = true;
40+
for (auto opBlock : blocks)
41+
{
42+
if (!domTree->dominates(opBlock, block))
43+
{
44+
dominated = false;
45+
break;
46+
}
47+
}
48+
if (dominated)
49+
{
50+
earliestBlock = block;
51+
break;
52+
}
53+
}
54+
55+
if (!earliestBlock)
56+
return;
57+
58+
IRInst* latestOperand = nullptr;
59+
for (auto childInst : earliestBlock->getChildren())
60+
{
61+
if (operandInsts.Contains(childInst))
62+
{
63+
latestOperand = childInst;
64+
}
65+
}
66+
67+
if (!latestOperand || as<IRParam>(latestOperand))
68+
inst->insertBefore(earliestBlock->getFirstOrdinaryInst());
69+
else
70+
inst->insertAfter(latestOperand);
71+
}
72+
73+
AddressAccessInfo analyzeAddressUse(IRDominatorTree* dom, IRGlobalValueWithCode* func)
74+
{
75+
DeduplicateContext deduplicateContext;
76+
77+
AddressAccessInfo info;
78+
79+
// Deduplicate and move known address insts.
80+
for (auto block : func->getBlocks())
81+
{
82+
for (auto inst = block->getFirstChild(); inst;)
83+
{
84+
auto next = inst->getNextInst();
85+
switch (inst->getOp())
86+
{
87+
case kIROp_Var:
88+
{
89+
RefPtr<AddressInfo> addrInfo = new AddressInfo();
90+
addrInfo->addrInst = inst;
91+
addrInfo->isConstant = true;
92+
addrInfo->parentAddress = nullptr;
93+
info.addressInfos[inst] = addrInfo;
94+
}
95+
break;
96+
case kIROp_Param:
97+
if (as<IRPtrTypeBase>(inst->getFullType()))
98+
{
99+
RefPtr<AddressInfo> addrInfo = new AddressInfo();
100+
addrInfo->addrInst = inst;
101+
addrInfo->isConstant = (block == func->getFirstBlock() ? true : false);
102+
addrInfo->parentAddress = nullptr;
103+
info.addressInfos[inst] = addrInfo;
104+
}
105+
break;
106+
case kIROp_GetElementPtr:
107+
case kIROp_FieldAddress:
108+
{
109+
moveInstToEarliestPoint(dom, func, inst->getFullType());
110+
moveInstToEarliestPoint(dom, func, inst);
111+
auto deduplicated = deduplicateContext.deduplicate(inst, [func](IRInst* inst)
112+
{
113+
if (!inst->getParent())
114+
return false;
115+
if (inst->getParent()->getParent() != func)
116+
return false;
117+
switch (inst->getOp())
118+
{
119+
case kIROp_GetElementPtr:
120+
case kIROp_FieldAddress:
121+
return true;
122+
default:
123+
return false;
124+
}
125+
});
126+
127+
if (deduplicated != inst)
128+
{
129+
SLANG_RELEASE_ASSERT(dom->dominates(
130+
as<IRBlock>(deduplicated->getParent()),
131+
as<IRBlock>(inst->getParent())));
132+
133+
inst->replaceUsesWith(deduplicated);
134+
inst->removeAndDeallocate();
135+
}
136+
else
137+
{
138+
RefPtr<AddressInfo> addrInfo = new AddressInfo();
139+
addrInfo->addrInst = inst;
140+
if (inst->getOp() == kIROp_FieldAddress)
141+
{
142+
addrInfo->isConstant = true;
143+
}
144+
else
145+
{
146+
addrInfo->isConstant =
147+
as<IRConstant>(inst->getOperand(1)) ? true : false;
148+
}
149+
info.addressInfos[inst] = addrInfo;
150+
}
151+
}
152+
break;
153+
}
154+
inst = next;
155+
}
156+
}
157+
158+
// Construct address info tree.
159+
for (auto& addr : info.addressInfos)
160+
{
161+
RefPtr<AddressInfo> parentInfo;
162+
if (addr.Value->addrInst->getOperandCount() > 1 &&
163+
info.addressInfos.TryGetValue(addr.Value->addrInst->getOperand(0), parentInfo))
164+
{
165+
addr.Value->parentAddress = parentInfo;
166+
parentInfo->children.add(addr.Value);
167+
if (!parentInfo->isConstant)
168+
addr.Value->isConstant = false;
169+
}
170+
}
171+
return info;
172+
}
173+
}
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// slang-ir-address-analysis.h
2+
#pragma once
3+
4+
#include "slang-ir.h"
5+
#include "slang-ir-dominators.h"
6+
7+
namespace Slang
8+
{
9+
struct AddressInfo : public RefObject
10+
{
11+
IRInst* addrInst = nullptr;
12+
AddressInfo* parentAddress = nullptr;
13+
bool isConstant = false;
14+
List<AddressInfo*> children;
15+
};
16+
17+
struct AddressAccessInfo
18+
{
19+
OrderedDictionary<IRInst*, RefPtr<AddressInfo>> addressInfos;
20+
};
21+
22+
// Gather info on all addresses used by `func`.
23+
AddressAccessInfo analyzeAddressUse(IRDominatorTree* domTree, IRGlobalValueWithCode* func);
24+
}

0 commit comments

Comments
 (0)