Skip to content

Commit 0298a04

Browse files
author
Tim Foley
authored
IR: support CompileTimeForStmt (shader-slang#286)
This statement type is a bit of a hack, to support loops that *must* be unrolled. The AST-to-AST pass handles them by cloning the AST for the loop body N times, and it was easy enough to do the same thing for the IR: emit the instructions for the body N times. The only thing that requires a bit of care is that now we might see the same variable declarations multiple times, so we need to play it safe and overwrite existing entries in our map from declarations to their IR values. Of course a better answer long-term would be to do the actual unrolling in the IR. This is especially true because we might some day want to support compile-time/must-unroll loops in functions, where the loop counter comes in as a parameter (but must still be compile-time-constant at every call site).
1 parent 0e3d9ba commit 0298a04

File tree

3 files changed

+128
-9
lines changed

3 files changed

+128
-9
lines changed

source/slang/lower-to-ir.cpp

+35-9
Original file line numberDiff line numberDiff line change
@@ -1637,9 +1637,37 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
16371637
SLANG_UNEXPECTED("`case` or `default` not under `switch`");
16381638
}
16391639

1640-
void visitCompileTimeForStmt(CompileTimeForStmt*)
1640+
void visitCompileTimeForStmt(CompileTimeForStmt* stmt)
16411641
{
1642-
SLANG_UNIMPLEMENTED_X("IR lowering of CompileTimeForStmt");
1642+
// The user is asking us to emit code for the loop
1643+
// body for each value in the given integer range.
1644+
// For now, we will handle this by repeatedly lowering
1645+
// the body statement, with the loop variable bound
1646+
// to a different integer literal value each time.
1647+
//
1648+
// TODO: eventually we might handle this as just an
1649+
// ordinary loop, with an `[unroll]` attribute on
1650+
// it that we would respect.
1651+
1652+
auto rangeBeginVal = GetIntVal(stmt->rangeBeginVal);
1653+
auto rangeEndVal = GetIntVal(stmt->rangeEndVal);
1654+
1655+
if (rangeBeginVal >= rangeEndVal)
1656+
return;
1657+
1658+
auto varDecl = stmt->varDecl;
1659+
auto varType = varDecl->type;
1660+
1661+
for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii)
1662+
{
1663+
auto constVal = getBuilder()->getIntValue(
1664+
varType,
1665+
ii);
1666+
1667+
context->shared->declValues[varDecl] = LoweredValInfo::simple(constVal);
1668+
1669+
lowerStmt(context, stmt->body);
1670+
}
16431671
}
16441672

16451673
// Create a basic block in the current function,
@@ -2590,9 +2618,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
25902618
// A global variable's SSA value is a *pointer* to
25912619
// the underlying storage.
25922620
auto globalVal = LoweredValInfo::ptr(irGlobal);
2593-
context->shared->declValues.Add(
2594-
DeclRef<VarDeclBase>(decl, nullptr),
2595-
globalVal);
2621+
context->shared->declValues[
2622+
DeclRef<VarDeclBase>(decl, nullptr)] = globalVal;
25962623

25972624
if( auto initExpr = decl->initExpr )
25982625
{
@@ -2667,9 +2694,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
26672694
assign(context, varVal, initVal);
26682695
}
26692696

2670-
context->shared->declValues.Add(
2671-
DeclRef<VarDeclBase>(decl, nullptr),
2672-
varVal);
2697+
context->shared->declValues[
2698+
DeclRef<VarDeclBase>(decl, nullptr)] = varVal;
26732699

26742700
return varVal;
26752701
}
@@ -3214,7 +3240,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
32143240
if( auto paramDecl = paramInfo.decl )
32153241
{
32163242
DeclRef<VarDeclBase> paramDeclRef = makeDeclRef(paramDecl);
3217-
subContext->shared->declValues.Add(paramDeclRef, paramVal);
3243+
subContext->shared->declValues[paramDeclRef] = paramVal;
32183244
}
32193245

32203246
if (paramInfo.isThisParam)

tests/compute/compile-time-loop.slang

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//TEST(compute):COMPARE_RENDER_COMPUTE:-xslang -use-ir
2+
3+
//TEST_INPUT: Texture2D(size=4, content = one) : dxbinding(0),glbinding(0)
4+
//TEST_INPUT: Sampler : dxbinding(0),glbinding(0)
5+
6+
//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(1),glbinding(0),out
7+
8+
Texture2D t;
9+
SamplerState s;
10+
RWStructuredBuffer<float> outputBuffer;
11+
12+
cbuffer Uniforms
13+
{
14+
float4x4 modelViewProjection;
15+
}
16+
17+
struct AssembledVertex
18+
{
19+
float3 position;
20+
float3 color;
21+
float2 uv;
22+
};
23+
24+
struct CoarseVertex
25+
{
26+
float3 color;
27+
float2 uv;
28+
};
29+
30+
struct Fragment
31+
{
32+
float4 color;
33+
};
34+
35+
// Vertex Shader
36+
37+
struct VertexStageInput
38+
{
39+
AssembledVertex assembledVertex : A;
40+
};
41+
42+
struct VertexStageOutput
43+
{
44+
CoarseVertex coarseVertex : CoarseVertex;
45+
float4 sv_position : SV_Position;
46+
};
47+
48+
VertexStageOutput vertexMain(VertexStageInput input)
49+
{
50+
VertexStageOutput output;
51+
52+
float3 position = input.assembledVertex.position;
53+
float3 color = input.assembledVertex.color;
54+
55+
output.coarseVertex.color = color;
56+
output.sv_position = mul(modelViewProjection, float4(position, 1.0));
57+
output.coarseVertex.uv = input.assembledVertex.uv;
58+
return output;
59+
}
60+
61+
// Fragment Shader
62+
63+
struct FragmentStageInput
64+
{
65+
CoarseVertex coarseVertex : CoarseVertex;
66+
};
67+
68+
struct FragmentStageOutput
69+
{
70+
Fragment fragment : SV_Target;
71+
};
72+
73+
FragmentStageOutput fragmentMain(FragmentStageInput input)
74+
{
75+
FragmentStageOutput output;
76+
77+
float3 color = input.coarseVertex.color;
78+
float2 uv = input.coarseVertex.uv;
79+
output.fragment.color = float4(color, 1.0);
80+
81+
82+
float4 result = 0;
83+
$for(i in Range(0,5))
84+
{
85+
float4 v = t.Sample(s, uv, int2(i - 2, 0));
86+
result += v;
87+
}
88+
89+
outputBuffer[0] = result.x;
90+
91+
return output;
92+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
40A00000

0 commit comments

Comments
 (0)