@@ -1919,6 +1919,28 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
1919
1919
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->getValue()));
1920
1920
}
1921
1921
1922
+ IRType* visitDifferentialPairType(DifferentialPairType* pairType)
1923
+ {
1924
+ IRType* primalType = lowerType(context, pairType->getPrimalType());
1925
+ if (as<IRAssociatedType>(primalType) || as<IRThisType>(primalType))
1926
+ {
1927
+ List<IRInst*> operands;
1928
+ SubstitutionSet(pairType->getDeclRef())
1929
+ .forEachSubstitutionArg(
1930
+ [&](Val* arg)
1931
+ {
1932
+ auto argVal = lowerVal(context, arg).val;
1933
+ SLANG_ASSERT(argVal);
1934
+ operands.add(argVal);
1935
+ });
1936
+
1937
+ auto undefined = getBuilder()->emitUndefined(operands[1]->getFullType());
1938
+ return getBuilder()->getDifferentialPairUserCodeType(primalType, undefined);
1939
+ }
1940
+ else
1941
+ return lowerSimpleIntrinsicType(pairType);
1942
+ }
1943
+
1922
1944
IRFuncType* visitFuncType(FuncType* type)
1923
1945
{
1924
1946
IRType* resultType = lowerType(context, type->getResultType());
@@ -10195,30 +10217,27 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
10195
10217
// If our function is differentiable, register a callback so the derivative
10196
10218
// annotations for types can be lowered.
10197
10219
//
10198
- if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
10220
+ if (decl->findModifier<DifferentiableAttribute>() && !isInterfaceRequirement(decl ))
10199
10221
{
10222
+ auto diffAttr = decl->findModifier<DifferentiableAttribute>();
10223
+
10200
10224
auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness();
10201
- OrderedDictionary<DeclRefBase *, SubtypeWitness*> resolveddiffTypeWitnessMap;
10225
+ OrderedDictionary<Type *, SubtypeWitness*> resolveddiffTypeWitnessMap;
10202
10226
10203
10227
// Go through each entry in the map and resolve the key.
10204
10228
for (auto& entry : diffTypeWitnessMap)
10205
10229
{
10206
- auto resolvedKey = as<DeclRefBase >(entry.key->resolve());
10230
+ auto resolvedKey = as<Type >(entry.key->resolve());
10207
10231
resolveddiffTypeWitnessMap[resolvedKey] =
10208
10232
as<SubtypeWitness>(as<Val>(entry.value)->resolve());
10209
10233
}
10210
10234
10211
10235
subContext->registerTypeCallback(
10212
10236
[=](IRGenContext* context, Type* type, IRType* irType)
10213
10237
{
10214
- if (!as<DeclRefType>(type))
10215
- return irType;
10216
-
10217
- DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase();
10218
- if (resolveddiffTypeWitnessMap.containsKey(declRefBase))
10238
+ if (resolveddiffTypeWitnessMap.containsKey(type))
10219
10239
{
10220
- auto irWitness =
10221
- lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val;
10240
+ auto irWitness = lowerVal(subContext, resolveddiffTypeWitnessMap[type]).val;
10222
10241
if (irWitness)
10223
10242
{
10224
10243
IRInst* args[] = {irType, irWitness};
@@ -11328,7 +11347,7 @@ LoweredValInfo emitDeclRef(IRGenContext* context, Decl* decl, DeclRefBase* subst
11328
11347
// interface definitions.
11329
11348
return emitDeclRef(
11330
11349
context,
11331
- createDefaultSpecializedDeclRef(context, nullptr, decl),
11350
+ decl->getDefaultDeclRef( ),
11332
11351
context->irBuilder->getTypeKind());
11333
11352
}
11334
11353
0 commit comments