@@ -491,6 +491,98 @@ struct BackwardDiffTranscriber
491
491
builder.emitBranch (firstBlock);
492
492
}
493
493
494
+ void cleanUpUnusedPrimalIntermediate (IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
495
+ {
496
+ IRStructType* structType = as<IRStructType>(intermediateType);
497
+ if (!structType)
498
+ {
499
+ auto genType = as<IRGeneric>(intermediateType);
500
+ structType = as<IRStructType>(findGenericReturnVal (genType));
501
+ SLANG_RELEASE_ASSERT (structType);
502
+ }
503
+
504
+ // Collect fields that are never fetched by reverse func.
505
+ OrderedHashSet<IRStructKey*> fieldsToCleanup;
506
+ for (auto children : structType->getChildren ())
507
+ {
508
+ if (auto field = as<IRStructField>(children))
509
+ {
510
+ auto structKey = field->getKey ();
511
+ bool usedByRevFunc = false ;
512
+ for (auto use = structKey->firstUse ; use; use = use->nextUse )
513
+ {
514
+ if (isChildInstOf (use->getUser (), func))
515
+ {
516
+ usedByRevFunc = true ;
517
+ break ;
518
+ }
519
+ }
520
+ if (!usedByRevFunc)
521
+ {
522
+ List<IRInst*> users;
523
+ for (auto use = structKey->firstUse ; use; use = use->nextUse )
524
+ {
525
+ users.add (use->getUser ());
526
+ }
527
+ for (auto user : users)
528
+ {
529
+ if (!isChildInstOf (user, primalFunc))
530
+ continue ;
531
+ if (auto addr = as<IRFieldAddress>(user))
532
+ {
533
+ if (addr->hasMoreThanOneUse ())
534
+ continue ;
535
+ if (addr->firstUse )
536
+ {
537
+ if (addr->firstUse ->getUser ()->getOp () == kIROp_Store )
538
+ {
539
+ addr->firstUse ->getUser ()->removeAndDeallocate ();
540
+ }
541
+ addr->removeAndDeallocate ();
542
+ }
543
+ }
544
+ }
545
+
546
+ bool hasNonTrivialUse = false ;
547
+ for (auto use = structKey->firstUse ; use; use = use->nextUse )
548
+ {
549
+ switch (use->getUser ()->getOp ())
550
+ {
551
+ case kIROp_PrimalValueStructKeyDecoration :
552
+ case kIROp_StructField :
553
+ continue ;
554
+ default :
555
+ hasNonTrivialUse = true ;
556
+ break ;
557
+ }
558
+ }
559
+ if (!hasNonTrivialUse)
560
+ {
561
+ fieldsToCleanup.Add (structKey);
562
+ }
563
+ }
564
+ }
565
+ }
566
+
567
+ // Actually remove fields from struct.
568
+ for (auto children : structType->getChildren ())
569
+ {
570
+ if (auto field = as<IRStructField>(children))
571
+ {
572
+ if (fieldsToCleanup.Contains (field->getKey ()))
573
+ {
574
+ auto key = field->getKey ();
575
+ List<IRInst*> keyUsers;
576
+ for (auto use = key->firstUse ; use; use = use->nextUse )
577
+ keyUsers.add (use->getUser ());
578
+ for (auto keyUser : keyUsers)
579
+ keyUser->removeAndDeallocate ();
580
+ key->removeAndDeallocate ();
581
+ }
582
+ }
583
+ }
584
+ }
585
+
494
586
// Transcribe a function definition.
495
587
InstPair transcribeFunc (IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
496
588
{
@@ -520,12 +612,9 @@ struct BackwardDiffTranscriber
520
612
// second block of the unzipped function.
521
613
//
522
614
IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts (fwdDiffFunc);
523
-
615
+
524
616
// Clone the primal blocks from unzippedFwdDiffFunc
525
617
// to the reverse-mode function.
526
- // TODO: This is the spot where we can make a decision to split
527
- // the primal and differential into two different funcitons
528
- // instead of two blocks in the same function.
529
618
//
530
619
// Special care needs to be taken for the first block since it holds the parameters
531
620
@@ -547,6 +636,11 @@ struct BackwardDiffTranscriber
547
636
block->insertAtEnd (diffFunc);
548
637
}
549
638
639
+ // Extracts the primal computations into its own func, and replace the primal insts
640
+ // with the intermediate results computed from the extracted func.
641
+ IRInst* intermediateType = nullptr ;
642
+ auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc (diffFunc, unzippedFwdDiffFunc, intermediateType);
643
+
550
644
// Transpose the first block (parameter block)
551
645
transcribeParameterBlock (builder, diffFunc);
552
646
@@ -563,6 +657,9 @@ struct BackwardDiffTranscriber
563
657
unzippedFwdDiffFunc->removeAndDeallocate ();
564
658
fwdDiffFunc->removeAndDeallocate ();
565
659
660
+ eliminateDeadCode (diffFunc);
661
+ cleanUpUnusedPrimalIntermediate (diffFunc, extractedPrimalFunc, intermediateType);
662
+
566
663
return InstPair (primalFunc, diffFunc);
567
664
}
568
665
0 commit comments