Skip to content

Commit 112bca6

Browse files
committed
Merge remote-tracking branch 'official/master' into perf
2 parents a4ca08d + 0101e5a commit 112bca6

9 files changed

+217
-74
lines changed

source/slang/diff.meta.slang

+6-6
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
20742074
{
20752075
return DifferentialPair<T>(
20762076
max(dpx.p, dpy.p),
2077-
dpx.p > dpy.p ? dpx.d : dpy.d
2077+
dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
20782078
);
20792079
}
20802080

@@ -2084,8 +2084,8 @@ __generic<T : __BuiltinFloatingPointType>
20842084
[BackwardDerivativeOf(max)]
20852085
void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
20862086
{
2087-
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero());
2088-
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero());
2087+
dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
2088+
dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
20892089
}
20902090

20912091
VECTOR_MATRIX_BINARY_DIFF_IMPL(max)
@@ -2099,7 +2099,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
20992099
{
21002100
return DifferentialPair<T>(
21012101
min(dpx.p, dpy.p),
2102-
dpx.p < dpy.p ? dpx.d : dpy.d
2102+
dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
21032103
);
21042104
}
21052105

@@ -2109,8 +2109,8 @@ __generic<T : __BuiltinFloatingPointType>
21092109
[BackwardDerivativeOf(min)]
21102110
void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
21112111
{
2112-
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero());
2113-
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero());
2112+
dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
2113+
dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
21142114
}
21152115

21162116
VECTOR_MATRIX_BINARY_DIFF_IMPL(min)

source/slang/slang-emit-metal.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,15 @@ void MetalSourceEmitter::_emitHLSLTextureType(IRTextureTypeBase* texType)
136136
switch (texType->getAccess())
137137
{
138138
case SLANG_RESOURCE_ACCESS_READ:
139-
m_writer->emit("access::sample");
140-
break;
139+
{
140+
// Metal does not support access::sample for texture buffers, so we need to emit
141+
// access::read instead.
142+
if (texType->GetBaseShape() == SLANG_TEXTURE_BUFFER)
143+
m_writer->emit("access::read");
144+
else
145+
m_writer->emit("access::sample");
146+
break;
147+
}
141148

142149
case SLANG_RESOURCE_ACCESS_WRITE:
143150
m_writer->emit("access::write");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
2+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
typedef DifferentialPair<float> dpfloat;
8+
typedef DifferentialPair<float2> dpfloat2;
9+
10+
[BackwardDifferentiable]
11+
float diffMax(float x, float y)
12+
{
13+
return max(x, y);
14+
}
15+
16+
[BackwardDifferentiable]
17+
float2 diffMax(float2 x, float2 y)
18+
{
19+
return max(x, y);
20+
}
21+
22+
[BackwardDifferentiable]
23+
float diffMin(float x, float y)
24+
{
25+
return min(x, y);
26+
}
27+
28+
[BackwardDifferentiable]
29+
float2 diffMin(float2 x, float2 y)
30+
{
31+
return min(x, y);
32+
}
33+
34+
[numthreads(1, 1, 1)]
35+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
36+
{
37+
// Test max() with x < y
38+
{
39+
dpfloat dpx = dpfloat(2.0, 1.0);
40+
dpfloat dpy = dpfloat(5.0, -2.0);
41+
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
42+
outputBuffer[0] = res.p; // Expect: 5.000000
43+
outputBuffer[1] = res.d; // Expect: -2.000000
44+
}
45+
46+
// Test max() with x == y
47+
{
48+
dpfloat dpx = dpfloat(3.0, 1.0);
49+
dpfloat dpy = dpfloat(3.0, -2.0);
50+
dpfloat res = __fwd_diff(diffMax)(dpx, dpy);
51+
outputBuffer[2] = res.p; // Expect: 3.000000
52+
outputBuffer[3] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
53+
}
54+
55+
// Test min() with x > y
56+
{
57+
dpfloat dpx = dpfloat(5.0, 1.0);
58+
dpfloat dpy = dpfloat(2.0, -2.0);
59+
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
60+
outputBuffer[4] = res.p; // Expect: 2.000000
61+
outputBuffer[5] = res.d; // Expect: -2.000000
62+
}
63+
64+
// Test min() with x == y
65+
{
66+
dpfloat dpx = dpfloat(3.0, 1.0);
67+
dpfloat dpy = dpfloat(3.0, -2.0);
68+
dpfloat res = __fwd_diff(diffMin)(dpx, dpy);
69+
outputBuffer[6] = res.p; // Expect: 3.000000
70+
outputBuffer[7] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0)
71+
}
72+
73+
// Test backward-mode max() with x == y
74+
{
75+
dpfloat dpx = dpfloat(3.0, 0.0);
76+
dpfloat dpy = dpfloat(3.0, 0.0);
77+
__bwd_diff(diffMax)(dpx, dpy, 1.0);
78+
outputBuffer[8] = dpx.d; // Expect: 0.500000 (half of gradient)
79+
outputBuffer[9] = dpy.d; // Expect: 0.500000 (half of gradient)
80+
}
81+
82+
// Test backward-mode min() with x == y
83+
{
84+
dpfloat dpx = dpfloat(3.0, 0.0);
85+
dpfloat dpy = dpfloat(3.0, 0.0);
86+
__bwd_diff(diffMin)(dpx, dpy, 1.0);
87+
outputBuffer[10] = dpx.d; // Expect: 0.500000 (half of gradient)
88+
outputBuffer[11] = dpy.d; // Expect: 0.500000 (half of gradient)
89+
}
90+
91+
// Test vector max() with x == y
92+
{
93+
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
94+
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
95+
dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy);
96+
outputBuffer[12] = res.p[0]; // Expect: 3.000000
97+
outputBuffer[13] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
98+
outputBuffer[14] = res.p[1]; // Expect: 4.000000
99+
outputBuffer[15] = res.d[1]; // Expect: 2.000000
100+
}
101+
102+
// Test vector min() with x == y
103+
{
104+
dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0));
105+
dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0));
106+
dpfloat2 res = __fwd_diff(diffMin)(dpx, dpy);
107+
outputBuffer[16] = res.p[0]; // Expect: 3.000000
108+
outputBuffer[17] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0)
109+
outputBuffer[18] = res.p[1]; // Expect: 2.000000
110+
outputBuffer[19] = res.d[1]; // Expect: -3.000000
111+
}
112+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
type: float
2+
5.000000
3+
-2.000000
4+
3.000000
5+
-0.500000
6+
2.000000
7+
-2.000000
8+
3.000000
9+
-0.500000
10+
0.500000
11+
0.500000
12+
0.500000
13+
0.500000
14+
3.000000
15+
-0.500000
16+
4.000000
17+
2.000000
18+
3.000000
19+
-0.500000
20+
2.000000
21+
-3.000000

