Skip to content

Commit 78517dc

Browse files
Fix lowering of associated types in generic interfaces (shader-slang#6600)
* Fix lowering of associated types in generic interfaces. * Update diff-assoctype-generic-interface.slang * Fix-up lowering of differentiable witnesses for implicit ops * Update slang-ir-autodiff-transcriber-base.cpp * Fix issue with differentiating type-packs
1 parent c8c9e42 commit 78517dc

11 files changed

+234
-69
lines changed

source/slang/slang-ast-dump.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ struct ASTDumpContext
647647

648648
void dump(SourceLanguage language) { m_writer->emit((int)language); }
649649

650-
void dump(KeyValuePair<DeclRefBase*, SubtypeWitness*> pair)
650+
void dump(KeyValuePair<Type*, SubtypeWitness*> pair)
651651
{
652652
m_writer->emit("(");
653653
dump(pair.key);

source/slang/slang-ast-modifier.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Slang
77
{
8-
const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& DifferentiableAttribute::
8+
const OrderedDictionary<Type*, SubtypeWitness*>& DifferentiableAttribute::
99
getMapTypeToIDifferentiableWitness()
1010
{
1111
for (Index i = m_mapToIDifferentiableWitness.getCount();

source/slang/slang-ast-modifier.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -1391,25 +1391,25 @@ class DifferentiableAttribute : public Attribute
13911391
{
13921392
SLANG_AST_CLASS(DifferentiableAttribute)
13931393

1394-
List<KeyValuePair<DeclRefBase*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings;
1394+
List<KeyValuePair<Type*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings;
13951395

1396-
void addType(DeclRefBase* declRef, SubtypeWitness* witness)
1396+
void addType(Type* declRef, SubtypeWitness* witness)
13971397
{
13981398
getMapTypeToIDifferentiableWitness();
13991399
if (m_mapToIDifferentiableWitness.addIfNotExists(declRef, witness))
14001400
{
14011401
m_typeToIDifferentiableWitnessMappings.add(
1402-
KeyValuePair<DeclRefBase*, SubtypeWitness*>(declRef, witness));
1402+
KeyValuePair<Type*, SubtypeWitness*>(declRef, witness));
14031403
}
14041404
}
14051405

14061406
/// Mapping from types to subtype witnesses for conformance to IDifferentiable.
1407-
const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& getMapTypeToIDifferentiableWitness();
1407+
const OrderedDictionary<Type*, SubtypeWitness*>& getMapTypeToIDifferentiableWitness();
14081408

14091409
SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet;
14101410

14111411
private:
1412-
OrderedDictionary<DeclRefBase*, SubtypeWitness*> m_mapToIDifferentiableWitness;
1412+
OrderedDictionary<Type*, SubtypeWitness*> m_mapToIDifferentiableWitness;
14131413
};
14141414

14151415
class DllImportAttribute : public Attribute

source/slang/slang-check-expr.cpp

+21-7
Original file line numberDiff line numberDiff line change
@@ -1405,14 +1405,12 @@ Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, Sou
14051405
return result;
14061406
}
14071407

1408-
void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(
1409-
DeclRefType* type,
1410-
SubtypeWitness* witness)
1408+
void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(Type* type, SubtypeWitness* witness)
14111409
{
14121410
SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
14131411
if (witness)
14141412
{
1415-
m_parentDifferentiableAttr->addType(type->getDeclRef(), witness);
1413+
m_parentDifferentiableAttr->addType(type, witness);
14161414
}
14171415
}
14181416

@@ -1468,14 +1466,14 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
14681466
type,
14691467
getASTBuilder()->getDifferentiableInterfaceType())))
14701468
{
1471-
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
1469+
addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
14721470
}
14731471

14741472
if (auto subtypeWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(
14751473
type,
14761474
getASTBuilder()->getDifferentiableRefInterfaceType())))
14771475
{
1478-
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
1476+
addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
14791477
}
14801478

