Skip to content

Commit 216dfba

Browse files
authored
Separate primal computations from unzipped function into an explicit function. (shader-slang#2569)
Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 36220da commit 216dfba

24 files changed

+841
-50
lines changed

build/visual-studio/slang/slang.vcxproj

+1
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
523523
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp" />
524524
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp" />
525525
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp" />
526+
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-unzip.cpp" />
526527
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff.cpp" />
527528
<ClCompile Include="..\..\..\source\slang\slang-ir-bind-existentials.cpp" />
528529
<ClCompile Include="..\..\..\source\slang\slang-ir-byte-address-legalize.cpp" />

build/visual-studio/slang/slang.vcxproj.filters

+3
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,9 @@
671671
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp">
672672
<Filter>Source Files</Filter>
673673
</ClCompile>
674+
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-unzip.cpp">
675+
<Filter>Source Files</Filter>
676+
</ClCompile>
674677
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff.cpp">
675678
<Filter>Source Files</Filter>
676679
</ClCompile>

source/slang/slang-ir-autodiff-propagate.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
namespace Slang
1111
{
1212

13-
bool isDifferentialInst(IRInst* inst)
13+
inline bool isDifferentialInst(IRInst* inst)
1414
{
1515
return inst->findDecoration<IRDifferentialInstDecoration>();
1616
}
1717

18-
bool isMixedDifferentialInst(IRInst* inst)
18+
inline bool isMixedDifferentialInst(IRInst* inst)
1919
{
2020
return inst->findDecoration<IRMixedDifferentialInstDecoration>();
2121
}
@@ -104,4 +104,4 @@ struct DiffPropagationPass : InstPassBase
104104
}
105105
};
106106

107-
}
107+
}

source/slang/slang-ir-autodiff-rev.cpp

+101-4
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,98 @@ struct BackwardDiffTranscriber
491491
builder.emitBranch(firstBlock);
492492
}
493493

494+
void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
495+
{
496+
IRStructType* structType = as<IRStructType>(intermediateType);
497+
if (!structType)
498+
{
499+
auto genType = as<IRGeneric>(intermediateType);
500+
structType = as<IRStructType>(findGenericReturnVal(genType));
501+
SLANG_RELEASE_ASSERT(structType);
502+
}
503+
504+
// Collect fields that are never fetched by reverse func.
505+
OrderedHashSet<IRStructKey*> fieldsToCleanup;
506+
for (auto children : structType->getChildren())
507+
{
508+
if (auto field = as<IRStructField>(children))
509+
{
510+
auto structKey = field->getKey();
511+
bool usedByRevFunc = false;
512+
for (auto use = structKey->firstUse; use; use = use->nextUse)
513+
{
514+
if (isChildInstOf(use->getUser(), func))
515+
{
516+
usedByRevFunc = true;
517+
break;
518+
}
519+
}
520+
if (!usedByRevFunc)
521+
{
522+
List<IRInst*> users;
523+
for (auto use = structKey->firstUse; use; use = use->nextUse)
524+
{
525+
users.add(use->getUser());
526+
}
527+
for (auto user : users)
528+
{
529+
if (!isChildInstOf(user, primalFunc))
530+
continue;
531+
if (auto addr = as<IRFieldAddress>(user))
532+
{
533+
if (addr->hasMoreThanOneUse())
534+
continue;
535+
if (addr->firstUse)
536+
{
537+
if (addr->firstUse->getUser()->getOp() == kIROp_Store)
538+
{
539+
addr->firstUse->getUser()->removeAndDeallocate();
540+
}
541+
addr->removeAndDeallocate();
542+
}
543+
}
544+
}
545+
546+
bool hasNonTrivialUse = false;
547+
for (auto use = structKey->firstUse; use; use = use->nextUse)
548+
{
549+
switch (use->getUser()->getOp())
550+
{
551+
case kIROp_PrimalValueStructKeyDecoration:
552+
case kIROp_StructField:
553+
continue;
554+
default:
555+
hasNonTrivialUse = true;
556+
break;
557+
}
558+
}
559+
if (!hasNonTrivialUse)
560+
{
561+
fieldsToCleanup.Add(structKey);
562+
}
563+
}
564+
}
565+
}
566+
567+
// Actually remove fields from struct.
568+
for (auto children : structType->getChildren())
569+
{
570+
if (auto field = as<IRStructField>(children))
571+
{
572+
if (fieldsToCleanup.Contains(field->getKey()))
573+
{
574+
auto key = field->getKey();
575+
List<IRInst*> keyUsers;
576+
for (auto use = key->firstUse; use; use = use->nextUse)
577+
keyUsers.add(use->getUser());
578+
for (auto keyUser : keyUsers)
579+
keyUser->removeAndDeallocate();
580+
key->removeAndDeallocate();
581+
}
582+
}
583+
}
584+
}
585+
494586
// Transcribe a function definition.
495587
InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
496588
{
@@ -520,12 +612,9 @@ struct BackwardDiffTranscriber
520612
// second block of the unzipped function.
521613
//
522614
IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
523-
615+
524616
// Clone the primal blocks from unzippedFwdDiffFunc
525617
// to the reverse-mode function.
526-
// TODO: This is the spot where we can make a decision to split
527-
// the primal and differential into two different funcitons
528-
// instead of two blocks in the same function.
529618
//
530619
// Special care needs to be taken for the first block since it holds the parameters
531620

@@ -547,6 +636,11 @@ struct BackwardDiffTranscriber
547636
block->insertAtEnd(diffFunc);
548637
}
549638

639+
// Extracts the primal computations into its own func, and replace the primal insts
640+
// with the intermediate results computed from the extracted func.
641+
IRInst* intermediateType = nullptr;
642+
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType);
643+
550644
// Transpose the first block (parameter block)
551645
transcribeParameterBlock(builder, diffFunc);
552646

@@ -563,6 +657,9 @@ struct BackwardDiffTranscriber
563657
unzippedFwdDiffFunc->removeAndDeallocate();
564658
fwdDiffFunc->removeAndDeallocate();
565659

660+
eliminateDeadCode(diffFunc);
661+
cleanUpUnusedPrimalIntermediate(diffFunc, extractedPrimalFunc, intermediateType);
662+
566663
return InstPair(primalFunc, diffFunc);
567664
}
568665

0 commit comments

Comments
 (0)