Skip to content

Commit

Permalink
Hack handling of primal insts that has a function type. (#2728)
Browse files Browse the repository at this point in the history
* Update diff-bwd material test

* Minor update

* Hack handling of primal insts that has a function type.

---------

Co-authored-by: winmad <winmad.wlf@gmail.com>
Co-authored-by: Yong He <yhe@nvidia.com>
  • Loading branch information
3 people authored Mar 24, 2023
1 parent 50e7d97 commit 6e4eae1
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 6 deletions.
7 changes: 5 additions & 2 deletions source/slang/slang-ir-autodiff-transcriber-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1116,9 +1116,12 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
}
else
{
if (!pair.primal->findDecoration<IRAutodiffInstDecoration>()
&& !as<IRConstant>(pair.differential))
if (!pair.primal->findDecoration<IRAutodiffInstDecoration>())
{
if (as<IRConstant>(pair.differential))
break;
if (as<IRType>(pair.differential))
break;
auto mixedType = (IRType*)(pair.primal->getDataType());
builder->markInstAsMixedDifferential(pair.primal, mixedType);
}
Expand Down
27 changes: 25 additions & 2 deletions source/slang/slang-ir-autodiff-transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -950,16 +950,32 @@ struct DiffTransposePass
// Slang doesn't support function values. So if we see a func-typed inst
// it's proabably a reference to a function.
//
if (as<IRFuncType>(child->getDataType()))
switch (child->getOp())
{
/*
TODO: need a better way to move specialize, lookupwitness, extractExistentialType/Value/Witness
insts to a proper location that dominates all their use sites. Create copies of these insts
when necessary.
case kIROp_Specialize:
case kIROp_LookupWitness:
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialWitnessTable:
*/
case kIROp_ForwardDifferentiate:
case kIROp_BackwardDifferentiate:
case kIROp_BackwardDifferentiatePrimal:
case kIROp_BackwardDifferentiatePropagate:
typeInsts.add(child);
break;
}
}

for (auto inst : typeInsts)
{
inst->insertAtEnd(revBlock);
}


// Then, go backwards through the regular instructions, and transpose them into the new
// rev block.
// Note the 'reverse' traversal here.
Expand Down Expand Up @@ -2221,6 +2237,10 @@ struct DiffTransposePass
case kIROp_ifElse:
case kIROp_loop:
case kIROp_Switch:
case kIROp_LookupWitness:
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialWitnessTable:
{
// Ignore. transposeBlock() should take care of adding the
// appropriate branch instruction.
Expand Down Expand Up @@ -3474,6 +3494,9 @@ struct DiffTransposePass
List<IRUse*> primalUsesToHoist;

Dictionary<IRStore*, IRBlock*> mapStoreToDefBlock;

IRCloneEnv typeInstCloneEnv = {};

};


Expand Down
10 changes: 9 additions & 1 deletion tests/autodiff/material/GlossyMaterialInstance.slang
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,26 @@ struct GlossyBSDF : IBSDF
{
float3 albedo;

[BackwardDifferentiable]
float3 getAlbedo()
{
return albedo;
}

[ForwardDerivativeOf(getAlbedo)]
[TreatAsDifferentiable]
DifferentialPair<float3> __fwd_d_getAlbedo()
{
return diffPair(albedo, float3(1.f));
}

[BackwardDerivativeOf(getAlbedo)]
[TreatAsDifferentiable]
void __bwd_d_getAlbedo(float3 dOut)
{
[unroll]
for (int j = 0; j < 3; j++) outputBuffer[j + 3] += dOut[j];
}

[BackwardDifferentiable]
float3 eval(const float3 wiLocal, const float3 woLocal)
{
Expand Down
3 changes: 3 additions & 0 deletions tests/autodiff/material/IBSDF.slang
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//TEST_IGNORE_FILE:

//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

interface IBSDF
{
[BackwardDifferentiable]
Expand Down
33 changes: 33 additions & 0 deletions tests/autodiff/material/diff-bwd-falcor-material-system.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type

// outputBuffer is defined in IBSDF.slang
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer

import MaterialSystem;
import DiffuseMaterial;
import DiffuseMaterialInstance;
import GlossyMaterial;
import GlossyMaterialInstance;

//TEST_INPUT: type_conformance DiffuseMaterial:IMaterial = 0
//TEST_INPUT: type_conformance GlossyMaterial:IMaterial = 1

[BackwardDifferentiable]
float3 evalBSDF(int type)
{
float3 wi = normalize(float3(0.5, 0.2, 0.8));
float3 wo = normalize(float3(-0.1, -0.3, 0.9));

IMaterial material = createMaterialClassConformance(type, float3(0.9f, 0.6f, 0.2f));
let mi = material.setupMaterialInstance();
float3 f = mi.eval(wi, wo);
return f;
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
__bwd_diff(evalBSDF)(0, float3(1.f));
__bwd_diff(evalBSDF)(1, float3(1.f));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
type: float
0.0
0.0
0.0
1.8
1.2
0.4
2 changes: 1 addition & 1 deletion tests/autodiff/material/diff-falcor-material-system.slang
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type

// outputBuffer is defined in IBSDF.slang
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

import MaterialSystem;
import DiffuseMaterial;
Expand Down

0 comments on commit 6e4eae1

Please sign in to comment.