Skip to content

Commit e61a0c0

Browse files
committed
Update all auto-diff locations that handle pointers to treat user pointers as regular values
1 parent 8a04dc8 commit e61a0c0

11 files changed

+38
-26
lines changed

source/slang/slang-ir-autodiff-fwd.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1777,7 +1777,7 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
17771777

17781778
for (auto param : params)
17791779
{
1780-
auto ptrType = as<IRPtrTypeBase>(param->getDataType());
1780+
auto ptrType = asRelevantPtrType(param->getDataType());
17811781
auto tempVar = builder.emitVar(ptrType->getValueType());
17821782
param->replaceUsesWith(tempVar);
17831783
mapParamToTempVar[param] = tempVar;
@@ -2245,7 +2245,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(
22452245
builder->emitDifferentialPairGetPrimal(diffPairParam),
22462246
builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
22472247
}
2248-
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
2248+
else if (auto pairPtrType = asRelevantPtrType(diffPairType))
22492249
{
22502250
auto ptrInnerPairType = as<IRDifferentialPairTypeBase>(pairPtrType->getValueType());
22512251
// Make a local copy of the parameter for primal and diff parts.

source/slang/slang-ir-autodiff-primal-hoist.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ IRVar* emitIndexedLocalVar(
11741174
SourceLoc location)
11751175
{
11761176
// Cannot store pointers. Case should have been handled by now.
1177-
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
1177+
SLANG_RELEASE_ASSERT(!asRelevantPtrType(baseType));
11781178

11791179
// Cannot store types. Case should have been handled by now.
11801180
SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));
@@ -1656,7 +1656,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
16561656
return true;
16571657
}
16581658
else if (
1659-
as<IRPtrTypeBase>(instToStore->getDataType()) &&
1659+
asRelevantPtrType(instToStore->getDataType()) &&
16601660
!isDifferentialOrRecomputeBlock(defBlock))
16611661
{
16621662
return true;

source/slang/slang-ir-autodiff-rev.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(
370370
auto diffPairType = tryGetDiffPairType(builder, paramType);
371371
if (diffPairType)
372372
{
373-
if (!as<IRPtrTypeBase>(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
373+
if (!asRelevantPtrType(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
374374
return builder->getInOutType(diffPairType);
375375
return diffPairType;
376376
}
@@ -514,7 +514,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
514514
{
515515
// As long as the primal parameter is not an out or constref type,
516516
// we need to fetch the primal value from the parameter.
517-
if (as<IRPtrTypeBase>(propagateParamType))
517+
if (asRelevantPtrType(propagateParamType))
518518
{
519519
primalArg = builder.emitLoad(param);
520520
}
@@ -544,7 +544,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
544544
}
545545
else
546546
{
547-
auto primalPtrType = as<IRPtrTypeBase>(primalParamType);
547+
auto primalPtrType = asRelevantPtrType(primalParamType);
548548
SLANG_RELEASE_ASSERT(primalPtrType);
549549
auto primalValueType = primalPtrType->getValueType();
550550
auto var = builder.emitVar(primalValueType);

source/slang/slang-ir-autodiff-transcriber-base.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
291291
if (isNoDiffType(origType))
292292
return nullptr;
293293

294-
if (auto ptrType = as<IRPtrTypeBase>(origType))
294+
if (auto ptrType = asRelevantPtrType(origType))
295295
return builder->getPtrType(
296296
origType->getOp(),
297297
differentiateType(builder, ptrType->getValueType()));
@@ -556,7 +556,7 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
556556
if (isNoDiffType(originalType))
557557
return nullptr;
558558

559-
if (auto origPtrType = as<IRPtrTypeBase>(originalType))
559+
if (auto origPtrType = asRelevantPtrType(originalType))
560560
{
561561
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
562562
return builder->getPtrType(originalType->getOp(), diffPairValueType);

source/slang/slang-ir-autodiff-transpose.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ struct DiffTransposePass
619619
if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst))
620620
{
621621
if (auto ptrPrimalType =
622-
as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst)))
622+
asRelevantPtrType(tryGetPrimalTypeFromDiffInst(varInst)))
623623
{
624624
varInst->insertAtEnd(firstRevDiffBlock);
625625

@@ -1119,7 +1119,7 @@ struct DiffTransposePass
11191119

11201120
auto getDiffPairType = [](IRType* type)
11211121
{
1122-
if (auto ptrType = as<IRPtrTypeBase>(type))
1122+
if (auto ptrType = asRelevantPtrType(type))
11231123
type = ptrType->getValueType();
11241124
return as<IRDifferentialPairType>(type);
11251125
};
@@ -1168,7 +1168,7 @@ struct DiffTransposePass
11681168
argRequiresLoad.add(false);
11691169
writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
11701170
}
1171-
else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
1171+
else if (!asRelevantPtrType(arg->getDataType()) && getDiffPairType(arg->getDataType()))
11721172
{
11731173
// Normal differentiable input parameter will become an inout DiffPair parameter
11741174
// in the propagate func. The split logic has already prepared the initial value
@@ -1241,7 +1241,6 @@ struct DiffTransposePass
12411241
argRequiresLoad.add(false);
12421242
}
12431243

1244-
12451244
auto revFnType =
12461245
this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType(
12471246
builder,

source/slang/slang-ir-autodiff-unzip.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ bool isIntermediateContextType(IRInst* type)
332332
case kIROp_Specialize:
333333
return isIntermediateContextType(as<IRSpecialize>(type)->getBase());
334334
default:
335-
if (as<IRPtrTypeBase>(type))
336-
return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
335+
if (auto ptrType = asRelevantPtrType(type))
336+
return isIntermediateContextType(ptrType->getValueType());
337337
return false;
338338
}
339339
}

source/slang/slang-ir-autodiff-unzip.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ struct DiffUnzipPass
7575
primalParam = primalParam->getNextParam())
7676
{
7777
auto type = primalParam->getFullType();
78-
if (auto ptrType = as<IRPtrTypeBase>(type))
78+
if (auto ptrType = asRelevantPtrType(type))
7979
{
8080
type = ptrType->getValueType();
8181
}
8282
if (auto pairType = as<IRDifferentialPairType>(type))
8383
{
8484
IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType);
85-
if (as<IRPtrTypeBase>(primalParam->getFullType()))
86-
diffType =
87-
builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType);
85+
if (auto ptrType = asRelevantPtrType(primalParam->getFullType()))
86+
diffType = builder->getPtrType(ptrType->getOp(), (IRType*)diffType);
8887
auto primalRef = builder->emitPrimalParamRef(primalParam);
8988
auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam);
9089
builder->markInstAsDifferential(diffRef, pairType->getValueType());

source/slang/slang-ir-autodiff.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ bool isNoDiffType(IRType* paramType)
135135

136136
paramType = attrType->getBaseType();
137137
}
138-
else if (auto ptrType = as<IRPtrTypeBase>(paramType))
138+
else if (auto ptrType = asRelevantPtrType(paramType))
139139
{
140140
paramType = ptrType->getValueType();
141141
}
@@ -184,7 +184,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
184184
IRStructKey* key)
185185
{
186186
IRInst* pairType = nullptr;
187-
if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
187+
if (auto basePtrType = asRelevantPtrType(baseInst->getDataType()))
188188
{
189189
auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType());
190190

@@ -203,7 +203,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
203203
baseInst,
204204
key));
205205
}
206-
else if (auto ptrType = as<IRPtrTypeBase>(pairType))
206+
else if (auto ptrType = asRelevantPtrType(pairType))
207207
{
208208
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
209209
{
@@ -240,7 +240,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
240240
baseInst,
241241
key));
242242
}
243-
else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
243+
else if (auto genericPtrType = asRelevantPtrType(genericType))
244244
{
245245
if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
246246
{
@@ -1646,7 +1646,7 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
16461646
IRBuilder* builder,
16471647
IRInst* primalType)
16481648
{
1649-
if (auto ptrType = as<IRPtrTypeBase>(primalType))
1649+
if (auto ptrType = asRelevantPtrType(primalType))
16501650
return builder->getPtrType(
16511651
primalType->getOp(),
16521652
differentiateType(builder, ptrType->getValueType()));

source/slang/slang-ir-autodiff.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ inline bool isRelevantDifferentialPair(IRType* type)
604604
{
605605
return true;
606606
}
607-
else if (auto argPtrType = as<IRPtrTypeBase>(type))
607+
else if (auto argPtrType = asRelevantPtrType(type))
608608
{
609609
if (as<IRDifferentialPairType>(argPtrType->getValueType()))
610610
{

source/slang/slang-ir-util.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -1528,14 +1528,24 @@ bool isOne(IRInst* inst)
15281528
}
15291529
}
15301530

1531+
IRPtrTypeBase* asRelevantPtrType(IRInst* inst)
1532+
{
1533+
if (auto ptrType = as<IRPtrTypeBase>(inst))
1534+
{
1535+
if (ptrType->getAddressSpace() != AddressSpace::UserPointer)
1536+
return ptrType;
1537+
}
1538+
return nullptr;
1539+
}
1540+
15311541
IRPtrTypeBase* isMutablePointerType(IRInst* inst)
15321542
{
15331543
switch (inst->getOp())
15341544
{
15351545
case kIROp_ConstRefType:
15361546
return nullptr;
15371547
default:
1538-
return as<IRPtrTypeBase>(inst);
1548+
return asRelevantPtrType(inst);
15391549
}
15401550
}
15411551

source/slang/slang-ir-util.h

+4
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ bool isZero(IRInst* inst);
271271

272272
bool isOne(IRInst* inst);
273273

274+
// Casts inst to IRPtrTypeBase, excluding UserPointer address space.
275+
IRPtrTypeBase* asRelevantPtrType(IRInst* inst);
276+
277+
// Returns the pointer type if it is pointer type that is not a const ref or a user pointer.
274278
IRPtrTypeBase* isMutablePointerType(IRInst* inst);
275279

276280
void initializeScratchData(IRInst* inst);

0 commit comments

Comments
 (0)