Skip to content

Commit d48cd13

Browse files
authored
Correct IR generation for no-diff pointer type (shader-slang#5976)
* Correct IR generation for no-diff pointer type Close shader-slang#5805 There is an issue on checking whether a pointer type parameter is no_diff, we should first check whether this parameter is an Attribute type first, then check the data type. In the back-propagate pass, for the pointer type parameter, we should load this parameter to a temp variable, then pass it to the primal function call. Otherwise, the temp variable will no be initialized, which will cause the following calculation wrong.
1 parent e3b71cf commit d48cd13

5 files changed

+69
-10
lines changed

source/slang/slang-ir-autodiff-rev.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
512512
{
513513
// If primal parameter is mutable, we need to pass in a temp var.
514514
auto tempVar = builder.emitVar(primalParamPtrType->getValueType());
515-
if (primalParamPtrType->getOp() == kIROp_InOutType)
516-
{
517-
// If the primal parameter is inout, we need to set the initial value.
518-
builder.emitStore(tempVar, primalArg);
519-
}
515+
516+
// We also need to setup the initial value of the temp var, otherwise
517+
// the temp var will be uninitialized which could cause undefined behavior
518+
// in the primal function.
519+
builder.emitStore(tempVar, primalArg);
520+
520521
primalArgs.add(tempVar);
521522
}
522523
else

source/slang/slang-ir-autodiff-transcriber-base.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
565565
// If this is a PtrType (out, inout, etc..), then create diff pair from
566566
// value type and re-apply the appropropriate PtrType wrapper.
567567
//
568+
if (isNoDiffType(originalType))
569+
return nullptr;
570+
568571
if (auto origPtrType = as<IRPtrTypeBase>(originalType))
569572
{
570573
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))

source/slang/slang-ir-autodiff.cpp

+14-5
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,22 @@ static IRInst* _getDiffTypeWitnessFromPairType(
126126

127127
bool isNoDiffType(IRType* paramType)
128128
{
129-
while (auto ptrType = as<IRPtrTypeBase>(paramType))
130-
paramType = ptrType->getValueType();
131-
while (auto attrType = as<IRAttributedType>(paramType))
129+
while (paramType)
132130
{
133-
if (attrType->findAttr<IRNoDiffAttr>())
131+
if (auto attrType = as<IRAttributedType>(paramType))
134132
{
135-
return true;
133+
if (attrType->findAttr<IRNoDiffAttr>())
134+
return true;
135+
136+
paramType = attrType->getBaseType();
137+
}
138+
else if (auto ptrType = as<IRPtrTypeBase>(paramType))
139+
{
140+
paramType = ptrType->getValueType();
141+
}
142+
else
143+
{
144+
return false;
136145
}
137146
}
138147
return false;

tests/autodiff/nodiff-ptr.slang

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
[Differentiable]
3+
float sumOfSquares(float x, float y, no_diff float4* test)
4+
{
5+
return x * x + y * y * (test->x + test->y + test->z);
6+
}
7+
8+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly
9+
10+
//TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0], stride=4)
11+
uniform float* ptr;
12+
13+
//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer
14+
RWStructuredBuffer<float> outputBuffer;
15+
16+
[shader("compute")]
17+
[numthreads(1, 1, 1)]
18+
void computeMain()
19+
{
20+
float4* testPtr = (float4*)ptr;
21+
22+
let result = sumOfSquares(2.0, 3.0, testPtr);
23+
24+
// Use forward differentiation to compute the gradient of the output w.r.t. x only.
25+
let diffX = fwd_diff(sumOfSquares)(diffPair(2.0, 1.0), diffPair(3.0, 0.0), testPtr);
26+
27+
// Create a differentiable pair to pass in the primal value and to receive the gradient.
28+
var dpX = diffPair(2.0);
29+
var dpY = diffPair(3.0);
30+
31+
// Propagate the gradient of the output (1.0f) to the input parameters.
32+
bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0);
33+
34+
outputBuffer[0] = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58
35+
outputBuffer[1] = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4
36+
outputBuffer[2] = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58
37+
outputBuffer[3] = dpX.d; // 2*x = 4
38+
39+
outputBuffer[4] = dpY.d; // 2*y * (1 + 2 +3) = 36
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
type: float
2+
58.000000
3+
4.000000
4+
58.000000
5+
4.000000
6+
36.000000

0 commit comments

Comments
 (0)