Skip to content

Commit fe77f07

Browse files
Fix non-square matrix derivatives (shader-slang#6282)
1 parent 78a6389 commit fe77f07

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

source/slang/slang-ir-autodiff-transpose.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -1812,8 +1812,8 @@ struct DiffTransposePass
18121812
{
18131813
List<RevGradient> gradients;
18141814
auto matrixType = as<IRMatrixType>(fwdMakeMatrix->getDataType());
1815-
auto row = as<IRIntLit>(matrixType->getRowCount());
18161815
auto colCount = matrixType->getColumnCount();
1816+
auto colCountVal = as<IRIntLit>(matrixType->getColumnCount())->getValue();
18171817
IRType* rowVectorType = nullptr;
18181818
for (UIndex ii = 0; ii < fwdMakeMatrix->getOperandCount(); ii++)
18191819
{
@@ -1828,9 +1828,8 @@ struct DiffTransposePass
18281828
}
18291829
else
18301830
{
1831-
SLANG_RELEASE_ASSERT(row);
1832-
UInt rowIndex = ii / (UInt)row->getValue();
1833-
UInt colIndex = ii % (UInt)row->getValue();
1831+
UInt rowIndex = ii / (UInt)colCountVal;
1832+
UInt colIndex = ii % (UInt)colCountVal;
18341833
if (!rowVectorType)
18351834
rowVectorType = builder->getVectorType(matrixType->getElementType(), colCount);
18361835
auto revRow = builder->emitElementExtract(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
2+
3+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
[Differentiable]
7+
float h_outer(float a, float b, float c)
8+
{
9+
const float3x2 m2 = float3x2(2 * a, 0.0,
10+
0.0, 3 * b,
11+
0.0, 5 * c);
12+
13+
const float3x3 m1 = float3x3(1.f);
14+
return h(m1, m2);
15+
}
16+
17+
[Differentiable]
18+
float h(float3x3 x, float3x2 y)
19+
{
20+
let res = mul(x, y);
21+
return dot(mul(res, float2(1.0)), float3(1.0));
22+
}
23+
24+
[numthreads(1, 1, 1)]
25+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
26+
{
27+
// Do a bwd_diff test for h_outer
28+
var dpa = diffPair(1.f, 0.f);
29+
var dpb = diffPair(1.f, 0.f);
30+
var dpc = diffPair(1.f, 0.f);
31+
bwd_diff(h_outer)(dpa, dpb, dpc, 1.f);
32+
33+
outputBuffer[1] = dpa.d;
34+
outputBuffer[2] = dpb.d;
35+
outputBuffer[3] = dpc.d;
36+
37+
// CHECK: type: float
38+
// CHECK: 6.0
39+
// CHECK: 9.0
40+
// CHECK: 15.0
41+
}

0 commit comments

Comments
 (0)