Skip to content

Commit 7ef980f

Browse files
Fix unzipping logic for inout non-diff parameters and adjust tests (shader-slang#4090)
* Fix unzipping logic for inout non-diff parameters and adjust tests + Removed `-g0` from `struct-this-parameter.slang` test. Works correctly with the new unzipping logic. + Removed `-g0` from `was/warped-sampling-1d.slang` test. Works correctly with DX12 & CS_5_1. CS_5_0 appears to run into an FXC compiler bug with detecting infinite loops where there don't appear to be any. * Update slang-ir-autodiff-unzip.h * Update warped-sampling-1d.slang
1 parent 6b30957 commit 7ef980f

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

source/slang/slang-ir-autodiff-unzip.h

+25-3
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,32 @@ struct DiffUnzipPass
372372
}
373373
else
374374
{
375-
// For non differentiable arguments, we can simply pass the argument as is
376-
// if this isn't a `out` parameter, in which case it is removed from propagate call.
377-
if (!as<IROutType>(arg->getDataType()))
375+
if (auto inOutType = as<IRInOutType>(resolvedPrimalFuncType->getParamType(ii)))
376+
{
377+
// For 'inout' parameter we need to create a temp var to hold the value
378+
// before the primal call. This logic is similar to the 'inout' case for differentiable params
379+
// only we don't need to deal with pair types.
380+
//
381+
auto tempPrimalVar = primalBuilder->emitVar(as<IRPtrTypeBase>(arg->getDataType())->getValueType());
382+
383+
auto storeUse = findUniqueStoredVal(cast<IRVar>(arg));
384+
auto storeInst = cast<IRStore>(storeUse->getUser());
385+
auto storedVal = storeInst->getVal();
386+
387+
primalBuilder->emitStore(tempPrimalVar, storedVal);
388+
389+
diffArgs.add(tempPrimalVar);
390+
}
391+
else
392+
{
393+
// For pure 'in' type. Simply re-use the original argument inst.
394+
//
395+
// For 'out' type parameters, it doesn't really matter what we pass in here, since
396+
// the tranposition logic will discard the argument anyway (we'll pass in the old arg,
397+
// just to keep the number of arguments consistent)
398+
//
378399
diffArgs.add(arg);
400+
}
379401
}
380402
}
381403

tests/autodiff/struct-this-parameter.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0
2-
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -g0
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
33

44
//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
55
RWStructuredBuffer<float> outputBuffer;

tests/autodiff/was/warped-sampling-1d.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -profile cs_5_1 -dx12
22

33
//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out,name=endpointDifferentialBuffer
44
RWStructuredBuffer<float> endpointDifferentialBuffer;

0 commit comments

Comments
 (0)