@@ -135,7 +135,7 @@ bool isNoDiffType(IRType* paramType)
135
135
136
136
paramType = attrType->getBaseType ();
137
137
}
138
- else if (auto ptrType = as<IRPtrTypeBase> (paramType))
138
+ else if (auto ptrType = asRelevantPtrType (paramType))
139
139
{
140
140
paramType = ptrType->getValueType ();
141
141
}
@@ -184,7 +184,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
184
184
IRStructKey* key)
185
185
{
186
186
IRInst* pairType = nullptr ;
187
- if (auto basePtrType = as<IRPtrTypeBase> (baseInst->getDataType ()))
187
+ if (auto basePtrType = asRelevantPtrType (baseInst->getDataType ()))
188
188
{
189
189
auto loweredType = lowerDiffPairType (builder, basePtrType->getValueType ());
190
190
@@ -203,7 +203,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
203
203
baseInst,
204
204
key));
205
205
}
206
- else if (auto ptrType = as<IRPtrTypeBase> (pairType))
206
+ else if (auto ptrType = asRelevantPtrType (pairType))
207
207
{
208
208
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType ()))
209
209
{
@@ -240,7 +240,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
240
240
baseInst,
241
241
key));
242
242
}
243
- else if (auto genericPtrType = as<IRPtrTypeBase> (genericType))
243
+ else if (auto genericPtrType = asRelevantPtrType (genericType))
244
244
{
245
245
if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType ()))
246
246
{
@@ -1646,7 +1646,7 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
1646
1646
IRBuilder* builder,
1647
1647
IRInst* primalType)
1648
1648
{
1649
- if (auto ptrType = as<IRPtrTypeBase> (primalType))
1649
+ if (auto ptrType = asRelevantPtrType (primalType))
1650
1650
return builder->getPtrType (
1651
1651
primalType->getOp (),
1652
1652
differentiateType (builder, ptrType->getValueType ()));
0 commit comments