Skip to content

Commit

Permalink
WIP: Fixed inout struct and added testing for calls to non-differenti…
Browse files Browse the repository at this point in the history
…able functions (#2505)

* Added non-differentiable call test

* Extended testing for nondifferentiable calls

* Fixed subtle issue with extensions on generic types not applying the correct substitutions, leading to unspecialized generics at the emit stage

* More fixes. inout struct params now work fine

* Update inout-struct-parameters-jvp.slang

* Update slang-ir.cpp

* Fixed hoisting lookup_interface_method

* Fixed non-diff call return value

* Fixed issue with phi nodes

* Fixed problem with IRSpecialize preventing hoisitng of DifferentialPairType

* Fixed non-diff call test to conform to the new 'no_diff' system
  • Loading branch information
saipraveenb25 authored Nov 21, 2022
1 parent d58e08f commit 545de51
Show file tree
Hide file tree
Showing 10 changed files with 309 additions and 91 deletions.
7 changes: 7 additions & 0 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6024,6 +6024,13 @@ namespace Slang
// without any additional substitutions.
if (extDecl->targetType->equals(type))
{
/*
auto subst = trySolveConstraintSystem(
&constraints,
DeclRef<Decl>(extGenericDecl, nullptr).as<GenericDecl>(),
as<GenericSubstitution>(as<DeclRefType>(type)->declRef.substitutions.substitutions));
return DeclRef<Decl>(extDecl, subst).as<ExtensionDecl>();
*/
return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef).as<ExtensionDecl>();
}

Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ Result linkAndOptimizeIR(
// perform specialization of functions based on parameter
// values that need to be compile-time constants.
//

dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
if (!codeGenContext->isSpecializationDisabled())
specializeModule(irModule);
Expand Down
222 changes: 139 additions & 83 deletions source/slang/slang-ir-diff-jvp.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions source/slang/slang-ir-diff-jvp.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ namespace Slang

void setFunc(IRGlobalValueWithCode* func);

void buildGlobalWitnessDictionary();

// Lookup a witness table for the concreteType. One should exist if concreteType
// inherits (successfully) from IDifferentiable.
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -2696,6 +2696,8 @@ struct IRBuilder
IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair);
IRInst* emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairAddressPrimal(IRInst* diffPair);
IRInst* emitMakeVector(
IRType* type,
UInt argCount,
Expand Down
49 changes: 41 additions & 8 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3147,6 +3147,9 @@ namespace Slang

IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type));
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)->getValueType() != nullptr);

IRInst* args[] = {primal, differential};
auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>(
this, kIROp_MakeDifferentialPair, type, 2, args);
Expand All @@ -3160,6 +3163,18 @@ namespace Slang
UInt argCount,
IRInst* const* args)
{
auto innerReturnVal = findInnerMostGenericReturnVal(as<IRGeneric>(genericVal));

if (as<IRWitnessTable>(innerReturnVal))
{
return findOrEmitHoistableInst(
type,
kIROp_Specialize,
genericVal,
argCount,
args);
}

auto inst = createInstWithTrailingArgs<IRSpecialize>(
this,
kIROp_Specialize,
Expand All @@ -3186,15 +3201,13 @@ namespace Slang
//
SLANG_ASSERT(witnessTableVal->getOp() != kIROp_StructKey);

auto inst = createInst<IRLookupWitnessMethod>(
this,
kIROp_lookup_interface_method,
type,
witnessTableVal,
interfaceMethodVal);
IRInst* args[] = {witnessTableVal, interfaceMethodVal};

addInst(inst);
return inst;
return findOrEmitHoistableInst(
type,
kIROp_lookup_interface_method,
2,
args);
}

IRInst* IRBuilder::emitGetSequentialIDInst(IRInst* rttiObj)
Expand Down Expand Up @@ -3467,6 +3480,15 @@ namespace Slang
&diffPair);
}

IRInst* IRBuilder::emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair)
{
return emitIntrinsicInst(
diffType,
kIROp_DifferentialPairGetDifferential,
1,
&diffPair);
}

IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
{
auto valueType = as<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
Expand All @@ -3477,6 +3499,17 @@ namespace Slang
&diffPair);
}

IRInst* IRBuilder::emitDifferentialPairAddressPrimal(IRInst* diffPair)
{
auto valueType = as<IRDifferentialPairType>(
as<IRPtrTypeBase>(diffPair->getDataType())->getValueType())->getValueType();
return emitIntrinsicInst(
this->getPtrType(kIROp_PtrType, valueType),
kIROp_DifferentialPairGetPrimal,
1,
&diffPair);
}

IRInst* IRBuilder::emitMakeMatrix(
IRType* type,
UInt argCount,
Expand Down
41 changes: 41 additions & 0 deletions tests/autodiff/inout-struct-parameters-jvp.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;

struct A : IDifferentiable
{
float p;
float3 q;
}

[ForwardDifferentiable]
void g(A a, inout A aout)
{
float t = a.p + a.q.y * a.q.x;
aout.p = aout.p + t;
aout.q = aout.q * t;
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
float p = 1.0;
float3 q = float3(1.0, 2.0, 3.0);

float dp = 1.0;
float3 dq = float3(1.0, 0.5, 0.25);

DifferentialPair<A> dpa = DifferentialPair<A>({p, q}, {dp, dq});

__fwd_diff(g)(DifferentialPair<A>( { p, q }, { dp, dq }), dpa);

outputBuffer[0] = dpa.p.p; // Expect: 4.0
outputBuffer[1] = dpa.d.q.x; // Expect: 6.5
outputBuffer[2] = dpa.d.q.y; // Expect: 8.5
outputBuffer[3] = dpa.d.q.z; // Expect: 11.25

}
5 changes: 5 additions & 0 deletions tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: float
4.000000
6.500000
8.500000
11.25000
66 changes: 66 additions & 0 deletions tests/autodiff/nondiff-call.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float3> dpfloat3;

[ForwardDifferentiable]
float f(float x)
{
return x * x + x * x * x;
}

[ForwardDifferentiable]
float f2(float x)
{
return f(x);
}

float g(float x)
{
return x * x + x * x * x;
}

[ForwardDifferentiable]
float g2(float x)
{
return no_diff(g(x));
}

struct A
{
float o;

[ForwardDifferentiable]
float doSomethingDifferentiable(float b)
{
return o + b;
}

float doSomethingNotDifferentiable(float b)
{
return o * b;
}
}

[ForwardDifferentiable]
float h2(A a, float k)
{
float v = k * k;
return no_diff(a.doSomethingNotDifferentiable(k)) + v;
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
{
outputBuffer[0] = f2(1.0); // Expect: 2.0
outputBuffer[1] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).d; // Expect: 5.0
outputBuffer[2] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).p; // Expect: 2.0
outputBuffer[3] = __fwd_diff(g2)(dpfloat(1.0, 1.0)).d; // Expect: 0.0
outputBuffer[4] = __fwd_diff(h2)({1.0}, DifferentialPair<float>(1.0, 2.0)).d; // Expect: 4.0
}
}
6 changes: 6 additions & 0 deletions tests/autodiff/nondiff-call.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
type: float
2.000000
5.000000
2.000000
0.000000
4.000000

0 comments on commit 545de51

Please sign in to comment.