Skip to content

Commit 05547e2

Browse files
Warnings function parameters (shader-slang#4626)
* Handle out/inout functions with separate consideration * Fixing bug with passing aliasable instructions * Handle autodiff functions (fwd and rev) in warning system * Handling interface methods * Handling ref parameters like out/inout * Temporary fix to remaining bugs * Refactoring methods and tests * Recursive check for empty structs * Using default initializable interface in tests * Resolving CI fail
1 parent b5174b4 commit 05547e2

19 files changed

+347
-177
lines changed

source/slang/slang-ir-use-uninitialized-values.cpp

+107-23
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ namespace Slang
2828
return false;
2929
}
3030

31-
// Casting to IRUndefined is currently vacuous
32-
// (e.g. any IRInst can be cast to IRUndefined)
33-
static bool isUndefinedValue(IRInst* inst)
31+
static bool isUninitializedValue(IRInst* inst)
3432
{
35-
return (inst->m_op == kIROp_undefined);
33+
// Also consider var since it does not
34+
// automatically mean it will be initialized
35+
// (at least not as the user may have intended)
36+
return (inst->m_op == kIROp_undefined)
37+
|| (inst->m_op == kIROp_Var);
3638
}
3739

3840
static bool isUndefinedParam(IRParam* param)
@@ -98,31 +100,56 @@ namespace Slang
98100
return false;
99101
}
100102

101-
static bool canIgnoreType(IRType* type)
103+
static IRInst* resolveSpecialization(IRSpecialize* spec)
102104
{
105+
IRInst* base = spec->getBase();
106+
IRGeneric* generic = as<IRGeneric>(base);
107+
return findInnerMostGenericReturnVal(generic);
108+
}
109+
110+
// The `upper` field contains the struct that the type is
111+
// is contained in. It is used to check for empty structs.
112+
static bool canIgnoreType(IRType* type, IRType* upper)
113+
{
114+
// In case specialization returns a function instead
115+
if (!type)
116+
return true;
117+
103118
if (as<IRVoidType>(type))
104119
return true;
105120

106121
// For structs, ignore if its empty
107-
if (as<IRStructType>(type))
108-
return (type->getFirstChild() == nullptr);
122+
if (auto str = as<IRStructType>(type))
123+
{
124+
int count = 0;
125+
for (auto field : str->getFields())
126+
{
127+
IRType* ftype = field->getFieldType();
128+
count += !canIgnoreType(ftype, type);
129+
}
130+
131+
return (count == 0);
132+
}
109133

110134
// Nothing to initialize for a pure interface
111135
if (as<IRInterfaceType>(type))
112136
return true;
113137

114138
// For pointers, check the value type (primarily for globals)
115139
if (auto ptr = as<IRPtrType>(type))
116-
return canIgnoreType(ptr->getValueType());
140+
{
141+
// Avoid the recursive step if its a
142+
// recursive structure like a linked list
143+
IRType* ptype = ptr->getValueType();
144+
return (ptype != upper) && canIgnoreType(ptype, upper);
145+
}
117146

118147
// In the case of specializations, check returned type
119148
if (auto spec = as<IRSpecialize>(type))
120149
{
121-
IRInst* base = spec->getBase();
122-
IRGeneric* generic = as<IRGeneric>(base);
123-
IRInst* inner = findInnerMostGenericReturnVal(generic);
150+
IRInst* inner = resolveSpecialization(spec);
124151
IRType* innerType = as<IRType>(inner);
125-
return canIgnoreType(innerType);
152+
return canIgnoreType(innerType, upper);
126153
}
127154

128155
return false;
@@ -146,8 +173,54 @@ namespace Slang
146173

147174
return addresses;
148175
}
176+
177+
static void checkCallUsage(List<IRInst*>& stores, List<IRInst*>& loads, IRCall* call, IRInst* inst)
178+
{
179+
IRInst* callee = call->getCallee();
180+
181+
// Resolve the actual function
182+
IRFunc* ftn = nullptr;
183+
IRFuncType* ftype = nullptr;
184+
if (auto spec = as<IRSpecialize>(callee))
185+
ftn = as<IRFunc>(resolveSpecialization(spec));
186+
else if (auto fwd = as<IRForwardDifferentiate>(callee))
187+
ftn = as<IRFunc>(fwd->getBaseFn());
188+
else if (auto rev = as<IRBackwardDifferentiate>(callee))
189+
ftn = as<IRFunc>(rev->getBaseFn());
190+
else if (auto wit = as<IRLookupWitnessMethod>(callee))
191+
ftype = as<IRFuncType>(callee->getFullType());
192+
else
193+
ftn = as<IRFunc>(callee);
194+
195+
// Find the argument index so we can fetch the type
196+
int index = 0;
197+
198+
auto args = call->getArgsList();
199+
for (int i = 0; i < args.getCount(); i++)
200+
{
201+
if (args[i] == inst)
202+
{
203+
index = i;
204+
break;
205+
}
206+
}
207+
208+
if (ftn)
209+
ftype = as<IRFuncType>(ftn->getFullType());
210+
211+
if (!ftype)
212+
return;
213+
214+
// Consider it as a store if its passed
215+
// as an out/inout/ref parameter
216+
IRType* type = ftype->getParamType(index);
217+
if (as<IROutType>(type) || as<IRInOutType>(type) || as<IRRefType>(type))
218+
stores.add(call);
219+
else
220+
loads.add(call);
221+
}
149222

