From 4ea736e3bb2db52088da0abf2510a97fbc61e67e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Fri, 30 Aug 2024 13:28:05 -0400 Subject: [PATCH 01/14] initial diff-ref-type interface --- source/slang/core.meta.slang | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 4e85296664..c81e108a0b 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -285,6 +285,13 @@ interface IDifferentiable static Differential dmul(T, Differential); }; +__magic_type(DifferentiableRefType) +interface IDifferentiableRefType +{ + __builtin_requirement($((int)BuiltinRequirementKind::DifferentialRefType)) + associatedtype Differential : IDifferentiableRefType; +} + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. @@ -357,6 +364,36 @@ struct DifferentialPair : IDifferentiable } }; +__generic +__magic_type(DifferentialRefPairType) +__intrinsic_type($(kIROp_DifferentialRefPairType)) +struct DifferentialRefPair : IDifferentiableRefType +{ + typedef DifferentialRefPair Differential; + typedef T.Differential DifferentialElementType; + + __intrinsic_op($(kIROp_MakeDifferentialRefPairUserCode)) + __init(T _primal, T.Differential _differential); + + property p : T + { + __intrinsic_op($(kIROp_DifferentialRefPairGetPrimalUserCode)) + get; + } + + property v : T + { + __intrinsic_op($(kIROp_DifferentialRefPairGetPrimalUserCode)) + get; + } + + property d : T.Differential + { + __intrinsic_op($(kIROp_DifferentialRefPairGetDifferentialUserCode)) + get; + } +} + /// A type that uses a floating-point representation [sealed] From ceaacab485ef0973e80c9524faa384a5170ecf9d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Fri, 6 Sep 2024 15:01:54 -0400 Subject: [PATCH 02/14] Initial support for `IDifferentiablePtrType` --- source/slang/core.meta.slang | 30 +- source/slang/slang-ast-builder.cpp | 23 +- source/slang/slang-ast-builder.h | 9 +- source/slang/slang-ast-support-types.h | 3 +- source/slang/slang-ast-type.h | 11 + source/slang/slang-check-conformance.cpp | 3 +- source/slang/slang-check-decl.cpp | 8 +- source/slang/slang-check-expr.cpp | 28 +- source/slang/slang-ir-autodiff-fwd.cpp | 184 +++- source/slang/slang-ir-autodiff-pairs.cpp | 3 + source/slang/slang-ir-autodiff-rev.cpp | 29 +- .../slang-ir-autodiff-transcriber-base.cpp | 166 +++- .../slang-ir-autodiff-transcriber-base.h | 6 +- source/slang/slang-ir-autodiff-transpose.h | 3 +- source/slang/slang-ir-autodiff-unzip.cpp | 5 +- source/slang/slang-ir-autodiff.cpp | 801 ++++++++++++------ source/slang/slang-ir-autodiff.h | 168 +++- .../slang-ir-check-differentiability.cpp | 2 +- source/slang/slang-ir-inst-defs.h | 12 +- source/slang/slang-ir-insts.h | 20 + source/slang/slang-ir.cpp | 53 ++ source/slang/slang-ir.h | 5 + tests/autodiff/diff-ptr-type-smoke.slang | 49 ++ 23 files changed, 1226 insertions(+), 395 deletions(-) create mode 100644 tests/autodiff/diff-ptr-type-smoke.slang diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index c81e108a0b..24ce97ee36 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -285,12 +285,12 @@ interface IDifferentiable static Differential dmul(T, Differential); }; -__magic_type(DifferentiableRefType) -interface IDifferentiableRefType +__magic_type(DifferentiablePtrType) +interface IDifferentiablePtrType { - __builtin_requirement($((int)BuiltinRequirementKind::DifferentialRefType)) - associatedtype Differential : IDifferentiableRefType; -} + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) ) + associatedtype Differential : IDifferentiablePtrType; +}; /// Pair type that serves to wrap the primal and @@ -364,35 +364,35 @@ struct DifferentialPair : IDifferentiable } }; -__generic -__magic_type(DifferentialRefPairType) -__intrinsic_type($(kIROp_DifferentialRefPairType)) -struct DifferentialRefPair : IDifferentiableRefType +__generic +__magic_type(DifferentialPtrPairType) +__intrinsic_type($(kIROp_DifferentialPtrPairType)) +struct DifferentialPtrPair : IDifferentiablePtrType { - typedef DifferentialRefPair Differential; + typedef DifferentialPtrPair Differential; typedef T.Differential DifferentialElementType; - __intrinsic_op($(kIROp_MakeDifferentialRefPairUserCode)) + __intrinsic_op($(kIROp_MakeDifferentialPtrPair)) __init(T _primal, T.Differential _differential); property p : T { - __intrinsic_op($(kIROp_DifferentialRefPairGetPrimalUserCode)) + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) get; } property v : T { - __intrinsic_op($(kIROp_DifferentialRefPairGetPrimalUserCode)) + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) get; } property d : T.Differential { - __intrinsic_op($(kIROp_DifferentialRefPairGetDifferentialUserCode)) + __intrinsic_op($(kIROp_DifferentialPtrPairGetDifferential)) get; } -} +}; /// A type that uses a floating-point representation diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 9879a41872..b66af34fa4 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness) + Witness* diffTypeWitness) { - Val* args[] = { valueType, primalIsDifferentialWitness }; + Val* args[] = { valueType, diffTypeWitness }; return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } +DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness) +{ + Val* args[] = { valueType, diffRefTypeWitness }; + return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType")); +} + DeclRef ASTBuilder::getDifferentiableInterfaceDecl() { DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiableType", nullptr)); return declRef; } +DeclRef ASTBuilder::getDifferentiableRefInterfaceDecl() +{ + DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiablePtrType", nullptr)); + return declRef; +} + bool ASTBuilder::isDifferentiableInterfaceAvailable() { return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr); @@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType() return DeclRefType::create(this, getDifferentiableInterfaceDecl()); } +Type* ASTBuilder::getDifferentiableRefInterfaceType() +{ + return DeclRefType::create(this, getDifferentiableRefInterfaceDecl()); +} + DeclRef ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index b9b1f7ab85..08951513dc 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -489,10 +489,17 @@ class ASTBuilder : public RefObject DifferentialPairType* getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness); + Witness* diffTypeWitness); + + DifferentialPtrPairType* getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness); DeclRef getDifferentiableInterfaceDecl(); + DeclRef getDifferentiableRefInterfaceDecl(); + Type* getDifferentiableInterfaceType(); + Type* getDifferentiableRefInterfaceType(); bool isDifferentiableInterfaceAvailable(); diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 83a4cf3535..56101bb919 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -9,7 +9,7 @@ #include "slang-profile.h" #include "slang-type-system-shared.h" -#include "slang.h" +#include "../../include/slang.h" #include "../core/slang-semantic-version.h" @@ -1606,6 +1606,7 @@ namespace Slang DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement + DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement DZeroFunc, ///< The `IDifferentiable.dzero` function requirement DAddFunc, ///< The `IDifferentiable.dadd` function requirement DMulFunc, ///< The `IDifferentiable.dmul` function requirement diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 401d73e29d..46ea3ea559 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -462,11 +462,22 @@ class DifferentialPairType : public ArithmeticExpressionType Type* getPrimalType(); }; +class DifferentialPtrPairType : public ArithmeticExpressionType +{ + SLANG_AST_CLASS(DifferentialPtrPairType) + Type* getPrimalRefType(); +}; + class DifferentiableType : public BuiltinType { SLANG_AST_CLASS(DifferentiableType) }; +class DifferentiablePtrType : public BuiltinType +{ + SLANG_AST_CLASS(DifferentiablePtrType) +}; + class DefaultInitializableType : public BuiltinType { SLANG_AST_CLASS(DefaultInitializableType); diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index ffa0379962..fb170222d9 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -276,7 +276,8 @@ namespace Slang bool SemanticsVisitor::isTypeDifferentiable(Type* type) { - return isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None); + return isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None) || + isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None); } bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 02e3241a99..6a21f6d53d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -9964,7 +9964,8 @@ namespace Slang bool isDiffParam = (!param->findModifier()); if (isDiffParam) { - if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) + auto diffPair = visitor->getDifferentialPairType(param->getType()); + if (auto pairType = as(diffPair)) { arg->type.type = pairType; arg->type.isLeftValue = true; @@ -9985,6 +9986,11 @@ namespace Slang direction = ParameterDirection::kParameterDirection_InOut; } } + else if (auto refPairType = as(diffPair)) + { + // no need to change direction of ref-pairs. + arg->type.type = refPairType; + } else { isDiffParam = false; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3414c16b5e..557d0345c8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1131,7 +1131,8 @@ namespace Slang { if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier()) { - if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType) + if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType + || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType) { // We are trying to get differential type from a differential type. // The result is itself. @@ -1139,7 +1140,10 @@ namespace Slang } } type = resolveType(type); - if (const auto witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()))) + auto witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())); + if (!witness) + witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType())); + if (witness) { auto diffTypeLookupResult = lookUpMember( getASTBuilder(), @@ -1367,6 +1371,13 @@ namespace Slang { addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } + + if (auto subtypeWitness = as( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType()))) + { + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + } + if (auto aggTypeDeclRef = declRefType->getDeclRef().as()) { foreachDirectOrExtensionMemberOfType(this, aggTypeDeclRef, [&](DeclRef member) @@ -2891,15 +2902,22 @@ namespace Slang return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); } } + // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); + auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType(); - auto conformanceWitness = as(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (conformanceWitness) + if (auto conformanceValWitness = as( + isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None))) + { + return m_astBuilder->getDifferentialPairType(primalType, conformanceValWitness); + } + else if (auto conformancePtrWitness = as( + isSubtype(primalType, differentiableRefInterface, IsSubTypeOptions::None))) { - return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + return m_astBuilder->getDifferentialPtrPairType(primalType, conformancePtrWitness); } else return primalType; diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fe7c77ba06..53d36af461 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -16,6 +16,74 @@ namespace Slang { + +IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal) +{ + if (as(pairType)) + { + return builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); + } + else if (as(pairType)) + { + return builder->emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } +} + +IRInst* emitDifferentialPairGetDifferential(IRBuilder* builder, IRType* diffType, IRInst* pairVal) +{ + if (as(pairVal->getDataType())) + { + return builder->emitDifferentialPairGetDifferential(diffType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return builder->emitDifferentialPtrPairGetDifferential(diffType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } +} + +IRInst* emitDifferentialPairGetPrimal(IRBuilder* builder, IRInst* pairVal) +{ + if (as(pairVal->getDataType())) + { + return builder->emitDifferentialPairGetPrimal(pairVal); + } + else if (as(pairVal->getDataType())) + { + return builder->emitDifferentialPtrPairGetPrimal(pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } +} + +IRInst* emitDifferentialPairGetPrimal(IRBuilder* builder, IRType* primalType, IRInst* pairVal) +{ + if (as(pairVal->getDataType())) + { + return builder->emitDifferentialPairGetPrimal(primalType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return builder->emitDifferentialPtrPairGetPrimal(primalType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } +} IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) { @@ -336,8 +404,8 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig { auto origPtr = origLoad->getPtr(); auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr); - auto primalPtrType = as(primalPtr->getFullType()); - if (primalPtrType) + + if (auto primalPtrType = as(primalPtr->getFullType())) { if (auto diffPairType = as(primalPtrType->getValueType())) { @@ -355,6 +423,18 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load); return InstPair(primalElement, diffElement); } + else if (auto diffRefPairType = as(primalPtrType->getValueType())) + { + auto load = builder->emitLoad(primalPtr); + builder->markInstAsPrimal(load); + + auto primalElement = builder->emitDifferentialPtrPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPtrPairGetDifferential( + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load); + builder->markInstAsPrimal(primalElement); + builder->markInstAsPrimal(diffElement); + return InstPair(primalElement, diffElement); + } } auto primalLoad = maybeCloneForPrimalInst(builder, origLoad); @@ -387,6 +467,16 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or auto store = builder->emitStore(primalStoreLocation, valToStore); builder->markInstAsMixedDifferential(store, diffPairType); + return InstPair(store, nullptr); + } + else if (auto diffRefPairType = as(primalLocationPtrType->getValueType())) + { + auto valToStore = builder->emitMakeDifferentialPtrPair(diffRefPairType, primalStoreVal, diffStoreVal); + builder->markInstAsPrimal(valToStore); + + auto store = builder->emitStore(primalStoreLocation, valToStore); + builder->markInstAsPrimal(store); + return InstPair(store, nullptr); } } @@ -404,7 +494,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or // Default case, storing the entire type (and not a member) diffStore = as( builder->emitStore(diffStoreLocation, diffStoreVal)); - + markDiffTypeInst(builder, diffStore, primalStoreVal->getDataType()); return InstPair(primalStore, diffStore); } @@ -696,14 +786,30 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) { auto pairPtrType = as(pairType); - auto pairValType = as( + + auto pairValType = as( pairPtrType ? pairPtrType->getValueType() : pairType); + + DiffConformanceKind kind = DiffConformanceKind::Any; + if (as(pairValType)) + { + kind = DiffConformanceKind::Ptr; + } + else if (as(pairValType)) + { + kind = DiffConformanceKind::Value; + } + else + { + SLANG_ASSERT(!"unreachable"); + } + auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType); if (auto ptrParamType = as(diffParamType)) { // Create temp var to pass in/out arguments. auto srcVar = argBuilder.emitVar(pairValType); - argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType()); + markDiffPairTypeInst(&argBuilder, srcVar, pairValType); auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); if (ptrParamType->getOp() == kIROp_InOutType) @@ -716,28 +822,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig else { diffArgVal = argBuilder.emitLoad(diffArg); - argBuilder.markInstAsDifferential(diffArgVal, pairValType->getValueType()); + markDiffTypeInst(&argBuilder, diffArgVal, pairValType->getValueType()); } - auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal); - argBuilder.markInstAsMixedDifferential(initVal, primalType); + auto initVal = emitMakeDifferentialPair(&argBuilder, pairValType, primalVal, diffArgVal); + markDiffPairTypeInst(&argBuilder, initVal, pairValType); auto store = argBuilder.emitStore(srcVar, initVal); - argBuilder.markInstAsMixedDifferential(store, primalType); + markDiffPairTypeInst(&argBuilder, store, pairValType); } if (as(ptrParamType)) { // Read back new value. auto newVal = afterBuilder.emitLoad(srcVar); - afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType()); - auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal); + markDiffPairTypeInst(&afterBuilder, newVal, pairValType); + auto newPrimalVal = emitDifferentialPairGetPrimal(&afterBuilder, pairValType->getValueType(), newVal); afterBuilder.emitStore(primalArg, newPrimalVal); if (diffArg) { - auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal); - afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType()); + auto newDiffVal = emitDifferentialPairGetDifferential(&afterBuilder, (IRType*)diffType, newVal); + markDiffTypeInst(&afterBuilder, newDiffVal, pairValType->getValueType()); auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal); - afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType()); + markDiffTypeInst(&afterBuilder, storeInst, pairValType->getValueType()); } } args.add(srcVar); @@ -752,8 +858,8 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig // If a pair type can be formed, this must be non-null. SLANG_RELEASE_ASSERT(diffArg); - auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg); - argBuilder.markInstAsMixedDifferential(diffPair, pairType); + auto diffPair = emitMakeDifferentialPair(&argBuilder, pairType, primalArg, diffArg); + markDiffPairTypeInst(&argBuilder, diffPair, pairType); args.add(diffPair); continue; @@ -779,16 +885,17 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig diffCallee, args); placeholderCall->removeAndDeallocate(); + argBuilder.markInstAsMixedDifferential(callInst, diffReturnType); argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee); *builder = afterBuilder; - if (diffReturnType->getOp() == kIROp_DifferentialPairType) + if (as(diffReturnType) || as(diffReturnType)) { - IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst); + IRInst* primalResultValue = emitDifferentialPairGetPrimal(&afterBuilder, callInst); auto diffType = differentiateType(&afterBuilder, origCall->getFullType()); - IRInst* diffResultValue = afterBuilder.emitDifferentialPairGetDifferential(diffType, callInst); + IRInst* diffResultValue = emitDifferentialPairGetDifferential(&afterBuilder, diffType, callInst); return InstPair(primalResultValue, diffResultValue); } else @@ -1751,12 +1858,14 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr IRInst* valToStore = nullptr; if (writeBack.value.differential) { + auto pairValType = cast(param->getFullType())->getValueType(); auto diffVal = builder.emitLoad(writeBack.value.differential); - builder.markInstAsDifferential(diffVal, primalVal->getFullType()); + markDiffTypeInst(&builder, diffVal, primalVal->getFullType()); - valToStore = builder.emitMakeDifferentialPair(cast(param->getFullType())->getValueType(), - primalVal, diffVal); - builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType()); + valToStore = emitMakeDifferentialPair( + &builder, pairValType, primalVal, diffVal); + + markDiffPairTypeInst(&builder, valToStore, pairValType); } else { @@ -1767,7 +1876,7 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr if (writeBack.value.differential) { - builder.markInstAsMixedDifferential(storeInst, valToStore->getFullType()); + markDiffPairTypeInst(&builder, storeInst, valToStore->getFullType()); } } } @@ -2043,24 +2152,26 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam SLANG_ASSERT(diffPairParam); - if (auto pairType = as(diffPairType)) + if (as(diffPairType) || as(diffPairType)) { return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, pairType), + emitDifferentialPairGetPrimal(builder, diffPairParam), + emitDifferentialPairGetDifferential( + builder, + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( + builder, + as(diffPairType)), diffPairParam)); } else if (auto pairPtrType = as(diffPairType)) { - auto ptrInnerPairType = as(pairPtrType->getValueType()); + auto ptrInnerPairType = as(pairPtrType->getValueType()); // Make a local copy of the parameter for primal and diff parts. auto primal = builder->emitVar(ptrInnerPairType->getValueType()); auto diffType = differentiateType(builder, cast(origParam->getDataType())->getValueType()); auto diff = builder->emitVar(diffType); - builder->markInstAsDifferential( - diff, builder->getPtrType(ptrInnerPairType->getValueType())); + markDiffTypeInst(builder, diff, builder->getPtrType(ptrInnerPairType->getValueType())); IRInst* primalInitVal = nullptr; IRInst* diffInitVal = nullptr; @@ -2072,17 +2183,18 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam else { auto initVal = builder->emitLoad(diffPairParam); - builder->markInstAsMixedDifferential(initVal, ptrInnerPairType); + markDiffPairTypeInst(builder, initVal, ptrInnerPairType); - primalInitVal = builder->emitDifferentialPairGetPrimal(initVal); - diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal); + primalInitVal = emitDifferentialPairGetPrimal(builder, initVal); + diffInitVal = emitDifferentialPairGetDifferential(builder, diffType, initVal); } - builder->markInstAsDifferential(diffInitVal, ptrInnerPairType->getValueType()); + markDiffTypeInst(builder, diffInitVal, ptrInnerPairType->getValueType()); + builder->emitStore(primal, primalInitVal); auto diffStore = builder->emitStore(diff, diffInitVal); - builder->markInstAsDifferential(diffStore, ptrInnerPairType->getValueType()); + markDiffTypeInst(builder, diffStore, ptrInnerPairType->getValueType()); mapInOutParamToWriteBackValue[diffPairParam] = InstPair(primal, diff); return InstPair(primal, diff); diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index 7fc8ebbe65..3a6d52bead 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -107,10 +107,13 @@ struct DiffPairLoweringPass : InstPassBase case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferentialUserCode: case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetDifferential: + case kIROp_DifferentialPtrPairGetPrimal: lowerPairAccess(builder, inst); break; case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: lowerMakePair(builder, inst); break; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 35a197f29b..02bc8190cd 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -152,6 +152,16 @@ namespace Slang builder->emitBlock(); params = _defineFuncParams(builder, as(existingPrimalFunc)); params.removeLast(); + + // Unwrap any ref pairs. We need this special case for trivial funcs. + for (Int i = 0; i < params.getCount(); i++) + { + if (auto diffPairType = as(params[i]->getDataType())) + { + params[i] = builder->emitDifferentialPtrPairGetPrimal(params[i]); + } + } + IRInst* originalFuncRefFromPrimalFunc = originalFunc; if (originalGeneric) originalFuncRefFromPrimalFunc = maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric); @@ -266,7 +276,20 @@ namespace Slang if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) return primalNoDiffType; - return (IRType*)findOrTranscribePrimalInst(builder, paramType); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); + + // Differentiable pointer types are treated as primal pairs, since they aren't involved in the transposition + // process. + // + if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) + { + auto diffPairType = tryGetDiffPairType(builder, primalType); + SLANG_ASSERT(diffPairType); + + return diffPairType; + } + + return primalType; } IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType) @@ -292,7 +315,7 @@ namespace Slang auto diffPairType = tryGetDiffPairType(builder, paramType); if (diffPairType) { - if (!as(diffPairType)) + if (!as(diffPairType) && !as(diffPairType)) return builder->getInOutType(diffPairType); return diffPairType; } @@ -942,7 +965,7 @@ namespace Slang // Initialize the var with input diff param at start. // Note that we insert the store in the primal block so it won't get transposed. auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam); - nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType); + nextBlockBuilder.markInstAsDifferential(storeInst, primalType); // Since this store inst is specific to propagate function, we track it in a // set so we can remove it when we generate the primal func. result.propagateFuncSpecificPrimalInsts.add(storeInst); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index da69ed8aea..011f7c923b 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -174,45 +174,54 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); -IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) +IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind) { - return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType); + if (kind == DiffConformanceKind::Any) + { + if (auto valueWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Value)) + return valueWitness; + if (auto ptrWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Ptr)) + return ptrWitness; + } + else + { + return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, kind); + } + return nullptr; } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) { - return builder->getDifferentialPairType( - (IRType*)primalType, - witness); + auto conformanceType = differentiableTypeConformanceContext.getConformanceTypeFromWitness(witness); + if (autoDiffSharedContext->isInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiableInterfaceType) + { + return builder->getDifferentialPairType((IRType*)primalType, witness); + } + else if (autoDiffSharedContext->isPtrInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiableRefInterfaceType) + { + return builder->getDifferentialPtrPairType((IRType*)primalType, witness); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType) { - auto primalType = lookupPrimalInst(builder, originalType, nullptr); + auto primalType = lookupPrimalInst(builder, originalType, originalType); SLANG_RELEASE_ASSERT(primalType); IRInst* witness = nullptr; - if (auto lookup = as(primalType)) - { - if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey) - { - witness = builder->emitLookupInterfaceMethodInst( - lookup->getWitnessTable()->getDataType(), - lookup->getWitnessTable(), - autoDiffSharedContext->differentialAssocTypeWitnessStructKey); - } - } - - // Obtain the witness that primalType conforms to IDifferentiable. + + // Obtain the witness that primalType conforms to IDifferentiable/IDifferentiablePtrType if (!witness) - witness = tryGetDifferentiableWitness(builder, originalType); + witness = tryGetDifferentiableWitness(builder, primalType, DiffConformanceKind::Any); SLANG_RELEASE_ASSERT(witness); - auto pairType = builder->getDifferentialPairType( - (IRType*)primalType, - witness); - - return pairType; + return getOrCreateDiffPairType(builder, primalType, witness); } IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) @@ -223,8 +232,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o // Special-case for differentiable existential types. if (as(origType) || as(origType)) { - if (differentiableTypeConformanceContext.lookUpConformanceForType(origType)) + if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Value)) return autoDiffSharedContext->differentiableInterfaceType; + else if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Ptr)) + return autoDiffSharedContext->differentiableRefInterfaceType; else return nullptr; } @@ -278,8 +289,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } case kIROp_DifferentialPairType: + case kIROp_DifferentialPtrPairType: { - auto primalPairType = as(primalType); + auto primalPairType = as(primalType); return getOrCreateDiffPairType( builder, differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), @@ -445,8 +457,17 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* auto interfaceType = as(unwrapAttributedType(origType->getOperand(0)->getDataType())); if (!interfaceType) return nullptr; - List lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath( + List lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath( autoDiffSharedContext->differentiableInterfaceType, interfaceType); + List lookupPathPtrType = differentiableTypeConformanceContext.findInterfaceLookupPath( + autoDiffSharedContext->differentiableRefInterfaceType, interfaceType); + + SLANG_ASSERT(!(lookupPathValueType.getCount() && lookupPathPtrType.getCount())); + + auto lookupKeyPath = lookupPathValueType.getCount() ? lookupPathValueType : lookupPathPtrType; + auto diffStructKey = lookupPathValueType.getCount() ? + autoDiffSharedContext->differentialAssocTypeStructKey : + autoDiffSharedContext->differentialAssocRefTypeStructKey; if (lookupKeyPath.getCount()) { @@ -456,7 +477,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* { outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey()); } - auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey); + auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, diffStructKey); return (IRType*)diffType; } return nullptr; @@ -559,12 +580,33 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui builder->markInstAsPrimal(primalDiffType); builder->markInstAsPrimal(diffWitness); + return InstPair(primal, diffWitness); + } + else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiableRefInterfaceType) + { + auto primalDiffType = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + primal, + autoDiffSharedContext->differentialAssocRefTypeStructKey); + auto diffWitness = builder->emitLookupInterfaceMethodInst( + (IRType*)primalDiffType, + primal, + autoDiffSharedContext->differentialAssocRefTypeWitnessStructKey); + + // Mark both as primal since we're working with types + // (which don't need transposing) + // + builder->markInstAsPrimal(primalDiffType); + builder->markInstAsPrimal(diffWitness); + return InstPair(primal, diffWitness); } } + auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); + if (!decor) { return InstPair(primal, nullptr); @@ -589,6 +631,10 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType( { originalType = (IRType*)unwrapAttributedType(originalType); auto primalType = (IRType*)lookupPrimalInst(builder, originalType); + + // Can't generate zero for differentiable ptr types. Should never hit this case. + SLANG_ASSERT(!differentiableTypeConformanceContext.isDifferentiablePtrType(originalType)); + if (auto diffType = differentiateType(builder, originalType)) { IRInst* diffWitnessTable = nullptr; @@ -985,7 +1031,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst && !as(pair.differential)) { auto primalType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + markDiffTypeInst(builder, pair.differential, primalType); } } else @@ -997,7 +1043,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst if (as(pair.differential)) break; auto mixedType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsMixedDifferential(pair.primal, mixedType); + markDiffPairTypeInst(builder, pair.primal, mixedType); } } @@ -1075,4 +1121,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori return result; } + +void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType) +{ + // Ignore module-level insts. + if (as(diffInst->getParent())) + return; + + // Also ignore generic-container-level insts. + if (as(diffInst->getParent()) && + as(diffInst->getParent()->getParent())) + return; + + // TODO: This logic is a bit of a hack. We need to determine if the type is + // relevant to ptr-type computation or not, or more complex applications + // that use dynamic dispatch + ptr types will fail. + // + if (as(diffInst)) + { + builder->markInstAsDifferential(diffInst, nullptr); + return; + } + + SLANG_ASSERT(diffInst); + SLANG_ASSERT(primalType); + + if (differentiableTypeConformanceContext.isDifferentiableValueType(primalType)) + { + builder->markInstAsDifferential(diffInst, primalType); + } + else if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) + { + builder->markInstAsPrimal(diffInst); + } + else + { + // Stop-gap solution to go with differential inst for now. + builder->markInstAsDifferential(diffInst, primalType); + } +} + +void AutoDiffTranscriberBase::markDiffPairTypeInst(IRBuilder* builder, IRInst* diffPairInst, IRType* pairType) +{ + SLANG_ASSERT(diffPairInst); + SLANG_ASSERT(pairType); + SLANG_ASSERT(as(pairType)); + + if (auto diffPairType = as(pairType)) + { + builder->markInstAsMixedDifferential(diffPairInst, pairType); + } + else if (as(pairType)) + { + builder->markInstAsPrimal(diffPairInst); + } + else + { + SLANG_UNEXPECTED("unexpected differentiable type"); + } +} + } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index f7f2dd6f20..9f3cfe56f0 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -91,7 +91,7 @@ struct AutoDiffTranscriberBase void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); - IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); @@ -152,6 +152,10 @@ struct AutoDiffTranscriberBase virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0; virtual IROp getInterfaceRequirementDerivativeDecorationOp() = 0; + + void markDiffTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType); + + void markDiffPairTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType); }; } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index d42462e1ba..e19554ca56 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2100,7 +2100,8 @@ struct DiffTransposePass // If we reach this point, revValue must be a differentiable type. auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness( builder, - primalType); + primalType, + DiffConformanceKind::Value); SLANG_ASSERT(revTypeWitness); auto baseExistential = fwdInst->getOperand(0); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 9b3e3a324a..9cdb28c981 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -141,7 +141,10 @@ struct ExtractPrimalFuncContext } auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); - if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness(&genTypeBuilder, (IRType*)fieldType)) + if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness( + &genTypeBuilder, + (IRType*)fieldType, + DiffConformanceKind::Value)) { genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, witness); } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 0979c097c4..93174b3913 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -25,7 +25,7 @@ bool isBackwardDifferentiableFunc(IRInst* func) return false; } -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey, IRType* resultType = nullptr) { if (auto witnessTable = as(witness)) { @@ -53,15 +53,16 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK } else { + SLANG_ASSERT(resultType); return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), + resultType, witness, requirementKey); } return nullptr; } -static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witness = type->getWitness(); SLANG_RELEASE_ASSERT(witness); @@ -70,16 +71,48 @@ static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRB if (as(type->getValueType()) || as(type->getValueType())) { // The differential type is the IDifferentiable interface type. - return sharedContext->differentiableInterfaceType; + if (as(type) || as(type)) + return sharedContext->differentiableInterfaceType; + else if (as(type)) + return sharedContext->differentiableRefInterfaceType; + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } - return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); + if (as(type) || as(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocTypeStructKey, + builder->getTypeKind()); + else if (as(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocRefTypeStructKey, + builder->getTypeKind()); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); + + if (as(type) || as(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType); + else if (as(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } bool isNoDiffType(IRType* paramType) @@ -320,6 +353,24 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( return result; } +IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst) +{ + for (auto inst : moduleInst->getGlobalInsts()) + { + if (auto interfaceType = as(inst)) + { + if (auto decor = interfaceType->findDecoration()) + { + if (decor->getName() == "IDifferentiablePtrType") + { + return interfaceType; + } + } + } + } + return nullptr; +} + AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst* inModuleInst) : moduleInst(inModuleInst), targetProgram(target) { @@ -328,14 +379,27 @@ AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + differentialAssocTypeWitnessTableType = findDifferentialTypeWitnessTableType(); zeroMethodStructKey = findZeroMethodStructKey(); + zeroMethodType = cast(getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementVal()); addMethodStructKey = findAddMethodStructKey(); + addMethodType = cast(getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementVal()); mulMethodStructKey = findMulMethodStructKey(); nullDifferentialStructType = findNullDifferentialStructType(); nullDifferentialWitness = findNullDifferentialWitness(); - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; + isInterfaceAvailable = true; + } + + differentiableRefInterfaceType = as(findDifferentiableRefInterface(inModuleInst)); + + if (differentiableRefInterfaceType) + { + differentialAssocRefTypeStructKey = findDifferentialPtrTypeStructKey(); + differentialAssocRefTypeWitnessStructKey = findDifferentialPtrTypeWitnessStructKey(); + differentialAssocRefTypeWitnessTableType = findDifferentialPtrTypeWitnessTableType(); + + isPtrInterfaceAvailable = true; } } @@ -404,14 +468,14 @@ IRInst* AutoDiffSharedContext::findNullDifferentialWitness() } -IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +IRInterfaceRequirementEntry* AutoDiffSharedContext::getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index) { - if (as(moduleInst) && differentiableInterfaceType) + if (as(moduleInst) && interface) { // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as(differentiableInterfaceType->getOperand(index))) - return as(entry->getRequirementKey()); + // SLANG_ASSERT(interface->getOperandCount() == 5); + if (auto entry = as(interface->getOperand(index))) + return entry; else { SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); @@ -421,6 +485,43 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde return nullptr; } +// Extracts conformance interface from a witness inst while accounting for some +// quirks in the type system around interfaces that conform to other interfaces. +// +IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWitness(IRInst* witness) +{ + IRInterfaceType* diffInterfaceType = nullptr; + if (auto witnessTableType = as(witness->getDataType())) + { + diffInterfaceType = cast(witnessTableType->getConformanceType()); + } + else if (auto structKey = as(witness)) + { + // We currently assume that a struct key is used uniquely for a single interface-requirement-entry. + // Find that entry + for (IRUse* use = structKey->firstUse; use; use = use->nextUse) + { + if (auto entry = as(use->getUser())) + { + auto innerWitnessTableType = cast(entry->getRequirementVal()); + diffInterfaceType = cast(innerWitnessTableType->getConformanceType()); + break; + } + } + } + else if (auto interfaceRequirementEntry = as(witness)) + { + auto innerWitnessTableType = cast(interfaceRequirementEntry->getRequirementVal()); + diffInterfaceType = cast(innerWitnessTableType->getConformanceType()); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } + + return diffInterfaceType; +} + void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; @@ -434,7 +535,15 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { if (auto item = as(child)) { - auto existingItem = differentiableWitnessDictionary.tryGetValue(item->getConcreteType()); + IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); + + SLANG_ASSERT( + diffInterfaceType == sharedContext->differentiableInterfaceType + || diffInterfaceType == sharedContext->differentiableRefInterfaceType); + + //lookUpConformanceForType(item->getConcreteType()); + // TODO: need to consider ref type. + auto existingItem = differentiableValueTypeWitnessDictionary.tryGetValue(item->getConcreteType()); if (existingItem) { *existingItem = item->getWitness(); @@ -458,20 +567,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { auto element = concreteType->getOperand(i); auto elementWitness = witnessPack->getOperand(i); - differentiableWitnessDictionary.addIfNotExists( - (IRType*)element, - _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + + if (diffInterfaceType == sharedContext->differentiableInterfaceType) + addTypeToDictionary( + (IRType*)element, + _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind())); + else if (diffInterfaceType == sharedContext->differentiableRefInterfaceType) + addTypeToDictionary( + (IRType*)element, + _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocRefTypeStructKey, subBuilder.getTypeKind())); } return; } } - differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness()); if (!as(item->getConcreteType())) { - differentiableWitnessDictionary.addIfNotExists( - (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey), + addTypeToDictionary( + (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind()), item->getWitness()); } @@ -480,29 +595,55 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) // For differential pair types, register the differential type as well. IRBuilder builder(diffPairType); builder.setInsertAfter(diffPairType->getWitness()); - auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey); - auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey); - if (diffType && diffWitness) - { - differentiableWitnessDictionary.addIfNotExists((IRType*)diffType, diffWitness); - } + + // TODO(sai): lot of this logic is duplicated. need to refactor. + auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey, builder.getTypeKind()) : + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocRefTypeStructKey, builder.getTypeKind()); + auto diffWitness = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType) : + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + + addTypeToDictionary((IRType*)diffType, diffWitness); } } } } } -IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type, DiffConformanceKind kind) { IRInst* foundResult = nullptr; - differentiableWitnessDictionary.tryGetValue(type, foundResult); + + switch (kind) + { + case DiffConformanceKind::Any: + differentiableValueTypeWitnessDictionary.tryGetValue(type, foundResult); + if (!foundResult) + differentiablePtrTypeWitnessDictionary.tryGetValue(type, foundResult); + case DiffConformanceKind::Value: + differentiableValueTypeWitnessDictionary.tryGetValue(type, foundResult); + break; + case DiffConformanceKind::Ptr: + differentiablePtrTypeWitnessDictionary.tryGetValue(type, foundResult); + break; + } + return foundResult; } -IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType) { - if (auto conformance = tryGetDifferentiableWitness(builder, origType)) - return _lookupWitness(builder, conformance, key); + if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any)) + return _lookupWitness(builder, conformance, key, resultType); return nullptr; } @@ -514,7 +655,8 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { - return _getDiffTypeFromPairType(sharedContext, builder, type); + return this->differentiateType(builder, type->getValueType()); + //return _getDiffTypeFromPairType(sharedContext, builder, type); } IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) @@ -525,20 +667,40 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); } IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey, sharedContext->addMethodType); +} + +void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRInst* witness) +{ + //auto witnessType = cast(witness->getDataType()); + auto conformanceType = getConformanceTypeFromWitness(witness); + if (sharedContext->isInterfaceAvailable && + conformanceType == sharedContext->differentiableInterfaceType) + { + differentiableValueTypeWitnessDictionary.addIfNotExists(type, witness); + } + else if (sharedContext->isPtrInterfaceAvailable && + conformanceType == sharedContext->differentiableRefInterfaceType) + { + differentiablePtrTypeWitnessDictionary.addIfNotExists(type, witness); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } } IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable) { SLANG_RELEASE_ASSERT(interfaceType); - List lookupKeyPath = findDifferentiableInterfaceLookupPath( + List lookupKeyPath = findInterfaceLookupPath( sharedContext->differentiableInterfaceType, interfaceType); IRInst* differentialTypeWitness = witnessTable; @@ -549,6 +711,7 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface { differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); // Lookup insts are always primal values. + builder->markInstAsPrimal(differentialTypeWitness); } return differentialTypeWitness; @@ -557,10 +720,10 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface return nullptr; } -// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. -static bool _findDifferentiableInterfaceLookupPathImpl( +// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `supType`. +static bool _findInterfaceLookupPathImpl( HashSet& processedTypes, - IRInterfaceType* idiffType, + IRInterfaceType* supType, IRInterfaceType* type, List& currentPath) { @@ -576,13 +739,13 @@ static bool _findDifferentiableInterfaceLookupPathImpl( if (auto wt = as(entry->getRequirementVal())) { currentPath.add(entry); - if (wt->getConformanceType() == idiffType) + if (wt->getConformanceType() == supType) { return true; } else if (auto subInterfaceType = as(wt->getConformanceType())) { - if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) + if (_findInterfaceLookupPathImpl(processedTypes, supType, subInterfaceType, currentPath)) return true; } currentPath.removeLast(); @@ -591,11 +754,11 @@ static bool _findDifferentiableInterfaceLookupPathImpl( return false; } -List DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type) +List DifferentiableTypeConformanceContext::findInterfaceLookupPath(IRInterfaceType *supType, IRInterfaceType *type) { List currentPath; HashSet processedTypes; - _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); + _findInterfaceLookupPathImpl(processedTypes, supType, type, currentPath); return currentPath; } @@ -722,7 +885,8 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { if (auto pairType = as(globalInst)) { - differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); + addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); + //differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); } } } @@ -762,9 +926,8 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build case kIROp_DifferentialPairType: { auto primalPairType = as(primalType); - return getOrCreateDiffPairType( - builder, - getDiffTypeFromPairType(builder, primalPairType), + return builder->getDifferentialPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), getDiffTypeWitnessFromPairType(builder, primalPairType)); } @@ -776,6 +939,14 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build getDiffTypeWitnessFromPairType(builder, primalPairType)); } + case kIROp_DifferentialPtrPairType: + { + auto primalPairType = as(primalType); + return builder->getDifferentialPtrPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), + getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + case kIROp_FuncType: { SLANG_UNIMPLEMENTED_X("Impl"); @@ -817,12 +988,12 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build } } -IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) +IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType, DiffConformanceKind kind) { if (isNoDiffType((IRType*)primalType)) return nullptr; - - IRInst* witness = lookUpConformanceForType((IRType*)primalType); + + IRInst* witness = lookUpConformanceForType((IRType*)primalType, kind); if (witness) { SLANG_RELEASE_ASSERT(witness || as(primalType)); @@ -834,31 +1005,60 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil witness = nullptr; } - if (!witness) + if (witness) + return witness; + + // If a witness is not already mapped, build one if possible. + SLANG_RELEASE_ASSERT(primalType); + if (auto primalPairType = as(primalType)) { - SLANG_RELEASE_ASSERT(primalType); - if (auto primalPairType = as(primalType)) - { - witness = getOrCreateDifferentiablePairWitness(builder, primalPairType); - } - else if (auto arrayType = as(primalType)) - { - witness = getArrayWitness(builder, arrayType); - } - else if (auto extractExistential = as(primalType)) - { - witness = getExtractExistensialTypeWitness(builder, extractExistential); - } - else if (auto typePack = as(primalType)) + witness = buildDifferentiablePairWitness(builder, primalPairType, kind); + } + else if (auto arrayType = as(primalType)) + { + witness = buildArrayWitness(builder, arrayType, kind); + } + else if (auto extractExistential = as(primalType)) + { + witness = buildExtractExistensialTypeWitness(builder, extractExistential, kind); + } + else if (auto typePack = as(primalType)) + { + witness = buildTupleWitness(builder, typePack, kind); + } + else if (auto tupleType = as(primalType)) + { + witness = buildTupleWitness(builder, tupleType, kind); + } + else if (auto lookup = as(primalType)) + { + // For types that are lookups from a table, we can simply lookup the witness from the same table + if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey) { - witness = getTupleWitness(builder, typePack); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocTypeWitnessStructKey); } - else if (auto tupleType = as(primalType)) + + if (lookup->getRequirementKey() == sharedContext->differentialAssocRefTypeStructKey) { - witness = getTupleWitness(builder, tupleType); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocRefTypeWitnessStructKey); } } - return witness; + + // If we created a witness, register it. + if (witness) + { + addTypeToDictionary((IRType*)primalType, witness); + return witness; + } + + // Failed. Type is either non-differentiable, or unhandled. + return nullptr; } IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) @@ -868,77 +1068,97 @@ IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* witness); } -IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType) +IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( + IRBuilder* builder, + IRDifferentialPairTypeBase* pairType, + DiffConformanceKind target) { - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); - - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - bool isUserCodeType = as(pairType) ? true : false; - - // Fill in differential method implementations. - auto elementType = as(pairType)->getValueType(); - auto innerWitness = as(pairType)->getWitness(); - - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); - b.emitBlock(); - auto p0 = b.emitParam(diffDiffPairType); - auto p1 = b.emitParam(diffDiffPairType); - - // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - IRInst* argsPrimal[2] = { - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; - auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); - IRInst* argsDiff[2] = { - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; - auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) - : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); - b.emitReturn(retVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); - b.emitBlock(); - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) - : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); - b.emitReturn(retVal); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + bool isUserCodeType = as(pairType) ? true : false; + + // Fill in differential method implementations. + auto elementType = as(pairType)->getValueType(); + auto innerWitness = as(pairType)->getWitness(); + + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); + b.emitBlock(); + auto p0 = b.emitParam(diffDiffPairType); + auto p1 = b.emitParam(diffDiffPairType); + + // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + IRInst* argsPrimal[2] = { + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; + auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); + IRInst* argsDiff[2] = { + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; + auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) + : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); + b.emitReturn(retVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); + b.emitBlock(); + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) + : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + table = builder->createWitnessTable( + sharedContext->differentiableRefInterfaceType, + (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } - - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)pairType] = table; return table; } -IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType) +IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( + IRBuilder* builder, + IRArrayType* arrayType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType); @@ -946,70 +1166,89 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder if (!diffArrayType) return nullptr; - auto innerWitness = tryGetDifferentiableWitness(builder, as(arrayType)->getElementType()); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType)); + auto innerWitness = tryGetDifferentiableWitness(builder, as(arrayType)->getElementType(), DiffConformanceKind::Value); - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - auto elementType = as(diffArrayType)->getElementType(); + auto elementType = as(diffArrayType)->getElementType(); - // Fill in differential method implementations. + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffArrayType, diffArrayType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); + b.emitBlock(); + auto p0 = b.emitParam(diffArrayType); + auto p1 = b.emitParam(diffArrayType); + + // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto resultVar = b.emitVar(diffArrayType); + IRBlock* loopBodyBlock = nullptr; + IRBlock* loopBreakBlock = nullptr; + auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); + b.setInsertBefore(loopBodyBlock->getTerminator()); + + IRInst* args[2] = { + b.emitElementExtract(p0, loopCounter), + b.emitElementExtract(p1, loopCounter) }; + auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); + auto addr = b.emitElementAddress(resultVar, loopCounter); + b.emitStore(addr, elementResult); + b.setInsertInto(loopBreakBlock); + b.emitReturn(b.emitLoad(resultVar)); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); + b.emitBlock(); + + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffArrayType, diffArrayType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); - b.emitBlock(); - auto p0 = b.emitParam(diffArrayType); - auto p1 = b.emitParam(diffArrayType); + SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType)); - // Since we are already dealing with a DiffPair.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto resultVar = b.emitVar(diffArrayType); - IRBlock* loopBodyBlock = nullptr; - IRBlock* loopBreakBlock = nullptr; - auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); - b.setInsertBefore(loopBodyBlock->getTerminator()); + table = builder->createWitnessTable(sharedContext->differentiableRefInterfaceType, (IRType*)arrayType); - IRInst* args[2] = { - b.emitElementExtract(p0, loopCounter), - b.emitElementExtract(p1, loopCounter) }; - auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); - auto addr = b.emitElementAddress(resultVar, loopCounter); - b.emitStore(addr, elementResult); - b.setInsertInto(loopBreakBlock); - b.emitReturn(b.emitLoad(resultVar)); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } + else { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); - b.emitBlock(); - - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); - b.emitReturn(retVal); + SLANG_UNEXPECTED("Invalid conformance kind for synthesis"); } - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)arrayType] = table; - return table; } -IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType) +IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( + IRBuilder* builder, + IRInst* inTupleType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType); @@ -1017,100 +1256,116 @@ IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder if (!diffTupleType) return nullptr; - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - // Fill in differential method implementations. - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffTupleType, diffTupleType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); - b.emitBlock(); - auto p0 = b.emitParam(diffTupleType); - auto p1 = b.emitParam(diffTupleType); - List results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType)); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffTupleType, diffTupleType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); + b.emitBlock(); + auto p0 = b.emitParam(diffTupleType); + auto p1 = b.emitParam(diffTupleType); + List results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto iVal = b.getIntValue(b.getIntType(), i); + IRInst* args[2] = { + b.emitGetTupleElement(diffElementType, p0, iVal), + b.emitGetTupleElement(diffElementType, p1, iVal) }; + elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto iVal = b.getIntValue(b.getIntType(), i); - IRInst* args[2] = { - b.emitGetTupleElement(diffElementType, p0, iVal), - b.emitGetTupleElement(diffElementType, p1, iVal) }; - elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); - b.emitBlock(); - List results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.emitBlock(); + List results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); } + else if (target == DiffConformanceKind::Ptr) + { + SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType)); + + table = builder->createWitnessTable(sharedContext->differentiableRefInterfaceType, (IRType*)inTupleType); - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)inTupleType] = table; + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); + } return table; } -IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( +IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness( IRBuilder* builder, - IRExtractExistentialType* extractExistentialType) + IRExtractExistentialType* extractExistentialType, + DiffConformanceKind target) { + SLANG_UNUSED(target); // logic is the same for both value and ptr + // Check that the type's base is differentiable if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType())) { @@ -1310,11 +1565,19 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* if (context.isDifferentiableType((IRType*)typeInst)) return true; // Look for equivalent types. - for (auto type : context.differentiableWitnessDictionary) + for (auto type : context.differentiableValueTypeWitnessDictionary) + { + if (isTypeEqual(type.key, (IRType*)typeInst)) + { + context.differentiableValueTypeWitnessDictionary[(IRType*)typeInst] = type.value; + return true; + } + } + for (auto type : context.differentiablePtrTypeWitnessDictionary) { if (isTypeEqual(type.key, (IRType*)typeInst)) { - context.differentiableWitnessDictionary[(IRType*)typeInst] = type.value; + context.differentiablePtrTypeWitnessDictionary[(IRType*)typeInst] = type.value; return true; } } @@ -1671,7 +1934,7 @@ struct AutoDiffPass : public InstPassBase IRBuilder keyBuilder = builder; keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); auto diffKey = keyBuilder.createStructKey(); - auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey); + auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey, builder.getTypeKind()); info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); info.witness = diffFieldWitness; builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); @@ -1694,7 +1957,11 @@ struct AutoDiffPass : public InstPassBase List fieldVals; for (auto info : diffFields) { - auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey); + auto innerZeroMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->zeroMethodStructKey, + autodiffContext->zeroMethodType); IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr); fieldVals.add(val); } @@ -1718,7 +1985,11 @@ struct AutoDiffPass : public InstPassBase List fieldVals; for (auto info : diffFields) { - auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey); + auto innerAddMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->addMethodStructKey, + autodiffContext->addMethodType); IRInst* args[2] = { builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()), builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()), diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 812471fe3d..d85271ff31 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -57,6 +57,14 @@ struct DiffTranscriberSet AutoDiffTranscriberBase* backwardTranscriber = nullptr; }; + +enum class DiffConformanceKind +{ + Any = 0, // Perform actions for any conformance (infer from context) + Ptr = 1, // Perform actions for IDifferentiablePtrType + Value = 2 // Perform actions for IDifferentiable +}; + struct AutoDiffSharedContext { TargetProgram* targetProgram = nullptr; @@ -78,6 +86,7 @@ struct AutoDiffSharedContext // The struct key for the witness that `Differential` associated type conforms to // `IDifferential`. IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + IRWitnessTableType* differentialAssocTypeWitnessTableType = nullptr; // The struct key for the 'zero()' associated type @@ -85,12 +94,14 @@ struct AutoDiffSharedContext // implementation of zero() for a given type. // IRStructKey* zeroMethodStructKey = nullptr; + IRFuncType* zeroMethodType = nullptr; // The struct key for the 'add()' associated type // defined inside IDifferential. We use this to lookup the // implementation of add() for a given type. // IRStructKey* addMethodStructKey = nullptr; + IRFuncType* addMethodType = nullptr; IRStructKey* mulMethodStructKey = nullptr; @@ -104,12 +115,27 @@ struct AutoDiffSharedContext // IRInst* nullDifferentialWitness = nullptr; + + // A reference to the builtin IDifferentiablePtrType interface type. + IRInterfaceType* differentiableRefInterfaceType = nullptr; + + // The struct key for the 'Differential' associated type + // defined inside IDifferentialPtrType. We use this to lookup the differential + // type in the conformance table associated with the concrete type. + // + IRStructKey* differentialAssocRefTypeStructKey = nullptr; + + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferentialPtrType`. + IRStructKey* differentialAssocRefTypeWitnessStructKey = nullptr; + IRWitnessTableType* differentialAssocRefTypeWitnessTableType = nullptr; // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. // Set to false to indicate that we are uninitialized. // bool isInterfaceAvailable = false; + bool isPtrInterfaceAvailable = false; List followUpFunctionsToTranscribe; @@ -127,38 +153,71 @@ struct AutoDiffSharedContext IRStructKey* findDifferentialTypeStructKey() { - return getIDifferentiableStructKeyAtIndex(0); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 0)->getRequirementKey()); } IRStructKey* findDifferentialTypeWitnessStructKey() { - return getIDifferentiableStructKeyAtIndex(1); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementKey()); + } + + IRWitnessTableType* findDifferentialTypeWitnessTableType() + { + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementVal()); } IRStructKey* findZeroMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(2); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementKey()); } IRStructKey* findAddMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(3); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementKey()); } IRStructKey* findMulMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(4); + return cast( + getInterfaceEntryAtIndex(differentiableInterfaceType, 4)->getRequirementKey()); + } + + + IRStructKey* findDifferentialPtrTypeStructKey() + { + return cast( + getInterfaceEntryAtIndex(differentiableRefInterfaceType, 0)->getRequirementKey()); } - IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); + IRStructKey* findDifferentialPtrTypeWitnessStructKey() + { + return cast( + getInterfaceEntryAtIndex(differentiableRefInterfaceType, 1)->getRequirementKey()); + } + + IRWitnessTableType* findDifferentialPtrTypeWitnessTableType() + { + return cast( + getInterfaceEntryAtIndex(differentiableRefInterfaceType, 1)->getRequirementVal()); + } + + //IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); + IRInterfaceRequirementEntry* getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index); }; + struct DifferentiableTypeConformanceContext { AutoDiffSharedContext* sharedContext; IRGlobalValueWithCode* parentFunc = nullptr; - OrderedDictionary differentiableWitnessDictionary; + OrderedDictionary differentiableValueTypeWitnessDictionary; + OrderedDictionary differentiablePtrTypeWitnessDictionary; IRFunc* existentialDAddFunc = nullptr; @@ -167,7 +226,7 @@ struct DifferentiableTypeConformanceContext { // Populate dictionary with null differential type. if (sharedContext->nullDifferentialStructType) - differentiableWitnessDictionary.add( + differentiableValueTypeWitnessDictionary.add( sharedContext->nullDifferentialStructType, sharedContext->nullDifferentialWitness); } @@ -179,21 +238,13 @@ struct DifferentiableTypeConformanceContext // Lookup a witness table for the concreteType. One should exist if concreteType // inherits (successfully) from IDifferentiable. // - IRInst* lookUpConformanceForType(IRInst* type); + IRInst* lookUpConformanceForType(IRInst* type, DiffConformanceKind kind); - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType = nullptr); IRType* differentiateType(IRBuilder* builder, IRInst* primalType); - IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); - - IRInst* getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType); - - IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType); - - IRInst* getTupleWitness(IRBuilder* builder, IRInst* tupleType); - - IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType); + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); @@ -207,17 +258,21 @@ struct DifferentiableTypeConformanceContext IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + void addTypeToDictionary(IRType* type, IRInst* witness); + + IRInterfaceType* getConformanceTypeFromWitness(IRInst* witness); + IRInst* tryExtractConformanceFromInterfaceType( IRBuilder* builder, IRInterfaceType* interfaceType, IRWitnessTable* witnessTable); - List findDifferentiableInterfaceLookupPath( - IRInterfaceType* idiffType, + List findInterfaceLookupPath( + IRInterfaceType* supType, IRInterfaceType* type); // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. + // in order to conform to the IDifferentiable/IDifferentiablePtrType interfaces // Note that inside a generic block, this will be a witness table lookup instruction // that gets resolved during the specialization pass. // @@ -227,8 +282,10 @@ struct DifferentiableTypeConformanceContext { case kIROp_InterfaceType: { - if (isDifferentiableType(origType)) + if (isDifferentiableValueType(origType)) return this->sharedContext->differentiableInterfaceType; + else if (isDifferentiablePtrType(origType)) + return this->sharedContext->differentiableRefInterfaceType; else return nullptr; } @@ -254,12 +311,29 @@ struct DifferentiableTypeConformanceContext auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType); return builder->getDifferentialPairUserCodeType((IRType*)diffType, diffWitness); } + case kIROp_DifferentialPtrPairType: + { + auto diffPairType = as(origType); + auto diffType = getDiffTypeFromPairType(builder, diffPairType); + auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType); + return builder->getDifferentialPtrPairType((IRType*)diffType, diffWitness); + } default: - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); + if (isDifferentiableValueType(origType)) + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey, builder->getTypeKind()); + else if (isDifferentiablePtrType(origType)) + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocRefTypeStructKey, builder->getTypeKind()); + else + return nullptr; } } bool isDifferentiableType(IRType* origType) + { + return isDifferentiableValueType(origType) || isDifferentiablePtrType(origType); + } + + bool isDifferentiableValueType(IRType* origType) { for (; origType;) { @@ -279,7 +353,27 @@ struct DifferentiableTypeConformanceContext origType = (IRType*)origType->getOperand(0); continue; default: - return lookUpConformanceForType(origType) != nullptr; + return lookUpConformanceForType(origType, DiffConformanceKind::Value) != nullptr; + } + } + return false; + } + + bool isDifferentiablePtrType(IRType* origType) + { + for (; origType;) + { + switch (origType->getOp()) + { + case kIROp_VectorType: + case kIROp_ArrayType: + case kIROp_PtrType: + case kIROp_OutType: + case kIROp_InOutType: + origType = (IRType*)origType->getOperand(0); + continue; + default: + return lookUpConformanceForType(origType, DiffConformanceKind::Ptr) != nullptr; } } return false; @@ -287,13 +381,13 @@ struct DifferentiableTypeConformanceContext IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) { - auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); return result; } IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) { - auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey, sharedContext->addMethodType); return result; } @@ -307,8 +401,28 @@ struct DifferentiableTypeConformanceContext IRFunc* getOrCreateExistentialDAddMethod(); + IRInst* buildDifferentiablePairWitness( + IRBuilder* builder, + IRDifferentialPairTypeBase* pairType, + DiffConformanceKind target); + + IRInst* buildArrayWitness( + IRBuilder* builder, + IRArrayType* pairType, + DiffConformanceKind target); + + IRInst* buildTupleWitness( + IRBuilder* builder, + IRInst* tupleType, + DiffConformanceKind target); + + IRInst* buildExtractExistensialTypeWitness( + IRBuilder* builder, + IRExtractExistentialType* extractExistentialType, + DiffConformanceKind target); }; + struct DifferentialPairTypeBuilder { DifferentialPairTypeBuilder() = default; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 8b4886a2cf..cae47fffde 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -625,7 +625,7 @@ struct CheckDifferentiabilityPassContext : public InstPassBase } } - if (!sharedContext.isInterfaceAvailable) + if (!sharedContext.isInterfaceAvailable && !sharedContext.isPtrInterfaceAvailable) return; for (auto inst : module->getGlobalInsts()) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index afc09f4801..4105fb18c0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -61,7 +61,8 @@ INST(Nop, nop, 0, 0) INST(DifferentialPairType, DiffPair, 1, HOISTABLE) INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE) - INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPairUserCodeType) + INST(DifferentialPtrPairType, DiffRefPair, 1, HOISTABLE) + INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPtrPairType) INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) @@ -325,15 +326,18 @@ INST(DefaultConstruct, defaultConstruct, 0, 0) INST(MakeDifferentialPair, MakeDiffPair, 2, 0) INST(MakeDifferentialPairUserCode, MakeDiffPairUserCode, 2, 0) -INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPairUserCode) +INST(MakeDifferentialPtrPair, MakeDiffRefPair, 2, 0) +INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPtrPair) INST(DifferentialPairGetDifferential, GetDifferential, 1, 0) INST(DifferentialPairGetDifferentialUserCode, GetDifferentialUserCode, 1, 0) -INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPairGetDifferentialUserCode) +INST(DifferentialPtrPairGetDifferential, GetDifferentialPtr, 1, 0) +INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPtrPairGetDifferential) INST(DifferentialPairGetPrimal, GetPrimal, 1, 0) INST(DifferentialPairGetPrimalUserCode, GetPrimalUserCode, 1, 0) -INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPairGetPrimalUserCode) +INST(DifferentialPtrPairGetPrimal, GetPrimalRef, 1, 0) +INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPtrPairGetPrimal) INST(Specialize, specialize, 2, HOISTABLE) INST(LookupWitness, lookupWitness, 2, HOISTABLE) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a0ed8ff0e7..56412ff989 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2938,6 +2938,10 @@ struct IRMakeDifferentialPairUserCode : IRMakeDifferentialPairBase { IR_LEAF_ISA(MakeDifferentialPairUserCode) }; +struct IRMakeDifferentialPtrPair : IRMakeDifferentialPairBase +{ + IR_LEAF_ISA(MakeDifferentialPtrPair) +}; struct IRDifferentialPairGetDifferentialBase : IRInst { @@ -2952,6 +2956,10 @@ struct IRDifferentialPairGetDifferentialUserCode : IRDifferentialPairGetDifferen { IR_LEAF_ISA(DifferentialPairGetDifferentialUserCode) }; +struct IRDifferentialPtrPairGetDifferential : IRDifferentialPairGetDifferentialBase +{ + IR_LEAF_ISA(DifferentialPtrPairGetDifferential) +}; struct IRDifferentialPairGetPrimalBase : IRInst { @@ -2966,6 +2974,10 @@ struct IRDifferentialPairGetPrimalUserCode : IRDifferentialPairGetPrimalBase { IR_LEAF_ISA(DifferentialPairGetPrimalUserCode) }; +struct IRDifferentialPtrPairGetPrimal : IRDifferentialPairGetPrimalBase +{ + IR_LEAF_ISA(DifferentialPtrPairGetPrimal) +}; struct IRDetachDerivative : IRInst { @@ -3636,6 +3648,10 @@ struct IRBuilder IRDifferentialPairType* getDifferentialPairType( IRType* valueType, IRInst* witnessTable); + + IRDifferentialPtrPairType* getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable); IRDifferentialPairUserCodeType* getDifferentialPairUserCodeType( IRType* valueType, @@ -3777,6 +3793,7 @@ struct IRBuilder IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential); IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); @@ -3959,8 +3976,11 @@ struct IRBuilder IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value); IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetPrimal(IRInst* diffPair); IRInst* emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair); IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair); IRInst* emitMakeVector( diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9305d17830..da0c584203 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3022,6 +3022,17 @@ namespace Slang operands); } + IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPtrPairType*)getType( + kIROp_DifferentialPtrPairType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType( IRType* valueType, IRInst* witnessTable) @@ -3515,6 +3526,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential) + { + SLANG_RELEASE_ASSERT(as(type)); + SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); + + IRInst* args[] = {primal, differential}; + auto inst = createInstWithTrailingArgs( + this, kIROp_MakeDifferentialPtrPair, type, 2, args); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as(type)); @@ -4230,6 +4253,17 @@ namespace Slang &diffPair); } + + IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + SLANG_ASSERT(as(diffPair->getDataType())); + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPtrPairGetDifferential, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) { auto valueType = cast(diffPair->getDataType())->getValueType(); @@ -4249,6 +4283,25 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair) + { + auto valueType = cast(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair) + { + return emitIntrinsicInst( + primalType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as(diffPair->getDataType())); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 375107d1d4..14dde200f4 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1660,6 +1660,11 @@ struct IRDifferentialPairType : IRDifferentialPairTypeBase IR_LEAF_ISA(DifferentialPairType) }; +struct IRDifferentialPtrPairType : IRDifferentialPairTypeBase +{ + IR_LEAF_ISA(DifferentialPtrPairType) +}; + struct IRDifferentialPairUserCodeType : IRDifferentialPairTypeBase { IR_LEAF_ISA(DifferentialPairUserCodeType) diff --git a/tests/autodiff/diff-ptr-type-smoke.slang b/tests/autodiff/diff-ptr-type-smoke.slang new file mode 100644 index 0000000000..e7e03c5e37 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-smoke.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + +[BackwardDifferentiable] +float test(MyPtrType b, uint idx) +{ + return load(b, idx) + load(b, idx + 1); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, id, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} \ No newline at end of file From 53b9cfb92630e6561bbfc68bbba6b6f27a26954c Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Fri, 6 Sep 2024 15:52:45 -0400 Subject: [PATCH 03/14] Fix unused vars --- source/slang/slang-ir-autodiff-fwd.cpp | 18 ++---------------- source/slang/slang-ir-autodiff-rev.cpp | 2 +- .../slang-ir-autodiff-transcriber-base.cpp | 2 +- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 53d36af461..fbc60d90d7 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -423,14 +423,14 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load); return InstPair(primalElement, diffElement); } - else if (auto diffRefPairType = as(primalPtrType->getValueType())) + else if (auto diffPtrPairType = as(primalPtrType->getValueType())) { auto load = builder->emitLoad(primalPtr); builder->markInstAsPrimal(load); auto primalElement = builder->emitDifferentialPtrPairGetPrimal(load); auto diffElement = builder->emitDifferentialPtrPairGetDifferential( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load); + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPtrPairType), load); builder->markInstAsPrimal(primalElement); builder->markInstAsPrimal(diffElement); return InstPair(primalElement, diffElement); @@ -790,20 +790,6 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto pairValType = as( pairPtrType ? pairPtrType->getValueType() : pairType); - DiffConformanceKind kind = DiffConformanceKind::Any; - if (as(pairValType)) - { - kind = DiffConformanceKind::Ptr; - } - else if (as(pairValType)) - { - kind = DiffConformanceKind::Value; - } - else - { - SLANG_ASSERT(!"unreachable"); - } - auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType); if (auto ptrParamType = as(diffParamType)) { diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 02bc8190cd..347649538f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -156,7 +156,7 @@ namespace Slang // Unwrap any ref pairs. We need this special case for trivial funcs. for (Int i = 0; i < params.getCount(); i++) { - if (auto diffPairType = as(params[i]->getDataType())) + if (as(params[i]->getDataType())) { params[i] = builder->emitDifferentialPtrPairGetPrimal(params[i]); } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 011f7c923b..e4656d485f 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1167,7 +1167,7 @@ void AutoDiffTranscriberBase::markDiffPairTypeInst(IRBuilder* builder, IRInst* d SLANG_ASSERT(pairType); SLANG_ASSERT(as(pairType)); - if (auto diffPairType = as(pairType)) + if (as(pairType)) { builder->markInstAsMixedDifferential(diffPairInst, pairType); } From 2cbf5f5276b9f8dd70d2b5560a6a9f669b452c41 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 10 Sep 2024 19:13:57 -0400 Subject: [PATCH 04/14] More tests + fix switch case fallthrough. --- source/slang/slang-ir-autodiff.cpp | 3 ++ tests/autodiff/diff-ptr-type-call.slang | 57 ++++++++++++++++++++++ tests/autodiff/diff-ptr-type-loop.slang | 63 +++++++++++++++++++++++++ 3 files changed, 123 insertions(+) create mode 100644 tests/autodiff/diff-ptr-type-call.slang create mode 100644 tests/autodiff/diff-ptr-type-loop.slang diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 93174b3913..facb7f026c 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -626,9 +626,12 @@ IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* t switch (kind) { case DiffConformanceKind::Any: + { differentiableValueTypeWitnessDictionary.tryGetValue(type, foundResult); if (!foundResult) differentiablePtrTypeWitnessDictionary.tryGetValue(type, foundResult); + break; + } case DiffConformanceKind::Value: differentiableValueTypeWitnessDictionary.tryGetValue(type, foundResult); break; diff --git a/tests/autodiff/diff-ptr-type-call.slang b/tests/autodiff/diff-ptr-type-call.slang new file mode 100644 index 0000000000..258a4477b5 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-call.slang @@ -0,0 +1,57 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + +// ------ +[Differentiable] +float reduce(MyPtrType a) +{ + return load(a, 0) + load(a, 1); +} + +[Differentiable] +float test(MyPtrType b) +{ + return reduce(b); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} \ No newline at end of file diff --git a/tests/autodiff/diff-ptr-type-loop.slang b/tests/autodiff/diff-ptr-type-loop.slang new file mode 100644 index 0000000000..423fee881b --- /dev/null +++ b/tests/autodiff/diff-ptr-type-loop.slang @@ -0,0 +1,63 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + uint offset; + + float load(uint idx) { return outputBuffer[offset + idx]; } + void accumulate(uint idx, float value) { outputBuffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + + +// ------ +[Differentiable] +float reduce(MyPtrType a, uint num) +{ + float sum = 0; + [MaxIters(3)] + for (uint i = 0; i < num; i++) + { + sum += load(a, i); + } + + return sum; +} + +[Differentiable] +float test(MyPtrType b, uint num) +{ + return reduce(b, num); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair( { 0 }, { 2 } ); + + bwd_diff(test)(b, 2, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} \ No newline at end of file From 96fa52dba332b2aa2a1b408b2f9eb20ff74d5e73 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 10 Sep 2024 19:24:48 -0400 Subject: [PATCH 05/14] Update slang-ir-autodiff.cpp --- source/slang/slang-ir-autodiff.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index facb7f026c..97b0c2fddb 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -514,6 +514,13 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit auto innerWitnessTableType = cast(interfaceRequirementEntry->getRequirementVal()); diffInterfaceType = cast(innerWitnessTableType->getConformanceType()); } + else if (auto tupleType = as(witness->getDataType())) + { + SLANG_ASSERT(tupleType->getOperandCount() >= 1); + auto operand = tupleType->getOperand(0); + auto innerWitnessTableType = cast(operand); + return cast(innerWitnessTableType->getConformanceType()); + } else { SLANG_UNEXPECTED("Unexpected witness type"); @@ -571,11 +578,11 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) if (diffInterfaceType == sharedContext->differentiableInterfaceType) addTypeToDictionary( (IRType*)element, - _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind())); + elementWitness); else if (diffInterfaceType == sharedContext->differentiableRefInterfaceType) addTypeToDictionary( (IRType*)element, - _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocRefTypeStructKey, subBuilder.getTypeKind())); + elementWitness); } return; } From 7c93b9d55201844558f4ee993eb46245cea232da Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 10 Sep 2024 19:45:54 -0400 Subject: [PATCH 06/14] Update diff-ptr-type-loop.slang --- tests/autodiff/diff-ptr-type-loop.slang | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/autodiff/diff-ptr-type-loop.slang b/tests/autodiff/diff-ptr-type-loop.slang index 423fee881b..4354837112 100644 --- a/tests/autodiff/diff-ptr-type-loop.slang +++ b/tests/autodiff/diff-ptr-type-loop.slang @@ -1,4 +1,5 @@ -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; From 96a2791498c6068420eb70d91b919800e3d78a6b Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 11 Sep 2024 16:01:44 -0400 Subject: [PATCH 07/14] Add optimization to allow more complex pair types --- source/slang/slang-ir-autodiff-fwd.cpp | 12 ++++++++++++ tests/autodiff/diff-ptr-type-loop.slang | 8 ++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fbc60d90d7..4a6ee5f7f8 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -25,6 +25,18 @@ IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* p } else if (as(pairType)) { + // Quick optimization: + // If primalVal and diffVal are extracted from the same pointer-pair, + // we can just use the pointer-pair directly. + // + if (auto primalPtrVal = as(primalVal)) + { + if (auto diffPtrVal = as(diffVal)) + { + if (primalPtrVal->getBase() == diffPtrVal->getBase()) + return primalPtrVal->getBase(); + } + } return builder->emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); } else diff --git a/tests/autodiff/diff-ptr-type-loop.slang b/tests/autodiff/diff-ptr-type-loop.slang index 4354837112..17d0be714f 100644 --- a/tests/autodiff/diff-ptr-type-loop.slang +++ b/tests/autodiff/diff-ptr-type-loop.slang @@ -1,4 +1,3 @@ -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer @@ -9,10 +8,11 @@ struct MyPtrType : IDifferentiablePtrType { typealias Differential = MyPtrType; + RWStructuredBuffer buffer; uint offset; - float load(uint idx) { return outputBuffer[offset + idx]; } - void accumulate(uint idx, float value) { outputBuffer[offset + idx] += value; } + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } } [BackwardDerivative(load_bwd)] @@ -54,7 +54,7 @@ void computeMain(uint id : SV_DispatchThreadID) outputBuffer[1] = 2; // CHECK: 2 // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. - var b = DifferentialPtrPair( { 0 }, { 2 } ); + var b = DifferentialPtrPair( { outputBuffer, 0 }, { outputBuffer, 2 } ); bwd_diff(test)(b, 2, 1.5f); From 6919026d9a64a843e3eed0f277b2f5a1e7d27ab0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 11 Sep 2024 16:36:48 -0400 Subject: [PATCH 08/14] Update slang-ir-autodiff-primal-hoist.cpp --- source/slang/slang-ir-autodiff-primal-hoist.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 9fe4ec70b6..d4dd453d3d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1980,6 +1980,7 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_MakeArrayFromElement: case kIROp_MakeDifferentialPair: case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: case kIROp_MakeOptionalNone: case kIROp_MakeOptionalValue: case kIROp_MakeExistential: @@ -1987,6 +1988,8 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferentialUserCode: case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetDifferential: + case kIROp_DifferentialPtrPairGetPrimal: case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialWitnessTable: From 65c9f2c1f8c5c9506b1614a68cff3a2c32a8d725 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 11 Sep 2024 16:36:53 -0400 Subject: [PATCH 09/14] Update diff-ptr-type-loop.slang --- tests/autodiff/diff-ptr-type-loop.slang | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/autodiff/diff-ptr-type-loop.slang b/tests/autodiff/diff-ptr-type-loop.slang index 17d0be714f..a57c69b760 100644 --- a/tests/autodiff/diff-ptr-type-loop.slang +++ b/tests/autodiff/diff-ptr-type-loop.slang @@ -1,3 +1,4 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer From 0d4a61508808a31478db69a8f48802de0f2220d5 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 11 Sep 2024 17:44:42 -0400 Subject: [PATCH 10/14] Update slang-ir-autodiff-primal-hoist.cpp --- source/slang/slang-ir-autodiff-primal-hoist.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index d4dd453d3d..7f1cf9bb58 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -891,6 +891,16 @@ void applyToInst( } } SLANG_ASSERT(replacement); + + // If the replacement and inst are not the exact same type, use an int-cast + // (e.g. uint vs. int) + // + if (replacement->getDataType() != inst->getDataType()) + { + setInsertAfterOrdinaryInst(builder, replacement); + replacement = builder->emitCast(inst->getDataType(), replacement); + } + cloneCtx->cloneEnv.mapOldValToNew[inst] = replacement; cloneCtx->registerClonedInst(builder, inst, replacement); return; From 7e79d00001375ccad7ace0edd1ec07f07160dad9 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 17 Sep 2024 14:23:25 -0400 Subject: [PATCH 11/14] More fixes to address reviews --- source/slang/slang-check-conformance.cpp | 10 +++++++--- source/slang/slang-check-expr.cpp | 17 +++++++++-------- source/slang/slang-check-impl.h | 2 +- source/slang/slang-ir-autodiff.cpp | 2 -- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index fb170222d9..9d9047e417 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -274,10 +274,14 @@ namespace Slang return isInterfaceType(type); } - bool SemanticsVisitor::isTypeDifferentiable(Type* type) + SubtypeWitness* SemanticsVisitor::isTypeDifferentiable(Type* type) { - return isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None) || - isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None); + if (auto valueWitness = isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None)) + return valueWitness; + else if (auto ptrWitness = isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None)) + return ptrWitness; + + return nullptr; } bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 557d0345c8..9b742825e4 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2909,15 +2909,16 @@ namespace Slang // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (auto conformanceValWitness = as( - isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None))) + if (auto conformanceWitness = isTypeDifferentiable(primalType)) { - return m_astBuilder->getDifferentialPairType(primalType, conformanceValWitness); - } - else if (auto conformancePtrWitness = as( - isSubtype(primalType, differentiableRefInterface, IsSubTypeOptions::None))) - { - return m_astBuilder->getDifferentialPtrPairType(primalType, conformancePtrWitness); + if (conformanceWitness->getSup() == differentiableInterface) + { + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + } + else if (conformanceWitness->getSup() == differentiableRefInterface) + { + return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness); + } } else return primalType; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ad3539a217..a9726069c3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2190,7 +2190,7 @@ namespace Slang bool isValidGenericConstraintType(Type* type); - bool isTypeDifferentiable(Type* type); + SubtypeWitness* isTypeDifferentiable(Type* type); bool doesTypeHaveTag(Type* type, TypeTag tag); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 97b0c2fddb..662d9eb293 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -666,7 +666,6 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { return this->differentiateType(builder, type->getValueType()); - //return _getDiffTypeFromPairType(sharedContext, builder, type); } IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) @@ -896,7 +895,6 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() if (auto pairType = as(globalInst)) { addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); - //differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); } } } From 4fc4af9fa64a18727b0df43b6c7b39bcc29eace2 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 17 Sep 2024 14:24:06 -0400 Subject: [PATCH 12/14] Update slang-check-expr.cpp --- source/slang/slang-check-expr.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 9b742825e4..b7bb7111c8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2920,8 +2920,7 @@ namespace Slang return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness); } } - else - return primalType; + return primalType; } Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) From 66507ecf0460f74687afbce6135392ebc3bbfcd6 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 18 Sep 2024 13:17:56 -0400 Subject: [PATCH 13/14] Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` --- .../slang-ir-autodiff-transcriber-base.cpp | 31 ++++++++++++------- source/slang/slang-ir-autodiff.cpp | 18 +++++------ source/slang/slang-ir-autodiff.h | 10 +++--- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index e4656d485f..6662e352f5 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -199,7 +199,7 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI return builder->getDifferentialPairType((IRType*)primalType, witness); } else if (autoDiffSharedContext->isPtrInterfaceAvailable && - conformanceType == autoDiffSharedContext->differentiableRefInterfaceType) + conformanceType == autoDiffSharedContext->differentiablePtrInterfaceType) { return builder->getDifferentialPtrPairType((IRType*)primalType, witness); } @@ -235,7 +235,7 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Value)) return autoDiffSharedContext->differentiableInterfaceType; else if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Ptr)) - return autoDiffSharedContext->differentiableRefInterfaceType; + return autoDiffSharedContext->differentiablePtrInterfaceType; else return nullptr; } @@ -457,17 +457,24 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* auto interfaceType = as(unwrapAttributedType(origType->getOperand(0)->getDataType())); if (!interfaceType) return nullptr; - List lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath( - autoDiffSharedContext->differentiableInterfaceType, interfaceType); - List lookupPathPtrType = differentiableTypeConformanceContext.findInterfaceLookupPath( - autoDiffSharedContext->differentiableRefInterfaceType, interfaceType); - SLANG_ASSERT(!(lookupPathValueType.getCount() && lookupPathPtrType.getCount())); + List lookupKeyPath; + IRStructKey* diffStructKey = nullptr; - auto lookupKeyPath = lookupPathValueType.getCount() ? lookupPathValueType : lookupPathPtrType; - auto diffStructKey = lookupPathValueType.getCount() ? - autoDiffSharedContext->differentialAssocTypeStructKey : - autoDiffSharedContext->differentialAssocRefTypeStructKey; + List lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath( + autoDiffSharedContext->differentiableInterfaceType, interfaceType); + if (lookupPathValueType.getCount() > 0) + { + lookupKeyPath = lookupPathValueType; + diffStructKey = autoDiffSharedContext->differentialAssocTypeStructKey; + } + else + { + // Try IDifferentiablePtrType + lookupKeyPath = differentiableTypeConformanceContext.findInterfaceLookupPath( + autoDiffSharedContext->differentiablePtrInterfaceType, interfaceType); + diffStructKey = autoDiffSharedContext->differentialAssocRefTypeStructKey; + } if (lookupKeyPath.getCount()) { @@ -582,7 +589,7 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui return InstPair(primal, diffWitness); } - else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiableRefInterfaceType) + else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiablePtrInterfaceType) { auto primalDiffType = builder->emitLookupInterfaceMethodInst( builder->getTypeKind(), diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 662d9eb293..3e896ed251 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -74,7 +74,7 @@ static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext* sharedContext, IR if (as(type) || as(type)) return sharedContext->differentiableInterfaceType; else if (as(type)) - return sharedContext->differentiableRefInterfaceType; + return sharedContext->differentiablePtrInterfaceType; else SLANG_UNEXPECTED("Unexpected differential pair type"); } @@ -391,9 +391,9 @@ AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst isInterfaceAvailable = true; } - differentiableRefInterfaceType = as(findDifferentiableRefInterface(inModuleInst)); + differentiablePtrInterfaceType = as(findDifferentiableRefInterface(inModuleInst)); - if (differentiableRefInterfaceType) + if (differentiablePtrInterfaceType) { differentialAssocRefTypeStructKey = findDifferentialPtrTypeStructKey(); differentialAssocRefTypeWitnessStructKey = findDifferentialPtrTypeWitnessStructKey(); @@ -546,7 +546,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) SLANG_ASSERT( diffInterfaceType == sharedContext->differentiableInterfaceType - || diffInterfaceType == sharedContext->differentiableRefInterfaceType); + || diffInterfaceType == sharedContext->differentiablePtrInterfaceType); //lookUpConformanceForType(item->getConcreteType()); // TODO: need to consider ref type. @@ -579,7 +579,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) addTypeToDictionary( (IRType*)element, elementWitness); - else if (diffInterfaceType == sharedContext->differentiableRefInterfaceType) + else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType) addTypeToDictionary( (IRType*)element, elementWitness); @@ -695,7 +695,7 @@ void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRI differentiableValueTypeWitnessDictionary.addIfNotExists(type, witness); } else if (sharedContext->isPtrInterfaceAvailable && - conformanceType == sharedContext->differentiableRefInterfaceType) + conformanceType == sharedContext->differentiablePtrInterfaceType) { differentiablePtrTypeWitnessDictionary.addIfNotExists(type, witness); } @@ -1152,7 +1152,7 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); table = builder->createWitnessTable( - sharedContext->differentiableRefInterfaceType, + sharedContext->differentiablePtrInterfaceType, (IRType*)pairType); // And place it in the synthesized witness table. @@ -1239,7 +1239,7 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( { SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType)); - table = builder->createWitnessTable(sharedContext->differentiableRefInterfaceType, (IRType*)arrayType); + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)arrayType); // And place it in the synthesized witness table. builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffArrayType); @@ -1357,7 +1357,7 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( { SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType)); - table = builder->createWitnessTable(sharedContext->differentiableRefInterfaceType, (IRType*)inTupleType); + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType); // And place it in the synthesized witness table. builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffTupleType); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index d85271ff31..ce64ba6a48 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -117,7 +117,7 @@ struct AutoDiffSharedContext // A reference to the builtin IDifferentiablePtrType interface type. - IRInterfaceType* differentiableRefInterfaceType = nullptr; + IRInterfaceType* differentiablePtrInterfaceType = nullptr; // The struct key for the 'Differential' associated type // defined inside IDifferentialPtrType. We use this to lookup the differential @@ -191,19 +191,19 @@ struct AutoDiffSharedContext IRStructKey* findDifferentialPtrTypeStructKey() { return cast( - getInterfaceEntryAtIndex(differentiableRefInterfaceType, 0)->getRequirementKey()); + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 0)->getRequirementKey()); } IRStructKey* findDifferentialPtrTypeWitnessStructKey() { return cast( - getInterfaceEntryAtIndex(differentiableRefInterfaceType, 1)->getRequirementKey()); + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementKey()); } IRWitnessTableType* findDifferentialPtrTypeWitnessTableType() { return cast( - getInterfaceEntryAtIndex(differentiableRefInterfaceType, 1)->getRequirementVal()); + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementVal()); } //IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); @@ -285,7 +285,7 @@ struct DifferentiableTypeConformanceContext if (isDifferentiableValueType(origType)) return this->sharedContext->differentiableInterfaceType; else if (isDifferentiablePtrType(origType)) - return this->sharedContext->differentiableRefInterfaceType; + return this->sharedContext->differentiablePtrInterfaceType; else return nullptr; } From b12dbfc9d02a8cf3cd38d74079a967fc7c4325f8 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 18 Sep 2024 17:58:20 -0400 Subject: [PATCH 14/14] Move pair logic to ir-builder, unify the type dictionaries. --- source/slang/slang-ir-autodiff-fwd.cpp | 104 +++---------------------- source/slang/slang-ir-autodiff.cpp | 70 +++++++---------- source/slang/slang-ir-autodiff.h | 5 +- source/slang/slang-ir-insts.h | 10 ++- source/slang/slang-ir.cpp | 88 ++++++++++++++++++++- 5 files changed, 132 insertions(+), 145 deletions(-) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 4a6ee5f7f8..609bcd8a33 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -16,86 +16,6 @@ namespace Slang { - -IRInst* emitMakeDifferentialPair(IRBuilder* builder, IRType* pairType, IRInst* primalVal, IRInst* diffVal) -{ - if (as(pairType)) - { - return builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); - } - else if (as(pairType)) - { - // Quick optimization: - // If primalVal and diffVal are extracted from the same pointer-pair, - // we can just use the pointer-pair directly. - // - if (auto primalPtrVal = as(primalVal)) - { - if (auto diffPtrVal = as(diffVal)) - { - if (primalPtrVal->getBase() == diffPtrVal->getBase()) - return primalPtrVal->getBase(); - } - } - return builder->emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); - } - else - { - SLANG_ASSERT(!"unreachable"); - return nullptr; - } -} - -IRInst* emitDifferentialPairGetDifferential(IRBuilder* builder, IRType* diffType, IRInst* pairVal) -{ - if (as(pairVal->getDataType())) - { - return builder->emitDifferentialPairGetDifferential(diffType, pairVal); - } - else if (as(pairVal->getDataType())) - { - return builder->emitDifferentialPtrPairGetDifferential(diffType, pairVal); - } - else - { - SLANG_ASSERT(!"unreachable"); - return nullptr; - } -} - -IRInst* emitDifferentialPairGetPrimal(IRBuilder* builder, IRInst* pairVal) -{ - if (as(pairVal->getDataType())) - { - return builder->emitDifferentialPairGetPrimal(pairVal); - } - else if (as(pairVal->getDataType())) - { - return builder->emitDifferentialPtrPairGetPrimal(pairVal); - } - else - { - SLANG_ASSERT(!"unreachable"); - return nullptr; - } -} - -IRInst* emitDifferentialPairGetPrimal(IRBuilder* builder, IRType* primalType, IRInst* pairVal) -{ - if (as(pairVal->getDataType())) - { - return builder->emitDifferentialPairGetPrimal(primalType, pairVal); - } - else if (as(pairVal->getDataType())) - { - return builder->emitDifferentialPtrPairGetPrimal(primalType, pairVal); - } - else - { - SLANG_ASSERT(!"unreachable"); - return nullptr; - } -} IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) { @@ -822,7 +742,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig diffArgVal = argBuilder.emitLoad(diffArg); markDiffTypeInst(&argBuilder, diffArgVal, pairValType->getValueType()); } - auto initVal = emitMakeDifferentialPair(&argBuilder, pairValType, primalVal, diffArgVal); + auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal); markDiffPairTypeInst(&argBuilder, initVal, pairValType); auto store = argBuilder.emitStore(srcVar, initVal); markDiffPairTypeInst(&argBuilder, store, pairValType); @@ -832,12 +752,12 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig // Read back new value. auto newVal = afterBuilder.emitLoad(srcVar); markDiffPairTypeInst(&afterBuilder, newVal, pairValType); - auto newPrimalVal = emitDifferentialPairGetPrimal(&afterBuilder, pairValType->getValueType(), newVal); + auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal); afterBuilder.emitStore(primalArg, newPrimalVal); if (diffArg) { - auto newDiffVal = emitDifferentialPairGetDifferential(&afterBuilder, (IRType*)diffType, newVal); + auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal); markDiffTypeInst(&afterBuilder, newDiffVal, pairValType->getValueType()); auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal); @@ -856,7 +776,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig // If a pair type can be formed, this must be non-null. SLANG_RELEASE_ASSERT(diffArg); - auto diffPair = emitMakeDifferentialPair(&argBuilder, pairType, primalArg, diffArg); + auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg); markDiffPairTypeInst(&argBuilder, diffPair, pairType); args.add(diffPair); @@ -891,9 +811,9 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (as(diffReturnType) || as(diffReturnType)) { - IRInst* primalResultValue = emitDifferentialPairGetPrimal(&afterBuilder, callInst); + IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst); auto diffType = differentiateType(&afterBuilder, origCall->getFullType()); - IRInst* diffResultValue = emitDifferentialPairGetDifferential(&afterBuilder, diffType, callInst); + IRInst* diffResultValue = afterBuilder.emitDifferentialPairGetDifferential(diffType, callInst); return InstPair(primalResultValue, diffResultValue); } else @@ -1860,8 +1780,7 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr auto diffVal = builder.emitLoad(writeBack.value.differential); markDiffTypeInst(&builder, diffVal, primalVal->getFullType()); - valToStore = emitMakeDifferentialPair( - &builder, pairValType, primalVal, diffVal); + valToStore = builder.emitMakeDifferentialPair(pairValType, primalVal, diffVal); markDiffPairTypeInst(&builder, valToStore, pairValType); } @@ -2153,9 +2072,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam if (as(diffPairType) || as(diffPairType)) { return InstPair( - emitDifferentialPairGetPrimal(builder, diffPairParam), - emitDifferentialPairGetDifferential( - builder, + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( builder, as(diffPairType)), @@ -2183,8 +2101,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam auto initVal = builder->emitLoad(diffPairParam); markDiffPairTypeInst(builder, initVal, ptrInnerPairType); - primalInitVal = emitDifferentialPairGetPrimal(builder, initVal); - diffInitVal = emitDifferentialPairGetDifferential(builder, diffType, initVal); + primalInitVal = builder->emitDifferentialPairGetPrimal(initVal); + diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal); } markDiffTypeInst(builder, diffInitVal, ptrInnerPairType->getValueType()); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 3e896ed251..ddef011774 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -548,9 +548,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) diffInterfaceType == sharedContext->differentiableInterfaceType || diffInterfaceType == sharedContext->differentiablePtrInterfaceType); - //lookUpConformanceForType(item->getConcreteType()); - // TODO: need to consider ref type. - auto existingItem = differentiableValueTypeWitnessDictionary.tryGetValue(item->getConcreteType()); + auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType()); if (existingItem) { *existingItem = item->getWitness(); @@ -629,25 +627,22 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type, DiffConformanceKind kind) { IRInst* foundResult = nullptr; - - switch (kind) - { - case DiffConformanceKind::Any: + differentiableTypeWitnessDictionary.tryGetValue(type, foundResult); + if (!foundResult) + return nullptr; + + if (kind == DiffConformanceKind::Any) + return foundResult; + + if (auto baseType = getConformanceTypeFromWitness(foundResult)) { - differentiableValueTypeWitnessDictionary.tryGetValue(type, foundResult); - if (!foundResult) - differentiablePtrTypeWitnessDictionary.tryGetValue(type, foundResult); - break; - } - case DiffConformanceKind::Value: - differentiableValueTypeWitnessDictionary.tryGetValue(type, foundResult); - break; - case DiffConformanceKind::Ptr: - differentiablePtrTypeWitnessDictionary.tryGetValue(type, foundResult); - break; + if (baseType == sharedContext->differentiableInterfaceType && kind == DiffConformanceKind::Value) + return foundResult; + else if (baseType == sharedContext->differentiablePtrInterfaceType && kind == DiffConformanceKind::Ptr) + return foundResult; } - return foundResult; + return nullptr; } IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType) @@ -687,22 +682,16 @@ IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBui void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRInst* witness) { - //auto witnessType = cast(witness->getDataType()); auto conformanceType = getConformanceTypeFromWitness(witness); - if (sharedContext->isInterfaceAvailable && - conformanceType == sharedContext->differentiableInterfaceType) - { - differentiableValueTypeWitnessDictionary.addIfNotExists(type, witness); - } - else if (sharedContext->isPtrInterfaceAvailable && - conformanceType == sharedContext->differentiablePtrInterfaceType) - { - differentiablePtrTypeWitnessDictionary.addIfNotExists(type, witness); - } - else - { - SLANG_UNEXPECTED("Unexpected witness type"); - } + + if (!sharedContext->isInterfaceAvailable && !sharedContext->isPtrInterfaceAvailable) + return; + + SLANG_ASSERT( + conformanceType == sharedContext->differentiableInterfaceType || + conformanceType == sharedContext->differentiablePtrInterfaceType); + + differentiableTypeWitnessDictionary.addIfNotExists(type, witness); } IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable) @@ -1572,20 +1561,13 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* if (context.isDifferentiableType((IRType*)typeInst)) return true; + // Look for equivalent types. - for (auto type : context.differentiableValueTypeWitnessDictionary) - { - if (isTypeEqual(type.key, (IRType*)typeInst)) - { - context.differentiableValueTypeWitnessDictionary[(IRType*)typeInst] = type.value; - return true; - } - } - for (auto type : context.differentiablePtrTypeWitnessDictionary) + for (auto type : context.differentiableTypeWitnessDictionary) { if (isTypeEqual(type.key, (IRType*)typeInst)) { - context.differentiablePtrTypeWitnessDictionary[(IRType*)typeInst] = type.value; + context.differentiableTypeWitnessDictionary[(IRType*)typeInst] = type.value; return true; } } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index ce64ba6a48..ad2486aad4 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -216,8 +216,7 @@ struct DifferentiableTypeConformanceContext AutoDiffSharedContext* sharedContext; IRGlobalValueWithCode* parentFunc = nullptr; - OrderedDictionary differentiableValueTypeWitnessDictionary; - OrderedDictionary differentiablePtrTypeWitnessDictionary; + OrderedDictionary differentiableTypeWitnessDictionary; IRFunc* existentialDAddFunc = nullptr; @@ -226,7 +225,7 @@ struct DifferentiableTypeConformanceContext { // Populate dictionary with null differential type. if (sharedContext->nullDifferentialStructType) - differentiableValueTypeWitnessDictionary.add( + differentiableTypeWitnessDictionary.add( sharedContext->nullDifferentialStructType, sharedContext->nullDifferentialWitness); } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 0b0c810f70..53bb4215bf 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3803,8 +3803,9 @@ struct IRBuilder IRInst* emitGetTorchCudaStream(); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); - IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential); IRInst* emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); @@ -3986,12 +3987,19 @@ struct IRBuilder IRInst* emitGetOptionalValue(IRInst* optValue); IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value); IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); + IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialValuePairGetPrimal(IRInst* diffPair); IRInst* emitDifferentialPtrPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair); IRInst* emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair); IRInst* emitMakeVector( diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index da0c584203..1f6b8b211d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3514,7 +3514,7 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) + IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as(type)); SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); @@ -3537,6 +3537,86 @@ namespace Slang addInst(inst); return inst; } + + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal) + { + if (as(pairType)) + { + return emitMakeDifferentialValuePair(pairType, primalVal, diffVal); + } + else if (as(pairType)) + { + // Quick optimization: + // If primalVal and diffVal are extracted from the same pointer-pair, + // we can just use the pointer-pair directly. + // + if (auto primalPtrVal = as(primalVal)) + { + if (auto diffPtrVal = as(diffVal)) + { + if (primalPtrVal->getBase() == diffPtrVal->getBase()) + return primalPtrVal->getBase(); + } + } + return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetDifferential(diffType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetDifferential(diffType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal) + { + if (as(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(primalType, pairVal); + } + else if (as(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(primalType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) { @@ -4243,7 +4323,7 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args); } - IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as(diffPair->getDataType())); return emitIntrinsicInst( @@ -4264,7 +4344,7 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair) { auto valueType = cast(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( @@ -4274,7 +4354,7 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair) { return emitIntrinsicInst( primalType,