diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index d8500a694b..0302d9ce74 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1777,7 +1777,7 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func) for (auto param : params) { - auto ptrType = as(param->getDataType()); + auto ptrType = asRelevantPtrType(param->getDataType()); auto tempVar = builder.emitVar(ptrType->getValueType()); param->replaceUsesWith(tempVar); mapParamToTempVar[param] = tempVar; @@ -2245,7 +2245,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam( builder->emitDifferentialPairGetPrimal(diffPairParam), builder->emitDifferentialPairGetDifferential(diffType, diffPairParam)); } - else if (auto pairPtrType = as(diffPairType)) + else if (auto pairPtrType = asRelevantPtrType(diffPairType)) { auto ptrInnerPairType = as(pairPtrType->getValueType()); // Make a local copy of the parameter for primal and diff parts. diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 06e3f409d6..b5ac784ced 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1174,7 +1174,7 @@ IRVar* emitIndexedLocalVar( SourceLoc location) { // Cannot store pointers. Case should have been handled by now. - SLANG_RELEASE_ASSERT(!as(baseType)); + SLANG_RELEASE_ASSERT(!asRelevantPtrType(baseType)); // Cannot store types. Case should have been handled by now. SLANG_RELEASE_ASSERT(!as(baseType)); @@ -1326,7 +1326,11 @@ static int getInstRegionNestLevel( struct UseChain { + // The chain of uses from the base use to the relevant use. + // However, this is stored in reverse order (so that the last use is the 'base use') + // List chain; + static List from( IRUse* baseUse, Func isRelevantUse, @@ -1366,41 +1370,20 @@ struct UseChain return result; } - void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst) + // This function only replaces the inner links, not the base use. + void replaceInnerLinks(IROutOfOrderCloneContext* ctx, IRBuilder* builder) { SLANG_ASSERT(chain.getCount() > 0); - // Simple case: if there is only one use, then we can just replace it. - if (chain.getCount() == 1) - { - builder->replaceOperand(chain.getLast(), inst); - chain.clear(); - return; - } - - // Pop the last use, which is the base use that needs to be replaced. - auto baseUse = chain.getLast(); - chain.removeLast(); + const UIndex count = chain.getCount(); - // Ensure that replacement inst is set as mapping for the baseUse. - ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst; - - IRBuilder chainBuilder(builder->getModule()); - setInsertAfterOrdinaryInst(&chainBuilder, inst); - - chain.reverse(); - chain.removeLast(); - - // Clone the rest of the chain. - for (auto& use : chain) + // Process the chain in reverse order (excluding the first and last elements). + // That is, iterate from count - 2 down to 1 (inclusive). + for (int i = ((int)count) - 2; i >= 1; i--) { - ctx->cloneInstOutOfOrder(&chainBuilder, use->get()); + IRUse* use = chain[i]; + ctx->cloneInstOutOfOrder(builder, use->get()); } - - // We won't actually replace the final use, because if there are multiple chains - // it can cause problems. The parent UseGraph will handle that. - - chain.clear(); } IRInst* getUser() const @@ -1417,6 +1400,14 @@ struct UseGraph // OrderedDictionary> chainSets; + // Create a UseGraph from a base inst. + // + // `isRelevantUse` is a predicate that determines if a use is relevant. Traversal will stop at + // this use, and all chains to this use will be grouped together. + // + // `passthroughInst` is a predicate that determines if an inst should be looked through + // for uses. + // static UseGraph from( IRInst* baseInst, Func isRelevantUse, @@ -1445,36 +1436,33 @@ struct UseGraph return result; } - void replace(IRBuilder* builder, IRUse* use, IRInst* inst) + void replace(IRBuilder* builder, IRUse* relevantUse, IRInst* inst) { // Since we may have common nodes, we will use an out-of-order cloning context // that can retroactively correct the uses as needed. // IROutOfOrderCloneContext ctx; - List chains = chainSets[use]; - for (auto chain : chains) - { - chain.replace(&ctx, builder, inst); - } + List chains = chainSets[relevantUse]; - if (!isTrivial()) + // Link the first use of each chain to inst. + for (auto& chain : chains) + ctx.cloneEnv.mapOldValToNew[chain.chain.getLast()->get()] = inst; + + // Process the inner links of each chain using the replacement. + for (auto& chain : chains) { - builder->setInsertBefore(use->getUser()); - auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get()); + IRBuilder chainBuilder(builder->getModule()); + setInsertAfterOrdinaryInst(&chainBuilder, inst); - // Replace the base use. - builder->replaceOperand(use, lastInstInChain); + chain.replaceInnerLinks(&ctx, builder); } - } - bool isTrivial() - { - // We're trivial if there's only one chain, and it has only one use. - if (chainSets.getCount() != 1) - return false; + // Finally, replace the relevant use (i.e, "final use") with the new replacement inst. + builder->setInsertBefore(relevantUse->getUser()); + auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, relevantUse->get()); - auto& chain = chainSets.getFirst().value; - return chain.getCount() == 1; + // Replace the base use. + builder->replaceOperand(relevantUse, lastInstInChain); } List getUniqueUses() const @@ -1668,7 +1656,7 @@ RefPtr ensurePrimalAvailability( return true; } else if ( - as(instToStore->getDataType()) && + asRelevantPtrType(instToStore->getDataType()) && !isDifferentialOrRecomputeBlock(defBlock)) { return true; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 3237ba3b26..519f796b4f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -370,7 +370,7 @@ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc( auto diffPairType = tryGetDiffPairType(builder, paramType); if (diffPairType) { - if (!as(diffPairType) && !as(diffPairType)) + if (!asRelevantPtrType(diffPairType) && !as(diffPairType)) return builder->getInOutType(diffPairType); return diffPairType; } @@ -514,7 +514,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF { // As long as the primal parameter is not an out or constref type, // we need to fetch the primal value from the parameter. - if (as(propagateParamType)) + if (asRelevantPtrType(propagateParamType)) { primalArg = builder.emitLoad(param); } @@ -544,7 +544,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF } else { - auto primalPtrType = as(primalParamType); + auto primalPtrType = asRelevantPtrType(primalParamType); SLANG_RELEASE_ASSERT(primalPtrType); auto primalValueType = primalPtrType->getValueType(); auto var = builder.emitVar(primalValueType); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 38a7a18bbd..8356e5f815 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -291,7 +291,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy if (isNoDiffType(origType)) return nullptr; - if (auto ptrType = as(origType)) + if (auto ptrType = asRelevantPtrType(origType)) return builder->getPtrType( origType->getOp(), differentiateType(builder, ptrType->getValueType())); @@ -556,7 +556,7 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* if (isNoDiffType(originalType)) return nullptr; - if (auto origPtrType = as(originalType)) + if (auto origPtrType = asRelevantPtrType(originalType)) { if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) return builder->getPtrType(originalType->getOp(), diffPairValueType); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 5e96c4e0f3..282cc96853 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -619,7 +619,7 @@ struct DiffTransposePass if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst)) { if (auto ptrPrimalType = - as(tryGetPrimalTypeFromDiffInst(varInst))) + asRelevantPtrType(tryGetPrimalTypeFromDiffInst(varInst))) { varInst->insertAtEnd(firstRevDiffBlock); @@ -1119,7 +1119,7 @@ struct DiffTransposePass auto getDiffPairType = [](IRType* type) { - if (auto ptrType = as(type)) + if (auto ptrType = asRelevantPtrType(type)) type = ptrType->getValueType(); return as(type); }; @@ -1168,7 +1168,7 @@ struct DiffTransposePass argRequiresLoad.add(false); writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar}); } - else if (!as(arg->getDataType()) && getDiffPairType(arg->getDataType())) + else if (!asRelevantPtrType(arg->getDataType()) && getDiffPairType(arg->getDataType())) { // Normal differentiable input parameter will become an inout DiffPair parameter // in the propagate func. The split logic has already prepared the initial value @@ -1241,7 +1241,6 @@ struct DiffTransposePass argRequiresLoad.add(false); } - auto revFnType = this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType( builder, diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 6bc428ad61..4d5903ab11 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -332,8 +332,8 @@ bool isIntermediateContextType(IRInst* type) case kIROp_Specialize: return isIntermediateContextType(as(type)->getBase()); default: - if (as(type)) - return isIntermediateContextType(as(type)->getValueType()); + if (auto ptrType = asRelevantPtrType(type)) + return isIntermediateContextType(ptrType->getValueType()); return false; } } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 556fb58a8a..ec435ee873 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -75,16 +75,15 @@ struct DiffUnzipPass primalParam = primalParam->getNextParam()) { auto type = primalParam->getFullType(); - if (auto ptrType = as(type)) + if (auto ptrType = asRelevantPtrType(type)) { type = ptrType->getValueType(); } if (auto pairType = as(type)) { IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType); - if (as(primalParam->getFullType())) - diffType = - builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType); + if (auto ptrType = asRelevantPtrType(primalParam->getFullType())) + diffType = builder->getPtrType(ptrType->getOp(), (IRType*)diffType); auto primalRef = builder->emitPrimalParamRef(primalParam); auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam); builder->markInstAsDifferential(diffRef, pairType->getValueType()); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 40dcb1b514..df657476a8 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -135,7 +135,7 @@ bool isNoDiffType(IRType* paramType) paramType = attrType->getBaseType(); } - else if (auto ptrType = as(paramType)) + else if (auto ptrType = asRelevantPtrType(paramType)) { paramType = ptrType->getValueType(); } @@ -184,7 +184,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( IRStructKey* key) { IRInst* pairType = nullptr; - if (auto basePtrType = as(baseInst->getDataType())) + if (auto basePtrType = asRelevantPtrType(baseInst->getDataType())) { auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType()); @@ -203,7 +203,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( baseInst, key)); } - else if (auto ptrType = as(pairType)) + else if (auto ptrType = asRelevantPtrType(pairType)) { if (auto ptrInnerSpecializedType = as(ptrType->getValueType())) { @@ -240,7 +240,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( baseInst, key)); } - else if (auto genericPtrType = as(genericType)) + else if (auto genericPtrType = asRelevantPtrType(genericType)) { if (auto genericPairStructType = as(genericPtrType->getValueType())) { @@ -1646,7 +1646,7 @@ IRType* DifferentiableTypeConformanceContext::differentiateType( IRBuilder* builder, IRInst* primalType) { - if (auto ptrType = as(primalType)) + if (auto ptrType = asRelevantPtrType(primalType)) return builder->getPtrType( primalType->getOp(), differentiateType(builder, ptrType->getValueType())); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 433b6093fd..4698408e3d 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -604,7 +604,7 @@ inline bool isRelevantDifferentialPair(IRType* type) { return true; } - else if (auto argPtrType = as(type)) + else if (auto argPtrType = asRelevantPtrType(type)) { if (as(argPtrType->getValueType())) { diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index dbd6ac099d..bf5b25d9c1 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1528,6 +1528,16 @@ bool isOne(IRInst* inst) } } +IRPtrTypeBase* asRelevantPtrType(IRInst* inst) +{ + if (auto ptrType = as(inst)) + { + if (ptrType->getAddressSpace() != AddressSpace::UserPointer) + return ptrType; + } + return nullptr; +} + IRPtrTypeBase* isMutablePointerType(IRInst* inst) { switch (inst->getOp()) @@ -1535,7 +1545,7 @@ IRPtrTypeBase* isMutablePointerType(IRInst* inst) case kIROp_ConstRefType: return nullptr; default: - return as(inst); + return asRelevantPtrType(inst); } } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 6105247540..aed63da47c 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -271,6 +271,10 @@ bool isZero(IRInst* inst); bool isOne(IRInst* inst); +// Casts inst to IRPtrTypeBase, excluding UserPointer address space. +IRPtrTypeBase* asRelevantPtrType(IRInst* inst); + +// Returns the pointer type if it is pointer type that is not a const ref or a user pointer. IRPtrTypeBase* isMutablePointerType(IRInst* inst); void initializeScratchData(IRInst* inst); diff --git a/tests/autodiff/dynamic-dispatch-ptr.slang b/tests/autodiff/dynamic-dispatch-ptr.slang new file mode 100644 index 0000000000..3f2269f785 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-ptr.slang @@ -0,0 +1,43 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +//CHECK: 1.0 + +//TEST_INPUT: type_conformance Sensor:ISensor = 1; + +[anyValueSize(16)] +interface ISensor +{ + [Differentiable] + float4 splat(float4 point); +} + +struct Sensor : ISensor +{ + [Differentiable] + float4 splat(float4 point) + { + return point; + } +} + +[Differentiable] +float4 splat(ISensor* obj, float4 point) +{ + return obj->splat(point); +} + +//TEST_INPUT: set s = ubuffer(data=[0 0 1 0 0 0 0 0]) +uniform ISensor *s; + +//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer outBuffer; + +[shader("compute"), numthreads(1, 1, 1)] +void computeMain( + uint3 id : SV_DispatchThreadID +) +{ + DifferentialPair dp; + bwd_diff(splat)(s, dp, float4(1.0f)); + outBuffer[id.x] = dp.d; +} \ No newline at end of file