Skip to content

Commit 051ae8a

Browse files
Fix crash during emitCast of attributed type, allow MaxIters to take linktime const. (#5791)
* Fix crash during emitCast of attributed type. * Allow [MaxIters] to take link time constants. --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
1 parent 71e90a7 commit 051ae8a

7 files changed

+93
-14
lines changed

source/slang/slang-ast-modifier.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ class MaxItersAttribute : public Attribute
744744
{
745745
SLANG_AST_CLASS(MaxItersAttribute)
746746

747-
int32_t value = 0;
747+
IntVal* value = 0;
748748
};
749749

750750
// An inferred max iteration count on a loop.

source/slang/slang-check-modifier.cpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -720,11 +720,7 @@ Modifier* SemanticsVisitor::validateAttribute(
720720
}
721721
else
722722
{
723-
auto cint = checkConstantIntVal(attr->args[0]);
724-
if (cint)
725-
{
726-
maxItersAttrs->value = (int32_t)cint->getValue();
727-
}
723+
maxItersAttrs->value = checkLinkTimeConstantIntVal(attr->args[0]);
728724
}
729725
}
730726
else if (const auto userDefAttr = as<UserDefinedAttribute>(attr))

source/slang/slang-ir-insts.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -4637,9 +4637,14 @@ struct IRBuilder
46374637
getIntValue(getIntType(), IRIntegerValue(mode)));
46384638
}
46394639

4640-
void addLoopMaxItersDecoration(IRInst* value, IntegerLiteralValue iters)
4640+
void addLoopMaxItersDecoration(IRInst* value, IRIntegerValue iters)
46414641
{
4642-
addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters));
4642+
addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(iters));
4643+
}
4644+
4645+
void addLoopMaxItersDecoration(IRInst* value, IRInst* iters)
4646+
{
4647+
addDecoration(value, kIROp_LoopMaxItersDecoration, iters);
46434648
}
46444649

46454650
void addLoopForceUnrollDecoration(IRInst* value, IntegerLiteralValue iters)

source/slang/slang-ir.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -3890,6 +3890,8 @@ enum class TypeCastStyle
38903890
};
38913891
static TypeCastStyle _getTypeStyleId(IRType* type)
38923892
{
3893+
type = (IRType*)unwrapAttributedType(type);
3894+
38933895
if (auto vectorType = as<IRVectorType>(type))
38943896
{
38953897
return _getTypeStyleId(vectorType->getElementType());

source/slang/slang-lower-to-ir.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -5938,7 +5938,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
59385938

59395939
if (auto maxItersAttr = stmt->findModifier<MaxItersAttribute>())
59405940
{
5941-
getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value);
5941+
auto iters = lowerVal(context, maxItersAttr->value);
5942+
getBuilder()->addLoopMaxItersDecoration(inst, getSimpleVal(context, iters));
59425943
}
59435944
else if (auto inferredMaxItersAttr = stmt->findModifier<InferredMaxItersAttribute>())
59445945
{
@@ -6028,12 +6029,15 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
60286029
{
60296030
if (auto maxIters = stmt->findModifier<MaxItersAttribute>())
60306031
{
6031-
if (inferredMaxIters->value < maxIters->value)
6032+
if (auto constIntVal = as<ConstantIntVal>(maxIters->value))
60326033
{
6033-
context->getSink()->diagnose(
6034-
maxIters,
6035-
Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters,
6036-
inferredMaxIters->value);
6034+
if (inferredMaxIters->value < constIntVal->getValue())
6035+
{
6036+
context->getSink()->diagnose(
6037+
maxIters,
6038+
Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters,
6039+
inferredMaxIters->value);
6040+
}
60376041
}
60386042
}
60396043
}

tests/bugs/gh-5781.slang

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target spirv
2+
// CHECK: OpEntryPoint
3+
4+
module test;
5+
6+
public enum class MaterialID : uint { invalid = 0xffffffff };
7+
8+
public struct Material : IDifferentiable
9+
{
10+
float x;
11+
}
12+
13+
public struct Hit
14+
{
15+
MaterialID material;
16+
}
17+
18+
public struct Scene
19+
{
20+
StructuredBuffer<Material> materials;
21+
RWStructuredBuffer<Material> grads;
22+
23+
[Differentiable]
24+
Material load(MaterialID id) { return materials[uint(id)]; }
25+
26+
void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; }
27+
28+
[Differentiable, BackwardDerivative(_get_material_bwd)]
29+
public Material get_material(MaterialID id) { return load(id); }
30+
31+
public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); }
32+
33+
[Differentiable]
34+
public Material get_material(Hit hit) { return get_material(hit.material); }
35+
}
36+
37+
[Differentiable]
38+
float trace(const Scene scene, Hit hit)
39+
{
40+
Material m = scene.get_material(hit);
41+
return m.x;
42+
}
43+
44+
45+
[shader("compute")]
46+
void main(
47+
uniform Scene scene,
48+
uniform StructuredBuffer<uint> input,
49+
uniform RWStructuredBuffer<float> output,
50+
uniform RWStructuredBuffer<float> grads
51+
)
52+
{
53+
Hit hit;
54+
hit.material = MaterialID(input[0]);
55+
output[0] = trace(scene, hit);
56+
bwd_diff(trace)(scene, hit, grads[0]);
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target spirv
2+
// CHECK: OpEntryPoint
3+
4+
extern static const int num = 10;
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
[numthreads(1,1,1)]
8+
void computeMain()
9+
{
10+
[MaxIters(num)]
11+
for (int i = 0; i < num; i++)
12+
{
13+
outputBuffer[i] = i;
14+
}
15+
}

0 commit comments

Comments
 (0)