14811479
if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
@@ -1515,6 +1513,15 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
15151513
maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i));
15161514
return;
15171515
}
1516+
1517+
// General check for types that may not be decl-ref-type, but still have some conformance to
1518+
// IDifferentiable/IDifferentiablePtrType
1519+
if (auto subtypeWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(
1520+
type,
1521+
getASTBuilder()->getDifferentiableInterfaceType())))
1522+
{
1523+
addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
1524+
}
15181525
}
15191526

15201527

@@ -4846,7 +4853,14 @@ Expr* SemanticsVisitor::checkBaseForMemberExpr(
48464853
auto baseExpr = inBaseExpr;
48474854
baseExpr = CheckTerm(baseExpr);
48484855

4849-
return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
4856+
auto resultBaseExpr =
4857+
maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
4858+
4859+
// We might want to register differentiability on any implicit ops that we add in.
4860+
if (this->m_parentFunc && this->m_parentFunc->findModifier<DifferentiableAttribute>())
4861+
maybeRegisterDifferentiableType(getASTBuilder(), resultBaseExpr->type.type);
4862+
4863+
return resultBaseExpr;
48504864
}
48514865

48524866
Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)

source/slang/slang-check-impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1512,7 +1512,7 @@ struct SemanticsVisitor : public SemanticsContext
15121512
/// Registers a type as conforming to IDifferentiable, along with a witness
15131513
/// describing the relationship.
15141514
///
1515-
void addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness);
1515+
void addDifferentiableTypeToDiffTypeRegistry(Type* type, SubtypeWitness* witness);
15161516
void maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type);
15171517

15181518
// Construct the differential for 'type', if it exists.

source/slang/slang-ir-autodiff-transcriber-base.cpp

+4-32
Original file line numberDiff line numberDiff line change
@@ -720,9 +720,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
720720

721721
if (auto diffType = differentiateType(builder, originalType))
722722
{
723-
IRInst* diffWitnessTable = nullptr;
724-
IRType* diffOuterType = nullptr;
725-
if (isExistentialType(diffType))
723+
if (isExistentialType(diffType) && !as<IRLookupWitnessMethod>(diffType))
726724
{
727725
// Emit null differential & pack it into an IDifferentiable existential.
728726

@@ -789,42 +787,16 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
789787
return result;
790788
}
791789

792-
// Since primalType has a corresponding differential type, we can lookup the
793-
// definition for zero().
794-
IRInst* zeroMethod = nullptr;
795-
if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
796-
{
797-
// if the differential type itself comes from a witness lookup, we can just lookup the
798-
// zero method from the same witness table.
799-
auto wt = lookupInterface->getWitnessTable();
800-
zeroMethod = builder->emitLookupInterfaceMethodInst(
801-
builder->getFuncType(List<IRType*>(), diffType),
802-
wt,
803-
autoDiffSharedContext->zeroMethodStructKey);
804-
builder->markInstAsPrimal(zeroMethod);
805-
}
806-
else
807-
{
808-
zeroMethod =
809-
differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType);
810-
}
790+
auto zeroMethod =
791+
differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType);
811792
SLANG_RELEASE_ASSERT(zeroMethod);
812793

813794
auto emptyArgList = List<IRInst*>();
814795

815796
auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
816797
builder->markInstAsDifferential(callInst, primalType);
817798

818-
if (diffOuterType && isExistentialType(diffOuterType))
819-
{
820-
// Need to wrap the result back into an existential.
821-
auto existentialZero =
822-
builder->emitMakeExistential(diffOuterType, callInst, diffWitnessTable);
823-
builder->markInstAsDifferential(existentialZero, primalType);
824-
return existentialZero;
825-
}
826-
else
827-
return callInst;
799+
return callInst;
828800
}
829801
else
830802
{

source/slang/slang-ir-autodiff.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -1362,9 +1362,10 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(
13621362
IRBuilder* builder,
13631363
IRType* origType,
13641364
IRStructKey* key,
1365-
IRType* resultType)
1365+
IRType* resultType,
1366+
DiffConformanceKind kind)
13661367
{
1367-
if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any))
1368+
if (auto conformance = tryGetDifferentiableWitness(builder, origType, kind))
13681369
return _lookupWitness(builder, conformance, key, resultType);
13691370
return nullptr;
13701371
}
@@ -2097,8 +2098,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
20972098
IRWitnessTable* table = nullptr;
20982099
if (target == DiffConformanceKind::Value)
20992100
{
2100-
SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType));
2101-
21022101
auto addMethod = builder->createFunc();
21032102
auto zeroMethod = builder->createFunc();
21042103