tests/autodiff-dstdlib/dstdlib-max.slang

-52
This file was deleted.

tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt

-11
This file was deleted.

tests/metal/test_buffer.slang

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Test that Buffer<T> maps to texture_buffer<uint, access::read> in Metal
2+
3+
//TEST:SIMPLE(filecheck=METAL): -stage compute -entry computeMain -target metal
4+
5+
6+
// METAL: texture_buffer<uint, access::read> inputBuffer_{{.*}}
7+
Buffer<uint> inputBuffer;
8+
9+
RWStructuredBuffer<uint> outputBuffer;
10+
11+
[numthreads(4, 1, 1)]
12+
void computeMain(uint3 dtid : SV_DispatchThreadID)
13+
{
14+
uint idx = dtid.x;
15+
// Load values from the buffer to verify correct access
16+
outputBuffer[idx] = inputBuffer.Load(idx);
17+
}

tools/gfx/metal/metal-device.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -653,8 +653,8 @@ Result DeviceImpl::createTextureView(
653653
MTL::PixelFormat pixelFormat = desc.format == Format::Unknown
654654
? textureImpl->m_pixelFormat
655655
: MetalUtil::translatePixelFormat(desc.format);
656-
NS::Range levelRange(sr.baseArrayLayer, sr.layerCount);
657-
NS::Range sliceRange(sr.mipLevel, sr.mipLevelCount);
656+
NS::Range sliceRange(sr.baseArrayLayer, sr.layerCount);
657+
NS::Range levelRange(sr.mipLevel, sr.mipLevelCount);
658658

659659
viewImpl->m_textureView = NS::TransferPtr(textureImpl->m_texture->newTextureView(
660660
pixelFormat,

tools/render-test/render-test-main.cpp

+50-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class RenderTestApp
131131
ComPtr<IBuffer> m_vertexBuffer;
132132
ComPtr<IShaderProgram> m_shaderProgram;
133133
ComPtr<IPipeline> m_pipeline;
134+
ComPtr<IShaderTable> m_shaderTable;
134135
ComPtr<ITexture> m_depthBuffer;
135136
ComPtr<ITextureView> m_depthBufferView;
136137
ComPtr<ITexture> m_colorBuffer;
@@ -648,6 +649,7 @@ SlangResult RenderTestApp::initialize(
648649
m_pipeline = device->createRenderPipeline(desc);
649650
}
650651
break;
652+
651653
case Options::ShaderProgramType::GraphicsMeshCompute:
652654
case Options::ShaderProgramType::GraphicsTaskMeshCompute:
653655
{
@@ -660,6 +662,33 @@ SlangResult RenderTestApp::initialize(
660662
desc.depthStencil.format = Format::D32_FLOAT;
661663
m_pipeline = device->createRenderPipeline(desc);
662664
}
665+
break;
666+
667+
case Options::ShaderProgramType::RayTracing:
668+
{
669+
RayTracingPipelineDesc desc;
670+
desc.program = m_shaderProgram;
671+
672+
m_pipeline = device->createRayTracingPipeline(desc);
673+
674+
const char* raygenNames[] = {"raygenMain"};
675+
676+
// We don't define a miss shader for this test. OptiX allows
677+
// passing nullptr to indicate no miss shader, but something in
678+
// slang-rhi assumes that the miss shader always has a name. To
679+
// work around that, use a dummy name.
680+
const char* missNames[] = {"missNull"};
681+
682+
ShaderTableDesc shaderTableDesc = {};
683+
shaderTableDesc.program = m_shaderProgram;
684+
shaderTableDesc.rayGenShaderCount = 1;
685+
shaderTableDesc.rayGenShaderEntryPointNames = raygenNames;
686+
shaderTableDesc.missShaderCount = 1;
687+
shaderTableDesc.missShaderEntryPointNames = missNames;
688+
SLANG_RETURN_ON_FAIL(
689+
device->createShaderTable(shaderTableDesc, m_shaderTable.writeRef()));
690+
}
691+
break;
663692
}
664693
}
665694
// If success must have a pipeline state
@@ -972,6 +1001,25 @@ Result RenderTestApp::update()
9721001
m_options.computeDispatchSize[2]);
9731002
passEncoder->end();
9741003
}
1004+
else if (m_options.shaderType == Options::ShaderProgramType::RayTracing)
1005+
{
1006+
auto rootObject = m_device->createRootShaderObject(m_pipeline);
1007+
applyBinding(rootObject);
1008+
rootObject->finalize();
1009+
1010+
auto passEncoder = encoder->beginRayTracingPass();
1011+
RayTracingState state;
1012+
state.pipeline = static_cast<IRayTracingPipeline*>(m_pipeline.get());
1013+
state.rootObject = rootObject;
1014+
state.shaderTable = m_shaderTable;
1015+
passEncoder->setRayTracingState(state);
1016+
passEncoder->dispatchRays(
1017+
0,
1018+
m_options.computeDispatchSize[0],
1019+
m_options.computeDispatchSize[1],
1020+
m_options.computeDispatchSize[2]);
1021+
passEncoder->end();
1022+
}
9751023
else
9761024
{
9771025
auto rootObject = m_device->createRootShaderObject(m_pipeline);
@@ -1072,7 +1120,8 @@ Result RenderTestApp::update()
10721120
if (m_options.shaderType == Options::ShaderProgramType::Compute ||
10731121
m_options.shaderType == Options::ShaderProgramType::GraphicsCompute ||
10741122
m_options.shaderType == Options::ShaderProgramType::GraphicsMeshCompute ||
1075-
m_options.shaderType == Options::ShaderProgramType::GraphicsTaskMeshCompute)
1123+
m_options.shaderType == Options::ShaderProgramType::GraphicsTaskMeshCompute ||
1124+
m_options.shaderType == Options::ShaderProgramType::RayTracing)
10761125
{
10771126
SLANG_RETURN_ON_FAIL(writeBindingOutput(m_options.outputPath));
10781127
}

0 commit comments

Comments
 (0)