From 545de51298ddda52ac51ded03ad489c98bdda397 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:29:57 -0500 Subject: [PATCH] WIP: Fixed inout struct and added testing for calls to non-differentiable functions (#2505) * Added non-differentiable call test * Extended testing for nondifferentiable calls * Fixed subtle issue with extensions on generic types not applying the correct substitutions, leading to unspecialized generics at the emit stage * More fixes. inout struct params now work fine * Update inout-struct-parameters-jvp.slang * Update slang-ir.cpp * Fixed hoisting lookup_interface_method * Fixed non-diff call return value * Fixed issue with phi nodes * Fixed problem with IRSpecialize preventing hoisitng of DifferentialPairType * Fixed non-diff call test to conform to the new 'no_diff' system --- source/slang/slang-check-decl.cpp | 7 + source/slang/slang-emit.cpp | 1 + source/slang/slang-ir-diff-jvp.cpp | 222 +++++++++++------- source/slang/slang-ir-diff-jvp.h | 1 + source/slang/slang-ir-insts.h | 2 + source/slang/slang-ir.cpp | 49 +++- .../inout-struct-parameters-jvp.slang | 41 ++++ ...t-struct-parameters-jvp.slang.expected.txt | 5 + tests/autodiff/nondiff-call.slang | 66 ++++++ .../autodiff/nondiff-call.slang.expected.txt | 6 + 10 files changed, 309 insertions(+), 91 deletions(-) create mode 100644 tests/autodiff/inout-struct-parameters-jvp.slang create mode 100644 tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt create mode 100644 tests/autodiff/nondiff-call.slang create mode 100644 tests/autodiff/nondiff-call.slang.expected.txt diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 009d0a9871..5a1218abea 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6024,6 +6024,13 @@ namespace Slang // without any additional substitutions. if (extDecl->targetType->equals(type)) { + /* + auto subst = trySolveConstraintSystem( + &constraints, + DeclRef(extGenericDecl, nullptr).as(), + as(as(type)->declRef.substitutions.substitutions)); + return DeclRef(extDecl, subst).as(); + */ return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef).as(); } diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index cd5f58925e..69ea29c7ad 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -358,6 +358,7 @@ Result linkAndOptimizeIR( // perform specialization of functions based on parameter // values that need to be compile-time constants. // + dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE"); if (!codeGenContext->isSpecializationDisabled()) specializeModule(irModule); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 4ee16aafcb..c9ca687e41 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -92,16 +92,33 @@ struct DifferentialPairTypeBuilder IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) { - auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); - if (baseTypeInfo.isTrivial) + IRInst* pairType = nullptr; + if (auto basePtrType = as(baseInst->getDataType())) { - if (key == globalPrimalKey) - return baseInst; - else - return builder->getDifferentialBottom(); + auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType()); + + // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types, + // especially since we probably don't need diff bottom anymore. + // + SLANG_ASSERT(!baseTypeInfo.isTrivial); + + pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType); + } + else + { + auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); + if (baseTypeInfo.isTrivial) + { + if (key == globalPrimalKey) + return baseInst; + else + return builder->getDifferentialBottom(); + } + + pairType = baseTypeInfo.loweredType; } - if (auto basePairStructType = as(baseTypeInfo.loweredType)) + if (auto basePairStructType = as(pairType)) { return as(builder->emitFieldExtract( findField(basePairStructType, key)->getFieldType(), @@ -109,7 +126,7 @@ struct DifferentialPairTypeBuilder key )); } - else if (auto ptrType = as(baseTypeInfo.loweredType)) + else if (auto ptrType = as(pairType)) { if (auto ptrInnerSpecializedType = as(ptrType->getValueType())) { @@ -135,7 +152,7 @@ struct DifferentialPairTypeBuilder key)); } } - else if (auto specializedType = as(baseTypeInfo.loweredType)) + else if (auto specializedType = as(pairType)) { // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's // type, emit the specialization type. @@ -333,7 +350,9 @@ struct JVPTranscriber JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) - {} + { + + } DiagnosticSink* getSink() { @@ -449,6 +468,17 @@ struct JVPTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } + IRWitnessTable* getDifferentialBottomWitness() + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); + auto result = + as(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + SLANG_ASSERT(result); + return result; + } + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) { @@ -456,23 +486,20 @@ struct JVPTranscriber builder.setInsertInto(inDiffPairType->parent); auto diffPairType = as(inDiffPairType); SLANG_ASSERT(diffPairType); - auto diffType = differentiateType(&builder, diffPairType->getValueType()); - IRInst* tableInst = nullptr; - if (!differentiableTypeConformanceContext.differentiableWitnessDictionary.TryGetValue(diffPairType, tableInst)) - { - IRWitnessTable* table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - // The witness that `diffType` - auto differentialType = builder.getDifferentialPairType( - diffType, - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffType] - .GetValue()); - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - tableInst = table; - } - return as(tableInst); + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = differentiateType(&builder, diffPairType); + + // And place it in the synthesized witness table. + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + // Record this in the context for future lookups + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + + return table; } IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) @@ -490,10 +517,19 @@ struct JVPTranscriber builder.setInsertInto(primalType->parent); auto witness = as( differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - if (!witness && as(primalType)) + + if (!witness) { - witness = getDifferentialPairWitness(primalType); + if (auto primalPairType = as(primalType)) + { + witness = getDifferentialPairWitness(primalPairType); + } + else + { + witness = getDifferentialBottomWitness(); + } } + return builder.getDifferentialPairType( (IRType*)primalType, witness); @@ -630,8 +666,8 @@ struct JVPTranscriber builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); SLANG_ASSERT(diffPairParam); - - if (auto pairType = as(diffPairParam->getDataType())) + + if (auto pairType = as(diffPairType)) { return InstPair( builder->emitDifferentialPairGetPrimal(diffPairParam), @@ -639,16 +675,23 @@ struct JVPTranscriber (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), diffPairParam)); } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); + else if (auto pairPtrType = as(diffPairType)) + { + auto ptrInnerPairType = as(pairPtrType->getValueType()); + + return InstPair( + builder->emitDifferentialPairAddressPrimal(diffPairParam), + builder->emitDifferentialPairAddressDifferential( + builder->getPtrType( + kIROp_PtrType, + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), + diffPairParam)); + } } + return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); + cloneInst(&cloneEnv, builder, origParam), + nullptr); } else { @@ -660,6 +703,7 @@ struct JVPTranscriber } return InstPair(primal, diff); } + } // Returns "d" to use as a name hint for variables and parameters. @@ -784,6 +828,7 @@ struct JVPTranscriber { // Special case load from an `out` param, which will not have corresponding `diff` and // `primal` insts yet. + auto load = builder->emitLoad(primalPtr); auto primalElement = builder->emitDifferentialPairGetPrimal(load); auto diffElement = builder->emitDifferentialPairGetDifferential( @@ -1401,30 +1446,25 @@ struct JVPTranscriber InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) { - // The loop comes with three blocks.. we just need to transcribe each one - // and assemble the new loop instruction. + // IfElse Statements come with 4 blocks. We transcribe each block into it's + // linear form, and then wire them up in the same way as the original if-else - // Transcribe the target block (this is the 'condition' part of the loop, which - // will branch into the loop body). - // Note that for the condition we use the primal inst (condition values should not have a - // differential) + // Transcribe condition block auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); SLANG_ASSERT(primalConditionBlock); - // Transcribe the break block (this is the block after the exiting the loop) + // Transcribe 'true' block (condition block branches into this if true) auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); SLANG_ASSERT(diffTrueBlock); - // Transcribe the continue block (this is the 'update' part of the loop, which will - // branch into the condition block) + // Transcribe 'false' block (condition block branches into this if true) + // TODO (sai): What happens if there's no false block? auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); SLANG_ASSERT(diffFalseBlock); - // Transcribe the continue block (this is the 'update' part of the loop, which will - // branch into the condition block) + // Transcribe 'after' block (true and false blocks branch into this) auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); SLANG_ASSERT(diffAfterBlock); - List diffIfElseArgs; diffIfElseArgs.add(primalConditionBlock); @@ -2462,6 +2502,9 @@ struct JVPDerivativeContext : public InstPassBase sharedBuilder->init(module); sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + // TODO(sai): Move this call. + transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); + IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; @@ -2477,6 +2520,9 @@ struct JVPDerivativeContext : public InstPassBase // modified |= simplifyDifferentialBottomType(builder); + // De-duplicate any remaining types. + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + modified |= processPairTypes(builder, module->getModuleInst()); modified |= eliminateDifferentialBottomType(builder); @@ -2665,7 +2711,13 @@ struct JVPDerivativeContext : public InstPassBase { if (auto getDiffInst = as(inst)) { - if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), nullptr)) + auto pairType = getDiffInst->getBase()->getDataType(); + if (auto pairPtrType = as(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) { builder->setInsertBefore(getDiffInst); IRInst* diffFieldExtract = nullptr; @@ -2677,7 +2729,13 @@ struct JVPDerivativeContext : public InstPassBase } else if (auto getPrimalInst = as(inst)) { - if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), nullptr)) + auto pairType = getPrimalInst->getBase()->getDataType(); + if (auto pairPtrType = as(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) { builder->setInsertBefore(getPrimalInst); @@ -2695,41 +2753,29 @@ struct JVPDerivativeContext : public InstPassBase bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; - // Hoist and deduplicate all pair types to global scope when possible. - // This avoids emitting different struct types for equivalent pair types. + // Hoist all pair types to global scope when possible. auto moduleInst = module->getModuleInst(); - Dictionary diffPairTypes; - for (;;) - { - bool changed = false; - sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); - processInstsOfType(kIROp_DifferentialPairType, [&](IRDifferentialPairType* originalPairType) + processInstsOfType(kIROp_DifferentialPairType, [&](IRInst* originalPairType) + { + if (originalPairType->parent != moduleInst) { - IRInst* finalType = nullptr; - if (diffPairTypes.TryGetValue(originalPairType->getValueType(), finalType)) - { - if (finalType != originalPairType) - { - originalPairType->replaceUsesWith(finalType); - originalPairType->removeAndDeallocate(); - changed = true; - return; - } - } - diffPairTypes[originalPairType->getValueType()] = originalPairType; - if (originalPairType->parent != moduleInst) + originalPairType->removeFromParent(); + ShortList operands; + for (UInt i = 0; i < originalPairType->getOperandCount(); i++) { - if (originalPairType->getValueType()->getParent() != originalPairType->getParent()) - { - originalPairType->insertAfter(originalPairType->getValueType()); - changed = true; - return; - } + operands.add(originalPairType->getOperand(i)); } - }); - if (!changed) - break; - } + auto newPairType = builder->findOrEmitHoistableInst( + originalPairType->getFullType(), + originalPairType->getOp(), + originalPairType->getOperandCount(), + operands.getArrayView().getBuffer()); + originalPairType->replaceUsesWith(newPairType); + originalPairType->removeAndDeallocate(); + } + }); + + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); processAllInsts([&](IRInst* inst) { @@ -3138,4 +3184,14 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b return nullptr; } +void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() +{ + for (auto globalInst : sharedContext->moduleInst->getChildren()) + { + if (auto pairType = as(globalInst)) + { + differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); + } + } +} } diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h index 9e0f9cfcc8..5e2a7f44f3 100644 --- a/source/slang/slang-ir-diff-jvp.h +++ b/source/slang/slang-ir-diff-jvp.h @@ -121,6 +121,7 @@ namespace Slang void setFunc(IRGlobalValueWithCode* func); + void buildGlobalWitnessDictionary(); // Lookup a witness table for the concreteType. One should exist if concreteType // inherits (successfully) from IDifferentiable. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fcdeed17ad..4434210c9b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2696,6 +2696,8 @@ struct IRBuilder IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPairAddressPrimal(IRInst* diffPair); IRInst* emitMakeVector( IRType* type, UInt argCount, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index de86a6a52f..c128723201 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3147,6 +3147,9 @@ namespace Slang IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) { + SLANG_RELEASE_ASSERT(as(type)); + SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); + IRInst* args[] = {primal, differential}; auto inst = createInstWithTrailingArgs( this, kIROp_MakeDifferentialPair, type, 2, args); @@ -3160,6 +3163,18 @@ namespace Slang UInt argCount, IRInst* const* args) { + auto innerReturnVal = findInnerMostGenericReturnVal(as(genericVal)); + + if (as(innerReturnVal)) + { + return findOrEmitHoistableInst( + type, + kIROp_Specialize, + genericVal, + argCount, + args); + } + auto inst = createInstWithTrailingArgs( this, kIROp_Specialize, @@ -3186,15 +3201,13 @@ namespace Slang // SLANG_ASSERT(witnessTableVal->getOp() != kIROp_StructKey); - auto inst = createInst( - this, - kIROp_lookup_interface_method, - type, - witnessTableVal, - interfaceMethodVal); + IRInst* args[] = {witnessTableVal, interfaceMethodVal}; - addInst(inst); - return inst; + return findOrEmitHoistableInst( + type, + kIROp_lookup_interface_method, + 2, + args); } IRInst* IRBuilder::emitGetSequentialIDInst(IRInst* rttiObj) @@ -3467,6 +3480,15 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair) + { + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPairGetDifferential, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) { auto valueType = as(diffPair->getDataType())->getValueType(); @@ -3477,6 +3499,17 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPairAddressPrimal(IRInst* diffPair) + { + auto valueType = as( + as(diffPair->getDataType())->getValueType())->getValueType(); + return emitIntrinsicInst( + this->getPtrType(kIROp_PtrType, valueType), + kIROp_DifferentialPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitMakeMatrix( IRType* type, UInt argCount, diff --git a/tests/autodiff/inout-struct-parameters-jvp.slang b/tests/autodiff/inout-struct-parameters-jvp.slang new file mode 100644 index 0000000000..80ff57b7d7 --- /dev/null +++ b/tests/autodiff/inout-struct-parameters-jvp.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; + +struct A : IDifferentiable +{ + float p; + float3 q; +} + +[ForwardDifferentiable] +void g(A a, inout A aout) +{ + float t = a.p + a.q.y * a.q.x; + aout.p = aout.p + t; + aout.q = aout.q * t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float p = 1.0; + float3 q = float3(1.0, 2.0, 3.0); + + float dp = 1.0; + float3 dq = float3(1.0, 0.5, 0.25); + + DifferentialPair dpa = DifferentialPair({p, q}, {dp, dq}); + + __fwd_diff(g)(DifferentialPair( { p, q }, { dp, dq }), dpa); + + outputBuffer[0] = dpa.p.p; // Expect: 4.0 + outputBuffer[1] = dpa.d.q.x; // Expect: 6.5 + outputBuffer[2] = dpa.d.q.y; // Expect: 8.5 + outputBuffer[3] = dpa.d.q.z; // Expect: 11.25 + +} \ No newline at end of file diff --git a/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt b/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt new file mode 100644 index 0000000000..4cc3c313d7 --- /dev/null +++ b/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +4.000000 +6.500000 +8.500000 +11.25000 diff --git a/tests/autodiff/nondiff-call.slang b/tests/autodiff/nondiff-call.slang new file mode 100644 index 0000000000..d62de1b78e --- /dev/null +++ b/tests/autodiff/nondiff-call.slang @@ -0,0 +1,66 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; +typedef DifferentialPair dpfloat3; + +[ForwardDifferentiable] +float f(float x) +{ + return x * x + x * x * x; +} + +[ForwardDifferentiable] +float f2(float x) +{ + return f(x); +} + +float g(float x) +{ + return x * x + x * x * x; +} + +[ForwardDifferentiable] +float g2(float x) +{ + return no_diff(g(x)); +} + +struct A +{ + float o; + + [ForwardDifferentiable] + float doSomethingDifferentiable(float b) + { + return o + b; + } + + float doSomethingNotDifferentiable(float b) + { + return o * b; + } +} + +[ForwardDifferentiable] +float h2(A a, float k) +{ + float v = k * k; + return no_diff(a.doSomethingNotDifferentiable(k)) + v; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + outputBuffer[0] = f2(1.0); // Expect: 2.0 + outputBuffer[1] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).d; // Expect: 5.0 + outputBuffer[2] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).p; // Expect: 2.0 + outputBuffer[3] = __fwd_diff(g2)(dpfloat(1.0, 1.0)).d; // Expect: 0.0 + outputBuffer[4] = __fwd_diff(h2)({1.0}, DifferentialPair(1.0, 2.0)).d; // Expect: 4.0 + } +} diff --git a/tests/autodiff/nondiff-call.slang.expected.txt b/tests/autodiff/nondiff-call.slang.expected.txt new file mode 100644 index 0000000000..8f85913bcc --- /dev/null +++ b/tests/autodiff/nondiff-call.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +5.000000 +2.000000 +0.000000 +4.000000 \ No newline at end of file