@@ -2138,6 +2137,8 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
21382137
&b,
21392138
(IRType*)elementType,
21402139
DiffConformanceKind::Value);
2140+
2141+
SLANG_ASSERT(isDifferentiableValueType((IRType*)elementType));
21412142
IRInst* elementResult = nullptr;
21422143
if (!innerWitness)
21432144
{
@@ -2171,9 +2172,9 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
21712172
{
21722173
// Zero method.
21732174
IRBuilder b = *builder;
2174-
b.setInsertInto(addMethod);
2175-
b.addBackwardDifferentiableDecoration(addMethod);
2176-
addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
2175+
b.setInsertInto(zeroMethod);
2176+
b.addBackwardDifferentiableDecoration(zeroMethod);
2177+
zeroMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
21772178
b.emitBlock();
21782179
List<IRInst*> results;
21792180
for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
@@ -2214,7 +2215,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
22142215
else if (target == DiffConformanceKind::Ptr)
22152216
{
22162217
SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType));
2217-
22182218
table = builder->createWitnessTable(
22192219
sharedContext->differentiablePtrInterfaceType,
22202220
(IRType*)inTupleType);

source/slang/slang-ir-autodiff.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ struct DifferentiableTypeConformanceContext
252252
IRBuilder* builder,
253253
IRType* origType,
254254
IRStructKey* key,
255-
IRType* resultType = nullptr);
255+
IRType* resultType = nullptr,
256+
DiffConformanceKind kind = DiffConformanceKind::Any);
256257

257258
IRType* differentiateType(IRBuilder* builder, IRInst* primalType);
258259

@@ -411,7 +412,8 @@ struct DifferentiableTypeConformanceContext
411412
builder,
412413
origType,
413414
sharedContext->zeroMethodStructKey,
414-
sharedContext->zeroMethodType);
415+
sharedContext->zeroMethodType,
416+
DiffConformanceKind::Value);
415417
return result;
416418
}
417419

@@ -421,7 +423,8 @@ struct DifferentiableTypeConformanceContext
421423
builder,
422424
origType,
423425
sharedContext->addMethodStructKey,
424-
sharedContext->addMethodType);
426+
sharedContext->addMethodType,
427+
DiffConformanceKind::Value);
425428
return result;
426429
}
427430

source/slang/slang-lower-to-ir.cpp

+30-11
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,28 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
19191919
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->getValue()));
19201920
}
19211921