150-
static void collectLoadStore(List<IRInst*>& stores, List<IRInst*>& loads, IRInst* user)
223+
static void collectLoadStore(List<IRInst*>& stores, List<IRInst*>& loads, IRInst* user, IRInst* inst)
151224
{
152225
// Meta intrinsics (which evaluate on type) do nothing
153226
if (isMetaOp(user))
@@ -163,13 +236,17 @@ namespace Slang
163236
case kIROp_unconditionalBranch:
164237
// TODO: Ignore branches for now
165238
return;
239+
240+
case kIROp_Call:
241+
// Function calls can be either
242+
// stores or loads depending on
243+
// whether the callee takes it
244+
// in as a out parameter or not
245+
return checkCallUsage(stores, loads, as<IRCall>(user), inst);
166246

167247
// These instructions will store data...
168248
case kIROp_Store:
169249
case kIROp_SwizzledStore:
170-
// TODO: for calls, should make check that the
171-
// function is passing as an out param
172-
case kIROp_Call:
173250
case kIROp_SPIRVAsm:
174251
case kIROp_GenericAsm:
175252
// For now assume that __intrinsic_asm blocks will do the right thing...
@@ -187,6 +264,11 @@ namespace Slang
187264
// For specializing generic structs
188265
stores.add(user);
189266
break;
267+
268+
// Miscellaenous cases
269+
case kIROp_ManagedPtrAttach:
270+
stores.add(user);
271+
break;
190272

191273
// ... and the rest will load/use them
192274
default:
@@ -225,7 +307,7 @@ namespace Slang
225307
for (auto use = alias->firstUse; use; use = use->nextUse)
226308
{
227309
IRInst* user = use->getUser();
228-
collectLoadStore(stores, loads, user);
310+
collectLoadStore(stores, loads, user, alias);
229311
}
230312
}
231313

@@ -257,7 +339,7 @@ namespace Slang
257339
for (auto use = alias->firstUse; use; use = use->nextUse)
258340
{
259341
IRInst* user = use->getUser();
260-
collectLoadStore(stores, loads, user);
342+
collectLoadStore(stores, loads, user, alias);
261343
}
262344
}
263345

@@ -297,11 +379,11 @@ namespace Slang
297379
// Check ordinary instructions
298380
for (auto inst = firstBlock->getFirstInst(); inst; inst = inst->getNextInst())
299381
{
300-
if (!isUndefinedValue(inst))
382+
if (!isUninitializedValue(inst))
301383
continue;
302384

303385
IRType* type = inst->getFullType();
304-
if (canIgnoreType(type))
386+
if (canIgnoreType(type, nullptr))
305387
continue;
306388

307389
auto loads = getUnresolvedVariableLoads(reachability, inst);
@@ -317,7 +399,7 @@ namespace Slang
317399
static void checkUninitializedGlobals(IRGlobalVar* variable, DiagnosticSink* sink)
318400
{
319401
IRType* type = variable->getFullType();
320-
if (canIgnoreType(type))
402+
if (canIgnoreType(type, nullptr))
321403
return;
322404

323405
// Check for semantic decorations
@@ -331,7 +413,7 @@ namespace Slang
331413
if (as<IRBlock>(inst))
332414
return;
333415
}
334-
416+
335417
auto addresses = getAliasableInstructions(variable);
336418

337419
List<IRInst*> stores;
@@ -342,12 +424,14 @@ namespace Slang
342424
for (auto use = alias->firstUse; use; use = use->nextUse)
343425
{
344426
IRInst* user = use->getUser();
345-
collectLoadStore(stores, loads, user);
427+
collectLoadStore(stores, loads, user, alias);
346428

347429
// Disregard if there is at least one store,
348430
// since we cannot tell what the control flow is
349431
if (stores.getCount())
350432
return;
433+
434+
// TODO: see if we can do better here (another kind of reachability check?)
351435
}
352436
}
353437

