@@ -770,13 +770,99 @@ struct SpecializationContext
770
770
}
771
771
}
772
772
773
+ // Finds any `IRTargetDecoration` from `inst`. Recursively chasing `specialize` chains.
774
+ IRTargetIntrinsicDecoration* findTargetIntrinsicDecorationRec (IRInst* inst)
775
+ {
776
+ while (auto specialize = as<IRSpecialize>(inst))
777
+ {
778
+ inst = specialize->getBase ();
779
+ }
780
+ while (auto genericInst = as<IRGeneric>(inst))
781
+ {
782
+ inst = findGenericReturnVal (genericInst);
783
+ }
784
+ if (auto decor = inst->findDecoration <IRTargetIntrinsicDecoration>())
785
+ return decor;
786
+ return nullptr ;
787
+ }
788
+
789
+ // Returns true if the call inst represents a call to
790
+ // StructuredBuffer::operator[]/Load/Consume methods.
791
+ bool isBufferLoadCall (IRCall* inst)
792
+ {
793
+ if (auto targetIntrinsic = findTargetIntrinsicDecorationRec (inst->getCallee ()))
794
+ {
795
+ auto name = targetIntrinsic->getDefinition ();
796
+ if (name == " .operator[]" || name == " .Load" || name == " .Consume" )
797
+ {
798
+ return true ;
799
+ }
800
+ }
801
+ return false ;
802
+ }
803
+
804
+ // / Transform a buffer load intrinsic call.
805
+ // / `bufferLoad(wrapExistential(bufferObj, wrapArgs), loadArgs)` should be transformed into
806
+ // / `wrapExistential(bufferLoad(bufferObj, loadArgs), wragArgs)`.
807
+ // / Returns true if `inst` matches the pattern and the load is transformed, otherwise,
808
+ // / returns false.
809
+ bool maybeSpecializeBufferLoadCall (IRCall* inst)
810
+ {
811
+ if (isBufferLoadCall (inst))
812
+ {
813
+ SLANG_ASSERT (inst->getArgCount () > 0 );
814
+ if (auto wrapExistential = as<IRWrapExistential>(inst->getArg (0 )))
815
+ {
816
+ if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(
817
+ wrapExistential->getWrappedValue ()->getDataType ()))
818
+ {
819
+ // We are seeing the instruction sequence in the form of
820
+ // .operator[](wrapExistential(structuredBuffer), idx).
821
+ // Similar to handling load(wrapExistential(..)) insts,
822
+ // we need to replace it into wrapExistential(.operator[](sb, idx))
823
+ auto resultType = inst->getFullType ();
824
+ auto elementType = sbType->getElementType ();
825
+
826
+ IRBuilder builder;
827
+ builder.sharedBuilder = &sharedBuilderStorage;
828
+ builder.setInsertBefore (inst);
829
+
830
+ List<IRInst*> args;
831
+ args.add (wrapExistential->getWrappedValue ());
832
+ for (UInt i = 1 ; i < inst->getArgCount (); i++)
833
+ args.add (inst->getArg (i));
834
+ List<IRInst*> slotOperands;
835
+ UInt slotOperandCount = wrapExistential->getSlotOperandCount ();
836
+ for (UInt ii = 0 ; ii < slotOperandCount; ++ii)
837
+ {
838
+ slotOperands.add (wrapExistential->getSlotOperand (ii));
839
+ }
840
+ auto newCall = builder.emitCallInst (elementType, inst->getCallee (), args);
841
+ auto newWrapExistential = builder.emitWrapExistential (
842
+ resultType, newCall, slotOperandCount, slotOperands.getBuffer ());
843
+ inst->replaceUsesWith (newWrapExistential);
844
+ inst->removeAndDeallocate ();
845
+ addUsersToWorkList (newWrapExistential);
846
+ return true ;
847
+ }
848
+ }
849
+ }
850
+ return false ;
851
+ }
852
+
773
853
// Given a `call` instruction in the IR, we need to detect the case
774
854
// where the callee has some interface-type parameter(s) and at the
775
855
// call site it is statically clear what concrete type(s) the arguments
776
856
// will have.
777
857
//
778
858
void maybeSpecializeExistentialsForCall (IRCall* inst)
779
859
{
860
+ // Handle a special case of `StructuredBuffer.operator[]/Load/Consume`
861
+ // calls first. These calls on builtin generic types should be handled
862
+ // the same way as a `load` inst.
863
+ if (maybeSpecializeBufferLoadCall (inst))
864
+ return ;
865
+
780
866
// We can only specialize a call when the callee function is known.
781
867
//
782
868
auto calleeFunc = as<IRFunc>(inst->getCallee ());
@@ -1678,21 +1764,26 @@ struct SpecializationContext
1678
1764
type->removeAndDeallocate ();
1679
1765
return ;
1680
1766
}
1681
- else if ( auto basePtrLikeType = as<IRPointerLikeType >(baseType) )
1767
+ else if ( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase >(baseType) )
1682
1768
{
1683
1769
// A `BindExistentials<P<T>, ...>` can be simplified to
1684
1770
// `P<BindExistentials<T, ...>>` when `P` is a pointer-like
1685
1771
// type constructor.
1686
1772
//
1687
- auto baseElementType = basePtrLikeType->getElementType ();
1773
+ IRType* baseElementType = nullptr ;
1774
+ if (auto basePtrLikeType = as<IRPointerLikeType>(baseType))
1775
+ baseElementType = basePtrLikeType->getElementType ();
1776
+ else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType))
1777
+ baseElementType = baseSBType->getElementType ();
1778
+
1688
1779
IRInst* wrappedElementType = builder.getBindExistentialsType (
1689
1780
baseElementType,
1690
1781
slotOperandCount,
1691
1782
type->getExistentialArgs ());
1692
1783
addToWorkList (wrappedElementType);
1693
1784
1694
1785
auto newPtrLikeType = builder.getType (
1695
- basePtrLikeType ->op ,
1786
+ baseType ->op ,
1696
1787
1 ,
1697
1788
&wrappedElementType);
1698
1789
addToWorkList (newPtrLikeType);
0 commit comments