Skip to content

Commit ba594d0

Browse files
author
Tim Foley
authored
IR: Add support for out and inout parameters (shader-slang#289)
These were already being handled a little bit, by lowering an `out T` or `inout T` function parameter in the AST to a function parameter with type `T*` in the IR, and then emiting explicit loads/stores. The HLSL emit logic, however, couldn't tell the difference between an `out` parameter, an `inout`, or a true pointer (if we ever needed to support them...). The intention (not fully implemented) was that we'd use a hierarchy of types rooted at `PtrTypeBase`: - `PtrTypeBase` - `Ptr`: "real" pointers in the C/C++ sense - `OutTypeBase`: pointers used to represent by-reference parameter passing - `OutType`: IR level type for an `out` parameter - `InOutType`: IR level type for an `inout` or `in out` parameter Actually implementing this involved: - Adding a bit more flexibility to the `Session::getPtrType` logic to allow for creating any of the concrete types above - Making the `lower-to-ir` logic create the right type for function parameters (instead of just using `PtrType`) - Making the HLSL emit logic check for the `OutType` and `InOutType` cases rather than just `PtrType` - Changing a bunch of small places in the code so that they use `PtrTypeBase` instead of `PtrType` when they should handle any of the above cases, and also make a few places check for `OutTypeBase` instead of `PtrType` or `PtrTypeBase`, when they are really trying to capture by-reference parameters - Add a test case that uses all of the different cases we care about (without these fixes, this test case generates errors from fxc because of variables being used before being initialized, becaues parameters get declared `out` that should be `inout`). A minor point here is that we are playing a bit fast and loose right now because the IR does not actually enforce any type checks. From the standpoint of the front end, `Ptr<T>`, `Out<T>`, and `InOut<T>` are all unrelated types (each is just a `struct` declared in `core.meta.slang`), but this doesn't really matter because none of these are types our current users are explicitly using. In the IR it makes perfect sense to allow `Out<T>` or `InOut<T>` as the operand of a `load` or `store` instruction (and ditto for `getFieldAddr`, etc.) - there instructions just apply to any `PtrTypeBase`. The place where this potentially gets tricky is whether an `Out<T>` can be used where a `Ptr<T>` is expected, or vice vers (e.g., can I just pass my local variable's pointer directly to an `Out<T>` function parameter? I'm going to ignore these issues for now, since the code currently works for our test case.
1 parent 54bf54b commit ba594d0

9 files changed

+144
-43
lines changed

source/slang/compiler.h

+16-1
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,24 @@ namespace Slang
442442
// Should not be used in front-end code
443443
Type* getIRBasicBlockType();
444444

445-
// Construct pointer types on-demand
445+
// Construct the type `Ptr<valueType>`, where `Ptr`
446+
// is looked up as a builtin type.
446447
RefPtr<PtrType> getPtrType(RefPtr<Type> valueType);
447448

449+
// Construct the type `Out<valueType>`
450+
RefPtr<OutType> getOutType(RefPtr<Type> valueType);
451+
452+
// Construct the type `InOut<valueType>`
453+
RefPtr<InOutType> getInOutType(RefPtr<Type> valueType);
454+
455+
// Construct a pointer type like `Ptr<valueType>`, but where
456+
// the actual type name for the pointer type is given by `ptrTypeName`
457+
RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, char const* ptrTypeName);
458+
459+
// Construct a pointer type like `Ptr<valueType>`, but where
460+
// the generic declaration for the pointer type is `genericDecl`
461+
RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl);
462+
448463
RefPtr<ArrayExpressionType> getArrayType(
449464
Type* elementType,
450465
IntVal* elementCount);

source/slang/core.meta.slang

+10
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ __magic_type(PtrType)
101101
struct Ptr
102102
{};
103103