tests/autodiff/material2/MxWeights.slang

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
public struct MxWeights<let TBsdfCount : int>
44
{
55
public float3 weights[TBsdfCount];
6+
7+
public __init()
8+
{
9+
for (int i = 0; i < TBsdfCount; i++)
10+
weights[i] = float3(0.0f);
11+
}
612
}
713

814
public interface IMxLayeredMaterialData

tests/bugs/generic-param-cast.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct A<let I : int>
88
int f() { return I; }
99
};
1010

11-
struct B<let U : uint>
11+
struct B<let U : uint> : IDefaultInitializable
1212
{
1313
A<U> a;
1414
};

tests/bugs/specialize-existential-in-generic.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct Impl : IFoo
1818
Assoc getValue() { Assoc r; return r; }
1919
}
2020

21-
struct GenType<T : IFoo>
21+
struct GenType<T : IFoo> : IDefaultInitializable
2222
{
2323
T obj;
2424
int doThing()

tests/compute/assoctype-func-param.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct GenStruct<T> : IBase
3737

3838
U.RetT test<U:IBase>(U.RetT val)
3939
{
40-
U obj;
40+
U obj = U();
4141
U.SubTypeT sb = obj.setVal(val);
4242
return obj.getVal(sb);
4343
}
@@ -50,4 +50,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
5050
/* wrong error message for the following line */
5151
float outVal = test<GenStruct<float> >(inVal);
5252
outputBuffer[tid] = outVal.x;
53-
}
53+
}

tests/compute/assoctype-nested-lookup.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ interface IFoo : IDefaultInitializable
1212
};
1313

1414

15-
struct FooPair<T : IFoo> : IFoo
15+
struct FooPair<T : IFoo> : IFoo, IDefaultInitializable
1616
{
1717
T a;
1818
T.Bar b;
@@ -41,4 +41,4 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
4141
{
4242
FooPair<ConcreteFoo>.Bar pair;
4343
test(pair);
44-
}
44+
}

tests/compute/dynamic-dispatch-17.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct FloatVal : IInterface
5252
float val;
5353
float run<Z:IReturnsZero>()
5454
{
55-
Z z;
55+
Z z = Z();
5656
return val + z.get();
5757
}
5858
};
@@ -62,7 +62,7 @@ struct Float4Val : IInterface
6262
Float4Struct val;
6363
float run<Z:IReturnsZero>()
6464
{
65-
Z z;
65+
Z z = Z();
6666
return val.val.x + val.val.y + z.get();
6767
}
6868
};

tests/compute/empty-struct2.slang

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ struct EmptyS : IEmptyS
2525
struct Empty<TT : IEmptyS> : IInterface
2626
{
2727
typedef TT T;
28-
TT value;
29-
float a;
28+
TT value = TT();
29+
float a = 0;
3030
TT getT()
3131
{
3232
return value;
@@ -51,4 +51,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
5151
Empty<EmptyS> obj;
5252
test(obj);
5353
outputBuffer[dispatchThreadID.x] = dispatchThreadID.x;
54-
}
54+
}

tests/compute/generic-closer.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct Gen0 : IGetter
1313
};
1414
struct Gen1<TGetter : IGetter> : IGetter
1515
{
16-
TGetter g;
16+
TGetter g = TGetter();
1717
int get() { return g.get(); }
1818
};
1919

@@ -39,4 +39,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
3939
int b = 5;
4040
if (a< b && b > a)
4141
outputBuffer[dispatchThreadID.x] = (g.get() >> 1) + g2.get() + g3.get();
42-
}
42+
}

tests/compute/generic-default-arg.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct Impl2 : ITest
3131
__generic<T : ITest = Impl1>
3232
struct GenStruct
3333
{
34-
T obj;
34+
T obj = T();
3535
};
3636

3737
int test(GenStruct gs, int val)
@@ -50,4 +50,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
5050
outVal += test(gs, tid);
5151

5252
outputBuffer[tid] = outVal;
53-
}
53+
}

tests/compute/transitive-interface.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct AssocImpl : IAssoc
4747

4848
int testAdd2<T:IAssoc>(T assoc)
4949
{
50-
T.AT obj;
50+
T.AT obj = T.AT();
5151
return obj.addf(1, 1);
5252
}
5353

@@ -69,4 +69,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
6969
outVal += testSub(s1, outVal);
7070

7171
outputBuffer[dispatchThreadID.x] = outVal;
72-
}
72+
}

tests/cuda/cuda-array-layout.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ struct PadLadenStruct
1010
};
1111

1212
// This is to check if the last half can be inserted 'inside' the spare padding of a. It should not be
13-
struct StructWithArray
13+
struct StructWithArray : IDefaultInitializable
1414
{
1515
PadLadenStruct a[1];
1616
uint8_t b;

0 commit comments

Comments
 (0)