@@ -1617,6 +1617,51 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
1617
1617
}
1618
1618
}
1619
1619
1620
+ bool isLocalPointer (IRInst* ptrInst)
1621
+ {
1622
+ // If it's not a local var or a function parameter, then it's probably
1623
+ // referencing something outside the function scope.
1624
+ //
1625
+ auto addr = getRootAddr (ptrInst);
1626
+ return as<IRVar>(addr) || as<IRParam>(addr);
1627
+ }
1628
+
1629
+ void lowerSwizzledStores (IRModule* module, IRFunc* func)
1630
+ {
1631
+ List<IRInst*> instsToRemove;
1632
+
1633
+ IRBuilder builder (module);
1634
+ for (auto block : func->getBlocks ())
1635
+ {
1636
+ for (auto inst : block->getChildren ())
1637
+ {
1638
+ if (auto swizzledStore = as<IRSwizzledStore>(inst))
1639
+ {
1640
+ if (!isLocalPointer (swizzledStore->getDest ()))
1641
+ continue ;
1642
+
1643
+ builder.setInsertBefore (inst);
1644
+ for (UIndex ii = 0 ; ii < swizzledStore->getElementCount (); ii++)
1645
+ {
1646
+ auto indexVal = swizzledStore->getElementIndex (ii);
1647
+ auto indexedPtr = builder.emitElementAddress (swizzledStore->getDest (), indexVal);
1648
+ builder.emitStore (
1649
+ indexedPtr,
1650
+ builder.emitElementExtract (
1651
+ swizzledStore->getSource (),
1652
+ builder.getIntValue (builder.getIntType (), ii)));
1653
+ }
1654
+ instsToRemove.add (inst);
1655
+ }
1656
+ }
1657
+ }
1658
+
1659
+ for (auto inst : instsToRemove)
1660
+ {
1661
+ inst->removeAndDeallocate ();
1662
+ }
1663
+ }
1664
+
1620
1665
SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff (IRFunc* func)
1621
1666
{
1622
1667
insertTempVarForMutableParams (autoDiffSharedContext->moduleInst ->getModule (), func);
@@ -1626,6 +1671,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
1626
1671
1627
1672
initializeLocalVariables (autoDiffSharedContext->moduleInst ->getModule (), func);
1628
1673
1674
+ lowerSwizzledStores (autoDiffSharedContext->moduleInst ->getModule (), func);
1675
+
1629
1676
auto result = eliminateAddressInsts (func, sink);
1630
1677
1631
1678
if (SLANG_SUCCEEDED (result))
@@ -1846,6 +1893,17 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
1846
1893
case kIROp_undefined :
1847
1894
return transcribeUndefined (builder, origInst);
1848
1895
1896
+ // Differentiable insts that should have been lowered in a previous pass.
1897
+ case kIROp_SwizzledStore :
1898
+ {
1899
+ // If we have a non-null dest ptr, then we error out because something went wrong
1900
+ // when lowering swizzle-stores to regular stores
1901
+ //
1902
+ auto swizzledStore = as<IRSwizzledStore>(origInst);
1903
+ SLANG_RELEASE_ASSERT (lookupDiffInst (swizzledStore->getDest (), nullptr ) == nullptr );
1904
+ return transcribeNonDiffInst (builder, swizzledStore);
1905
+ }
1906
+
1849
1907
// Known non-differentiable insts.
1850
1908
case kIROp_Not :
1851
1909
case kIROp_BitAnd :
@@ -1875,13 +1933,13 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
1875
1933
case kIROp_DetachDerivative :
1876
1934
case kIROp_GetSequentialID :
1877
1935
case kIROp_GetStringHash :
1878
- return trascribeNonDiffInst (builder, origInst);
1936
+ return transcribeNonDiffInst (builder, origInst);
1879
1937
1880
1938
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
1881
1939
// so we treat this inst as non differentiable.
1882
1940
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
1883
1941
case kIROp_CreateExistentialObject :
1884
- return trascribeNonDiffInst (builder, origInst);
1942
+ return transcribeNonDiffInst (builder, origInst);
1885
1943
1886
1944
case kIROp_StructKey :
1887
1945
return InstPair (origInst, nullptr );
0 commit comments