1922+
IRType* visitDifferentialPairType(DifferentialPairType* pairType)
1923+
{
1924+
IRType* primalType = lowerType(context, pairType->getPrimalType());
1925+
if (as<IRAssociatedType>(primalType) || as<IRThisType>(primalType))
1926+
{
1927+
List<IRInst*> operands;
1928+
SubstitutionSet(pairType->getDeclRef())
1929+
.forEachSubstitutionArg(
1930+
[&](Val* arg)
1931+
{
1932+
auto argVal = lowerVal(context, arg).val;
1933+
SLANG_ASSERT(argVal);
1934+
operands.add(argVal);
1935+
});
1936+
1937+
auto undefined = getBuilder()->emitUndefined(operands[1]->getFullType());
1938+
return getBuilder()->getDifferentialPairUserCodeType(primalType, undefined);
1939+
}
1940+
else
1941+
return lowerSimpleIntrinsicType(pairType);
1942+
}
1943+
19221944
IRFuncType* visitFuncType(FuncType* type)
19231945
{
19241946
IRType* resultType = lowerType(context, type->getResultType());
@@ -10195,30 +10217,27 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
1019510217
// If our function is differentiable, register a callback so the derivative
1019610218
// annotations for types can be lowered.
1019710219
//
10198-
if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
10220+
if (decl->findModifier<DifferentiableAttribute>() && !isInterfaceRequirement(decl))
1019910221
{
10222+
auto diffAttr = decl->findModifier<DifferentiableAttribute>();
10223+
1020010224
auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness();
10201-
OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap;
10225+
OrderedDictionary<Type*, SubtypeWitness*> resolveddiffTypeWitnessMap;
1020210226

1020310227
// Go through each entry in the map and resolve the key.
1020410228
for (auto& entry : diffTypeWitnessMap)
1020510229
{
10206-
auto resolvedKey = as<DeclRefBase>(entry.key->resolve());
10230+
auto resolvedKey = as<Type>(entry.key->resolve());
1020710231
resolveddiffTypeWitnessMap[resolvedKey] =
1020810232
as<SubtypeWitness>(as<Val>(entry.value)->resolve());
1020910233
}
1021010234

1021110235
subContext->registerTypeCallback(
1021210236
[=](IRGenContext* context, Type* type, IRType* irType)
1021310237
{
10214-
if (!as<DeclRefType>(type))
10215-
return irType;
10216-
10217-
DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase();
10218-
if (resolveddiffTypeWitnessMap.containsKey(declRefBase))
10238+
if (resolveddiffTypeWitnessMap.containsKey(type))
1021910239
{
10220-
auto irWitness =
10221-
lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val;
10240+
auto irWitness = lowerVal(subContext, resolveddiffTypeWitnessMap[type]).val;
1022210241
if (irWitness)
1022310242
{
1022410243
IRInst* args[] = {irType, irWitness};
@@ -11328,7 +11347,7 @@ LoweredValInfo emitDeclRef(IRGenContext* context, Decl* decl, DeclRefBase* subst
1132811347
// interface definitions.
1132911348
return emitDeclRef(
1133011349
context,
11331-
createDefaultSpecializedDeclRef(context, nullptr, decl),
11350+
decl->getDefaultDeclRef(),
1133211351
context->irBuilder->getTypeKind());
1133311352
}
1133411353

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none
2+
//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none
3+
4+
// CUDA: __device__ void s_primal_ctx_myKernel_0(
5+
// CUDA: printf("%f\n",
6+
// CUDA: __global__ void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
7+
// CUDA: __global__ void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
8+
// CUDA: __global__ void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
9+
10+
[AutoPyBindCUDA]
11+
[Differentiable]
12+
[CudaKernel]
13+
void myKernel(DiffTensorView inValues, DiffTensorView outValues)
14+
{
15+
if (cudaThreadIdx().x > 0)
16+
return;
17+
printf("%f\n", inValues[cudaThreadIdx().x]);
18+
outValues[cudaThreadIdx().x] = sin(inValues[cudaThreadIdx().x]);
19+
}
20+
21+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
22+
// TORCH-NEXT: void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
23+
//
24+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
25+
// TORCH-NEXT: void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
26+
//
27+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
28+
// TORCH-NEXT: void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}})
29+
//
30+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
31+
// TORCH-NEXT: void myKernel(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}})
32+
//
33+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
34+
// TORCH-NEXT: std::tuple<std::tuple<const char*, const char*, const char*, const char*>, std::tuple<const char*, const char*>, const char*, const char*> __funcinfo__myKernel()
35+
//
36+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
37+
// TORCH-NEXT: void myKernel_fwd_diff(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}})
38+
//
39+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
40+
// TORCH-NEXT: void myKernel_bwd_diff(std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<uint32_t, uint32_t, uint32_t> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}}, std::tuple<torch::Tensor, std::tuple<torch::Tensor>> {{[[:alnum:]_]+}})
41+
//
42+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
43+
// TORCH-NEXT: std::tuple<std::tuple<const char*, const char*>, std::tuple<const char*, const char*>> __typeinfo__DiffTensorView()
44+
//
45+
// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
46+
// TORCH-NEXT: std::tuple<std::tuple<const char*>, std::tuple<const char*>> __typeinfo__AtomicAdd()
47+
//

0 commit comments

Comments
 (0)