Skip to content

Commit cfb76f5

Browse files
authored
Merge branch 'master' into fix-optix-raytracing-test
2 parents fd92d8a + 9580e31 commit cfb76f5

11 files changed

+376
-15
lines changed

source/slang/diff.meta.slang

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
/// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode
32
/// derivative implementation.
43
/// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing
@@ -80,7 +79,6 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
8079
/// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified
8180
/// name.
8281
///
83-
///
8482
__attributeTarget(FunctionDeclBase)
8583
attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
8684

@@ -2150,17 +2148,19 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin
21502148
{
21512149
return DifferentialPair<T>(
21522150
clamp(dpx.p, dpMin.p, dpMax.p),
2153-
dpx.p < dpMin.p ? dpMin.d : (dpx.p > dpMax.p ? dpMax.d : dpx.d));
2151+
(dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dpx.d : (dpx.p < dpMin.p ? dpMin.d : dpMax.d));
21542152
}
21552153
__generic<T : __BuiltinFloatingPointType>
21562154
[BackwardDifferentiable]
21572155
[PreferRecompute]
21582156
[BackwardDerivativeOf(clamp)]
21592157
void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut)
21602158
{
2161-
dpx = diffPair(dpx.p, dpx.p > dpMin.p && dpx.p < dpMax.p ? dOut : T.dzero());
2162-
dpMin = diffPair(dpMin.p, dpx.p <= dpMin.p ? dOut : T.dzero());
2163-
dpMax = diffPair(dpMax.p, dpx.p >= dpMax.p ? dOut : T.dzero());
2159+
// Propagate the derivative to x if x is within [min, max] (including the boundaries).
2160+
dpx = diffPair(dpx.p, (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dOut : T.dzero());
2161+
// If x is strictly below min or above max, the gradient is instead applied to the clamp bounds
2162+
dpMin = diffPair(dpMin.p, dpx.p < dpMin.p ? dOut : T.dzero());
2163+
dpMax = diffPair(dpMax.p, dpx.p > dpMax.p ? dOut : T.dzero());
21642164
}
21652165
VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp)
21662166

source/slang/slang-emit-hlsl.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -1252,23 +1252,30 @@ void HLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst)
12521252
{
12531253
case IRConstant::FloatKind::Nan:
12541254
{
1255-
m_writer->emit("(0.0 / 0.0)");
1255+
m_writer->emit("(0.0f / 0.0f)");
12561256
return;
12571257
}
12581258
case IRConstant::FloatKind::PositiveInfinity:
12591259
{
1260-
m_writer->emit("(1.0 / 0.0)");
1260+
m_writer->emit("(1.0f / 0.0f)");
12611261
return;
12621262
}
12631263
case IRConstant::FloatKind::NegativeInfinity:
12641264
{
1265-
m_writer->emit("(-1.0 / 0.0)");
1265+
m_writer->emit("(-1.0f / 0.0f)");
12661266
return;
12671267
}
12681268
default:
1269-
break;
1269+
{
1270+
m_writer->emit(constantInst->value.floatVal);
1271+
// Add 'f' suffix for 32-bit float literals to ensure DXC treats them as float
1272+
if (constantInst->getDataType()->getOp() == kIROp_FloatType)
1273+
{
1274+
m_writer->emit("f");
1275+
}
1276+
return;
1277+
}
12701278
}
1271-
break;
12721279
}
12731280

12741281
default:

source/slang/slang-ir-lower-binding-query.cpp