104+
__generic<T>
105+
__magic_type(OutType)
106+
struct Out
107+
{};
108+
109+
__generic<T>
110+
__magic_type(InOutType)
111+
struct InOut
112+
{};
113+
104114
${{{{
105115

106116

source/slang/core.meta.slang.h

+10
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ sb << "__magic_type(PtrType)\n";
102102
sb << "struct Ptr\n";
103103
sb << "{};\n";
104104
sb << "\n";
105+
sb << "__generic<T>\n";
106+
sb << "__magic_type(OutType)\n";
107+
sb << "struct Out\n";
108+
sb << "{};\n";
109+
sb << "\n";
110+
sb << "__generic<T>\n";
111+
sb << "__magic_type(InOutType)\n";
112+
sb << "struct InOut\n";
113+
sb << "{};\n";
114+
sb << "\n";
105115
sb << "";
106116

107117

source/slang/emit.cpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -5946,19 +5946,15 @@ emitDeclImpl(decl, nullptr);
59465946
// encoded as a parameter of pointer type, so
59475947
// we need to decode that here.
59485948
//
5949-
if( auto ptrType = type->As<PtrType>() )
5949+
if( auto outType = type->As<OutType>() )
59505950
{
5951-
// TODO: we need a way to distinguish `out`
5952-
// from `inout`. The easiest way to do
5953-
// that might be to have each be a distinct
5954-
// sub-case of `IRPtrType` - this would also
5955-
// ensure that they can be distinguished from
5956-
// real pointers when the user means to use
5957-
// them.
5958-
59595951
emit("out ");
5960-
5961-
type = ptrType->getValueType();
5952+
type = outType->getValueType();
5953+
}
5954+
else if( auto inOutType = type->As<InOutType>() )
5955+
{
5956+
emit("inout ");
5957+
type = inOutType->getValueType();
59625958
}
59635959

59645960
emitIRType(ctx, type, name);
@@ -6595,7 +6591,7 @@ emitDeclImpl(decl, nullptr);
65956591
{
65966592
emitIRUsedType(ctx, genericType->elementType);
65976593
}
6598-
else if( auto ptrType = type->As<PtrType>() )
6594+
else if( auto ptrType = type->As<PtrTypeBase>() )
65996595
{
66006596
emitIRUsedType(ctx, ptrType->getValueType());
66016597
}

source/slang/ir.cpp

+5-7
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ namespace Slang
956956
IRInst* IRBuilder::emitLoad(
957957
IRValue* ptr)
958958
{
959-
auto ptrType = ptr->getType()->As<PtrType>();
959+
auto ptrType = ptr->getType()->As<PtrTypeBase>();
960960
if( !ptrType )
961961
{
962962
// Bad!
@@ -2849,12 +2849,10 @@ namespace Slang
28492849
builder.curBlock = firstBlock;
28502850
builder.insertBeforeInst = firstBlock->getFirstInst();
28512851

2852-
// TODO: We need to distinguish any true pointers in the
2853-
// user's code from pointers that only exist for
2854-
// parameter-passing. This `PtrType` here should actually
2855-
// be `OutTypeBase`, but I'm not confident that all
2856-
// the other code is handling that correctly...
2857-
if(auto paramPtrType = paramType->As<PtrType>() )
2852+
// Is the parameter type a special pointer type
2853+
// that indicates the parameter is used for `out`
2854+
// or `inout` access?
2855+
if(auto paramPtrType = paramType->As<OutTypeBase>() )
28582856
{
28592857
// Okay, we have the more interesting case here,
28602858
// where the parameter was being passed by reference.

source/slang/lower-to-ir.cpp

+19-19
Original file line numberDiff line numberDiff line change
@@ -819,14 +819,6 @@ IRType* getIntType(
819819
return context->getSession()->getBuiltinType(BaseType::Int);
820820
}
821821

822-
// Get a pointer type to the given element type
823-
RefPtr<PtrType> getPtrType(
824-
IRGenContext* context,
825-
IRType* valueType)
826-
{
827-
return context->getSession()->getPtrType(valueType);
828-
}
829-
830822
RefPtr<IRFuncType> getFuncType(
831823
IRGenContext* context,
832824
UInt paramCount,
@@ -1089,7 +1081,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
10891081
RefPtr<Type> loweredBaseType = loweredBaseVal->getType();
10901082

10911083
if (loweredBaseType->As<PointerLikeType>()
1092-
|| loweredBaseType->As<PtrType>())
1084+
|| loweredBaseType->As<PtrTypeBase>())
10931085
{
10941086
// Note that we do *not* perform an actual `load` operation
10951087
// here, but rather just use the pointer value to construct
@@ -1461,7 +1453,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
14611453
case LoweredValInfo::Flavor::Ptr:
14621454
return LoweredValInfo::ptr(
14631455
builder->emitElementAddress(
1464-
getPtrType(context, getSimpleType(type)),
1456+
context->getSession()->getPtrType(getSimpleType(type)),
14651457
baseVal.val,
14661458
indexVal));
14671459

@@ -1498,7 +1490,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
14981490
IRValue* irBasePtr = base.val;
14991491
return LoweredValInfo::ptr(
15001492
getBuilder()->emitFieldAddress(
1501-
getPtrType(context, getSimpleType(fieldType)),
1493+
context->getSession()->getPtrType(getSimpleType(fieldType)),
15021494
irBasePtr,
15031495
getBuilder()->getDeclRefVal(field)));
15041496
}
@@ -3114,14 +3106,22 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
31143106
paramTypes.Add(irParamType);
31153107
break;
31163108

3117-
default:
3118-
// The parameter is being used for input/output purposes,
3119-
// so it will lower to an actual parameter with a pointer type.
3120-
//
3121-
// TODO: Is this the best representation we can use?
3109+
// If the parameter is declared `out` or `inout`,
3110+
// then we will represent it with a pointer type in
3111+
// the IR, but we will use a specialized pointer
3112+
// type that encodes the parameter direction information.
3113+
case kParameterDirection_Out:
3114+
paramTypes.Add(
3115+
context->getSession()->getOutType(irParamType));
3116+
break;
3117+
case kParameterDirection_InOut:
3118+
paramTypes.Add(
3119+
context->getSession()->getInOutType(irParamType));
3120+
break;
31223121

3123-
auto irPtrType = getPtrType(context, irParamType);
3124-
paramTypes.Add(irPtrType);
3122+
default:
3123+
SLANG_UNEXPECTED("unknown parameter direction");
3124+
break;
31253125
}
31263126
}
31273127

@@ -3190,7 +3190,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
31903190
//
31913191
// TODO: Is this the best representation we can use?
31923192

3193-
auto irPtrType = irParamType.As<PtrType>();
3193+
auto irPtrType = irParamType.As<PtrTypeBase>();
31943194

31953195
IRParam* irParamPtr = subBuilder->emitParam(irPtrType);
31963196
if(auto paramDecl = paramInfo.decl)

source/slang/syntax.cpp

+25-4
Original file line numberDiff line numberDiff line change
@@ -279,20 +279,41 @@ void Type::accept(IValVisitor* visitor, void* extra)
279279

280280
RefPtr<PtrType> Session::getPtrType(
281281
RefPtr<Type> valueType)
282+
{
283+
return getPtrType(valueType, "PtrType").As<PtrType>();
284+
}
285+
286+
// Construct the type `Out<valueType>`
287+
RefPtr<OutType> Session::getOutType(RefPtr<Type> valueType)
288+
{
289+
return getPtrType(valueType, "OutType").As<OutType>();
290+
}
291+
292+
RefPtr<InOutType> Session::getInOutType(RefPtr<Type> valueType)
293+
{
294+
return getPtrType(valueType, "InOutType").As<InOutType>();
295+
}
296+
297+
RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName)
282298
{
283299
auto genericDecl = findMagicDecl(
284-
this, "PtrType").As<GenericDecl>();
300+
this, ptrTypeName).As<GenericDecl>();
301+
return getPtrType(valueType, genericDecl);
302+
}
303+
304+
RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl)
305+
{
285306
auto typeDecl = genericDecl->inner;
286-
307+
287308
auto substitutions = new GenericSubstitution();
288-
substitutions->genericDecl = genericDecl.Ptr();
309+
substitutions->genericDecl = genericDecl;
289310
substitutions->args.Add(valueType);
290311

291312
auto declRef = DeclRef<Decl>(typeDecl.Ptr(), substitutions);
292313

293314
return DeclRefType::Create(
294315
this,
295-
declRef)->As<PtrType>();
316+
declRef)->As<PtrTypeBase>();
296317
}
297318

298319
RefPtr<ArrayExpressionType> Session::getArrayType(

tests/compute/inout.slang

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(0),glbinding(0),out
3+
4+
// Test that we correctly support both `out`
5+
// and `inout` function parameters.
6+
7+
void testOut(int x, out int y)
8+
{
9+
y = x;
10+
}
11+
12+
void testInOut(int x, in out int y)
13+
{
14+
y = y + x;
15+
}
16+
17+
void testInout(int x, inout int y)
18+
{
19+
y = y + x;
20+
}
21+
22+
int test(int inVal)
23+
{
24+
int x0 = inVal;
25+
int x1;
26+
27+
testOut(x0, x1);
28+
29+
int x2 = x0;
30+
testInOut(x1, x2);
31+
32+
int x3 = x0;
33+
testInout(x2, x3);
34+
35+
return x3;
36+
}
37+
38+
RWStructuredBuffer<int> outputBuffer : register(u0);
39+
40+
[numthreads(4, 1, 1)]
41+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
42+
{
43+
uint tid = dispatchThreadID.x;
44+
int inVal = outputBuffer[tid];
45+
int outVal = test(inVal);
46+
outputBuffer[tid] = outVal;
47+
}
+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
0
2+
3
3+
6
4+
9

0 commit comments

Comments
 (0)