-
Notifications
You must be signed in to change notification settings - Fork 242
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Fixed inout struct and added testing for calls to non-differenti…
…able functions (#2505) * Added non-differentiable call test * Extended testing for nondifferentiable calls * Fixed subtle issue with extensions on generic types not applying the correct substitutions, leading to unspecialized generics at the emit stage * More fixes. inout struct params now work fine * Update inout-struct-parameters-jvp.slang * Update slang-ir.cpp * Fixed hoisting lookup_interface_method * Fixed non-diff call return value * Fixed issue with phi nodes * Fixed problem with IRSpecialize preventing hoisitng of DifferentialPairType * Fixed non-diff call test to conform to the new 'no_diff' system
- Loading branch information
1 parent
d58e08f
commit 545de51
Showing
10 changed files
with
309 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type | ||
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type | ||
|
||
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer | ||
RWStructuredBuffer<float> outputBuffer; | ||
|
||
typedef DifferentialPair<float> dpfloat; | ||
|
||
struct A : IDifferentiable | ||
{ | ||
float p; | ||
float3 q; | ||
} | ||
|
||
[ForwardDifferentiable] | ||
void g(A a, inout A aout) | ||
{ | ||
float t = a.p + a.q.y * a.q.x; | ||
aout.p = aout.p + t; | ||
aout.q = aout.q * t; | ||
} | ||
|
||
[numthreads(1, 1, 1)] | ||
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) | ||
{ | ||
float p = 1.0; | ||
float3 q = float3(1.0, 2.0, 3.0); | ||
|
||
float dp = 1.0; | ||
float3 dq = float3(1.0, 0.5, 0.25); | ||
|
||
DifferentialPair<A> dpa = DifferentialPair<A>({p, q}, {dp, dq}); | ||
|
||
__fwd_diff(g)(DifferentialPair<A>( { p, q }, { dp, dq }), dpa); | ||
|
||
outputBuffer[0] = dpa.p.p; // Expect: 4.0 | ||
outputBuffer[1] = dpa.d.q.x; // Expect: 6.5 | ||
outputBuffer[2] = dpa.d.q.y; // Expect: 8.5 | ||
outputBuffer[3] = dpa.d.q.z; // Expect: 11.25 | ||
|
||
} |
5 changes: 5 additions & 0 deletions
5
tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
type: float | ||
4.000000 | ||
6.500000 | ||
8.500000 | ||
11.25000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type | ||
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type | ||
|
||
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer | ||
RWStructuredBuffer<float> outputBuffer; | ||
|
||
typedef DifferentialPair<float> dpfloat; | ||
typedef DifferentialPair<float3> dpfloat3; | ||
|
||
[ForwardDifferentiable] | ||
float f(float x) | ||
{ | ||
return x * x + x * x * x; | ||
} | ||
|
||
[ForwardDifferentiable] | ||
float f2(float x) | ||
{ | ||
return f(x); | ||
} | ||
|
||
float g(float x) | ||
{ | ||
return x * x + x * x * x; | ||
} | ||
|
||
[ForwardDifferentiable] | ||
float g2(float x) | ||
{ | ||
return no_diff(g(x)); | ||
} | ||
|
||
struct A | ||
{ | ||
float o; | ||
|
||
[ForwardDifferentiable] | ||
float doSomethingDifferentiable(float b) | ||
{ | ||
return o + b; | ||
} | ||
|
||
float doSomethingNotDifferentiable(float b) | ||
{ | ||
return o * b; | ||
} | ||
} | ||
|
||
[ForwardDifferentiable] | ||
float h2(A a, float k) | ||
{ | ||
float v = k * k; | ||
return no_diff(a.doSomethingNotDifferentiable(k)) + v; | ||
} | ||
|
||
[numthreads(1, 1, 1)] | ||
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) | ||
{ | ||
{ | ||
outputBuffer[0] = f2(1.0); // Expect: 2.0 | ||
outputBuffer[1] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).d; // Expect: 5.0 | ||
outputBuffer[2] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).p; // Expect: 2.0 | ||
outputBuffer[3] = __fwd_diff(g2)(dpfloat(1.0, 1.0)).d; // Expect: 0.0 | ||
outputBuffer[4] = __fwd_diff(h2)({1.0}, DifferentialPair<float>(1.0, 2.0)).d; // Expect: 4.0 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
type: float | ||
2.000000 | ||
5.000000 | ||
2.000000 | ||
0.000000 | ||
4.000000 |