+112-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,118 @@ struct BindingQueryLoweringContext : public WorkListPass
288288
//
289289
OpaqueValueInfo computeOpaqueValueInfo(IRInst* opaqueValue)
290290
{
291-
if (auto globalParam = as<IRGlobalParam>(opaqueValue))
291+
if (auto getElement = as<IRGetElement>(opaqueValue))
292+
{
293+
IRInst* baseInst = getElement->getBase();
294+
IRInst* indexInst = getElement->getIndex();
295+
296+
IRInst* elementType = getElement->getDataType();
297+
298+
// TODO(JS): This a hack to make this work for arrays of resource type.
299+
// It won't work in the general case as it stands because we would need
300+
// to propogate layout kind types needed at usage sites.
301+
// Without knowing the resource kind that is being processed it's not possible
302+
// to accumulate the calculation.
303+
//
304+
// So presumably we need to request a binding query for a specific resource kind.
305+
// We could do this by making the type of the binding query hold the type.
306+
307+
// We need to add instructions which will work out the binding for the base
308+
OpaqueValueInfo baseInfo = findOrComputeOpaqueValueInfo(baseInst);
309+
310+
// If we couldn't find it we are done
311+
if (baseInfo.registerIndex == nullptr || baseInfo.registerSpace == nullptr)
312+
{
313+
return baseInfo;
314+
}
315+
316+
317+
LayoutResourceKind kind = LayoutResourceKind::None;
318+
Index stride = 1;
319+
320+
if (auto resourceType = as<IRResourceType>(elementType))
321+
{
322+
const auto shape = resourceType->getShape();
323+
324+
switch (shape)
325+
{
326+
case SLANG_TEXTURE_1D:
327+
case SLANG_TEXTURE_2D:
328+
case SLANG_TEXTURE_3D:
329+
case SLANG_TEXTURE_CUBE:
330+
case SLANG_STRUCTURED_BUFFER:
331+
case SLANG_BYTE_ADDRESS_BUFFER:
332+
case SLANG_TEXTURE_BUFFER:
333+
{
334+
const auto access = resourceType->getAccess();
335+
bool isReadOnly = (access == SLANG_RESOURCE_ACCESS_READ);
336+
337+
kind = isReadOnly ? LayoutResourceKind::ShaderResource
338+
: LayoutResourceKind::UnorderedAccess;
339+
break;
340+
}
341+
default:
342+
break;
343+
}
344+
}
345+
else if (as<IRSamplerStateTypeBase>(elementType))
346+
{
347+
kind = LayoutResourceKind::SamplerState;
348+
}
349+
else if (as<IRConstantBufferType>(elementType))
350+
{
351+
kind = LayoutResourceKind::ConstantBuffer;
352+
}
353+
354+
if (kind == LayoutResourceKind::None)
355+
{
356+
// Can't determine the kind
357+
return OpaqueValueInfo();
358+
}
359+
360+
// If the element type has type layout we can try and use that
361+
if (auto layoutDecoration = elementType->findDecoration<IRLayoutDecoration>())
362+
{
363+
// We have to calculate
364+
if (auto elementTypeLayout = as<IRTypeLayout>(layoutDecoration->getLayout()))
365+
{
366+
IRTypeSizeAttr* sizeAttr = elementTypeLayout->findSizeAttr(kind);
367+
sizeAttr = sizeAttr ? sizeAttr
368+
: elementTypeLayout->findSizeAttr(
369+
LayoutResourceKind::DescriptorTableSlot);
370+
371+
if (!sizeAttr)
372+
{
373+
// Couldn't work it out
374+
return OpaqueValueInfo();
375+
}
376+
377+
// TODO(JS): Perhaps we have to do something else if not finite?
378+
stride = sizeAttr->getFiniteSize();
379+
}
380+
}
381+
382+
SLANG_UNUSED(indexInst);
383+
384+
// Okay we need to create an instruction which is
385+
// base + stride * index
386+
387+
IRBuilder builder(module);
388+
389+
builder.setInsertBefore(opaqueValue);
390+
391+
auto calcRegisterInst = builder.emitAdd(
392+
indexType,
393+
builder.emitMul(indexType, builder.getIntValue(indexType, stride), indexInst),
394+
baseInfo.registerIndex);
395+
396+
OpaqueValueInfo finalInfo;
397+
finalInfo.registerIndex = calcRegisterInst;
398+
finalInfo.registerSpace = baseInfo.registerSpace;
399+
400+
return finalInfo;
401+
}
402+
else if (auto globalParam = as<IRGlobalParam>(opaqueValue))
292403
{
293404
// The simple/base case is when we have a global shader
294405
// parameter that has layout information attached.

source/slang/slang-ir-specialize-resources.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,10 @@ bool isIllegalGLSLParameterType(IRType* type)
13591359
return true;
13601360
if (as<IRDynamicResourceType>(type))
13611361
return true;
1362+
if (as<IRHLSLInputPatchType>(type))
1363+
return true;
1364+
if (as<IRHLSLOutputPatchType>(type))
1365+
return true;
13621366
return false;
13631367
}
13641368

source/slang/slang-ir.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -5259,6 +5259,11 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
52595259
SLANG_ASSERT(as<IRIntLit>(index));
52605260
type = (IRType*)tupleType->getOperand(getIntVal(index));
52615261
}
5262+
else if (auto hlslInputPatchType = as<IRHLSLInputPatchType>(valueType))
5263+
{
5264+
type = hlslInputPatchType->getElementType();
5265+
}
5266+
52625267
SLANG_RELEASE_ASSERT(type);
52635268
auto inst = createInst<IRGetElementPtr>(
52645269
this,

tests/autodiff-dstdlib/dstdlib-clamp.slang

+43-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
22
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
33

4-
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
55
RWStructuredBuffer<float> outputBuffer;
66

77
typedef DifferentialPair<float> dpfloat;
@@ -178,4 +178,46 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
178178
outputBuffer[28] = dpmax.d.y; // Expected: 0.0
179179
outputBuffer[29] = dpmax.d.z; // Expected: 0.3
180180
}
181+
182+
// New tests: Forward-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
183+
{
184+
// Lower edge: x exactly = 0
185+
dpfloat dpx = dpfloat(0.0, 0.4);
186+
dpfloat dpmin = dpfloat(0.0, 0.8);
187+
dpfloat dpmax = dpfloat(1.0, 0.5);
188+
dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
189+
outputBuffer[30] = res.d; // Expected: 0.4 (propagated from x)
190+
}
191+
192+
{
193+
// Upper edge: x exactly = 1
194+
dpfloat dpx = dpfloat(1.0, 0.7);
195+
dpfloat dpmin = dpfloat(0.0, 0.8);
196+
dpfloat dpmax = dpfloat(1.0, 0.9);
197+
dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
198+
outputBuffer[31] = res.d; // Expected: 0.7 (propagated from x)
199+
}
200+
201+
// Reverse-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
202+
{
203+
// Lower edge: x exactly = 0
204+
dpfloat dpx = dpfloat(0.0, 0.0);
205+
dpfloat dpmin = dpfloat(0.0, 0.0);
206+
dpfloat dpmax = dpfloat(1.0, 0.0);
207+
bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
208+
outputBuffer[32] = dpx.d; // Expected: 1.0 (propagated from x)
209+
outputBuffer[33] = dpmin.d; // Expected: 0.0
210+
outputBuffer[34] = dpmax.d; // Expected: 0.0
211+
}
212+
213+
{
214+
// Upper edge: x exactly = 1
215+
dpfloat dpx = dpfloat(1.0, 0.0);
216+
dpfloat dpmin = dpfloat(0.0, 0.0);
217+
dpfloat dpmax = dpfloat(1.0, 0.0);
218+
bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
219+
outputBuffer[35] = dpx.d; // Expected: 1.0 (propagated from x)
220+
outputBuffer[36] = dpmin.d; // Expected: 0.0
221+
outputBuffer[37] = dpmax.d; // Expected: 0.0
222+
}
181223
}

tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt

+8
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,11 @@ type: float
2929
0.000000
3030
0.000000
3131
0.300000
32+
0.400000
33+
0.700000
34+
1.000000
35+
0.000000
36+
0.000000
37+
1.000000
38+
0.000000
39+
0.000000

tests/expected-failure-github.txt

-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
tests/language-feature/spirv-asm/imageoperands-warning.slang (vk)
21
tests/language-feature/saturated-cooperation/simple.slang (vk)
32
tests/language-feature/saturated-cooperation/fuse3.slang (vk)
43
tests/language-feature/saturated-cooperation/fuse-product.slang (vk)
54
tests/language-feature/saturated-cooperation/fuse.slang (vk)
6-
tests/bugs/byte-address-buffer-interlocked-add-f32.slang (vk)
75
tests/render/render0.hlsl (mtl)
86
tests/render/multiple-stage-io-locations.slang (mtl)
97
tests/render/nointerpolation.hlsl (mtl)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//DISABLE_TEST:SIMPLE:-target dxil-assembly -entry fragmentMain -profile sm_6_5 -stage fragment -DNV_SHADER_EXTN_SLOT=u0
2+
3+
//TEST:SIMPLE(filecheck=SPIRV):-target spirv-assembly -entry fragmentMain -stage fragment
4+
//TEST:SIMPLE(filecheck=DXIL):-target dxil-assembly -entry fragmentMain -profile sm_6_5 -stage fragment -DNV_SHADER_EXTN_SLOT=u0
5+
//TEST:SIMPLE(filecheck=HLSL):-target hlsl -entry fragmentMain -stage fragment
6+
7+
//DISABLED_TEST:SIMPLE:-target spirv-assembly -entry fragmentMain -stage fragment
8+
//DISABLED_TEST:SIMPLE:-target dxil-assembly -entry fragmentMain -stage fragment
9+
//DISABLED_TEST:SIMPLE:-target hlsl -entry fragmentMain -stage fragment
10+
11+
uniform Texture2D textures[] : register(t2, space10);
12+
uniform SamplerState sampler;
13+
uniform RWStructuredBuffer<uint> outputBuffer;
14+
15+
static Texture2D _getBindlessTexture2d(uint texIdx)
16+
{
17+
return textures[NonUniformResourceIndex(texIdx)];
18+
}
19+
20+
void accumulate(inout uint r, uint u)
21+
{
22+
r = r ^ u;
23+
}
24+
25+
void accumulate(inout uint r, bool b)
26+
{
27+
accumulate(r, uint(b));
28+
}
29+
30+
void accumulate(inout uint r, uint2 u)
31+
{
32+
accumulate(r, u.x);
33+
accumulate(r, u.y);
34+
}
35+
36+
void accumulate(inout uint r, uint3 u)
37+
{
38+
accumulate(r, u.x);
39+
accumulate(r, u.y);
40+
accumulate(r, u.z);
41+
}
42+
43+
void accumulate(inout uint r, TextureFootprint2D f)
44+
{
45+
accumulate(r, f.anchor);
46+
accumulate(r, f.offset);
47+
accumulate(r, f.mask);
48+
accumulate(r, f.lod);
49+
accumulate(r, f.granularity);
50+
accumulate(r, f.isSingleLevel);
51+
}
52+
53+
cbuffer Uniforms
54+
{
55+
uniform float2 coords;
56+
uniform uint granularity;
57+
};
58+
59+
void fragmentMain(
60+
float v : VARYING)
61+
{
62+
uint index = uint(v);
63+
uint r = 0;
64+
65+
accumulate(r, _getBindlessTexture2d(index).queryFootprintCoarse(granularity, sampler, coords));
66+
67+
// SPIRV: Extension "SPV_NV_shader_image_footprint"
68+
// SPIRV: ImageSampleFootprintNV
69+
70+
// DXIL: struct struct.NvShaderExtnStruct
71+
72+
// HLSL: NvFootprintCoarse
73+
74+
outputBuffer[index] = r;
75+
}
76+

tests/hlsl/float-literal-suffix.slang

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//TEST:SIMPLE(filecheck=HLSL):-target hlsl -profile ps_6_6 -entry fragmentMain
2+
//TEST:SIMPLE(filecheck=DXIL):-target dxil -profile ps_6_6 -entry fragmentMain
3+
4+
float4 fragmentMain(float2 uv : TEXCOORD) : SV_Target
5+
{
6+
//HLSL:, ddx({{.*}}1.5f)
7+
//DXIL: = call float @dx.op.unary.f32(i32 {{.*}} ; DerivCoarseX(value)
8+
float val = 1.5;
9+
return float4(1.0, ddx(val), 0.0, 1.0);
10+
}

0 commit comments

Comments
 (0)