Skip to content

Commit e7a8332

Browse files
author
Tim Foley
authored
Fix global atomic functions (shader-slang#582)
Fixes shader-slang#581 This change adds a new parameter passing mode `__ref` to exist alongisde `in`, `out`, and `inout`. The `__ref` modifier indicates true by-reference parameter passing (whereas `inout` is copy-in-copy-out). This is not intended to be something that users interact with directly, but rather a low-level feature that lets us provide a correct signature for the `Interlocked*()` operations in the standard library. Most of the support for passing what are logically addresses around already exists in the IR, so the majority of the work here is just in introducing the new type `Ref<T>` and then using it appropriately when lowering `__ref` parameters/arguments to the IR.
1 parent ace9a8d commit e7a8332

21 files changed

+221
-38
lines changed

source/slang/check.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -3011,6 +3011,11 @@ namespace Slang
30113011
// because there is no way for overload resolution to pick between them.
30123012
if (fstParam.getDecl()->HasModifier<OutModifier>() != sndParam.getDecl()->HasModifier<OutModifier>())
30133013
return false;
3014+
3015+
// If one parameter is `ref` and the other isn't, then they don't match.
3016+
//
3017+
if(fstParam.getDecl()->HasModifier<RefModifier>() != sndParam.getDecl()->HasModifier<RefModifier>())
3018+
return false;
30143019
}
30153020

30163021
// Note(tfoley): return type doesn't enter into it, because we can't take
@@ -7046,8 +7051,15 @@ namespace Slang
70467051
for (UInt pp = 0; pp < paramCount; ++pp)
70477052
{
70487053
auto paramType = funcType->getParamType(pp);
7049-
if (auto outParamType = paramType->As<OutTypeBase>())
7054+
if (paramType->As<OutTypeBase>() || paramType->As<RefType>())
70507055
{
7056+
// `out`, `inout`, and `ref` parameters currently require
7057+
// an *exact* match on the type of the argument.
7058+
//
7059+
// TODO: relax this requirement by allowing an argument
7060+
// for an `inout` parameter to be converted in both
7061+
// directions.
7062+
//
70517063
if( pp < expr->Arguments.Count() )
70527064
{
70537065
auto argExpr = expr->Arguments[pp];

source/slang/compiler.h

+3
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,9 @@ namespace Slang
520520
// Construct the type `InOut<valueType>`
521521
RefPtr<InOutType> getInOutType(RefPtr<Type> valueType);
522522

523+
// Construct the type `Ref<valueType>`
524+
RefPtr<RefType> getRefType(RefPtr<Type> valueType);
525+
523526
// Construct a pointer type like `Ptr<valueType>`, but where
524527
// the actual type name for the pointer type is given by `ptrTypeName`
525528
RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, char const* ptrTypeName);

source/slang/core.meta.slang

+6
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ __intrinsic_type($(kIROp_InOutType))
117117
struct InOut
118118
{};
119119

120+
__generic<T>
121+
__magic_type(RefType)
122+
__intrinsic_type($(kIROp_RefType))
123+
struct Ref
124+
{};
125+
120126
__magic_type(StringType)
121127
__intrinsic_type($(kIROp_StringType))
122128
struct String

source/slang/core.meta.slang.h

+9
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ SLANG_RAW(")\n")
126126
SLANG_RAW("struct InOut\n")
127127
SLANG_RAW("{};\n")
128128
SLANG_RAW("\n")
129+
SLANG_RAW("__generic<T>\n")
130+
SLANG_RAW("__magic_type(RefType)\n")
131+
SLANG_RAW("__intrinsic_type(")
132+
SLANG_SPLICE(kIROp_RefType
133+
)
134+
SLANG_RAW(")\n")
135+
SLANG_RAW("struct Ref\n")
136+
SLANG_RAW("{};\n")
137+
SLANG_RAW("\n")
129138
SLANG_RAW("__magic_type(StringType)\n")
130139
SLANG_RAW("__intrinsic_type(")
131140
SLANG_SPLICE(kIROp_StringType

source/slang/diagnostic-defs.h

+2
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ DIAGNOSTIC(40006, Error, needCompileTimeConstant, "expected a compile-time const
311311

312312
DIAGNOSTIC(40007, Internal, irValidationFailed, "IR validation failed: $0")
313313

314+
DIAGNOSTIC(40008, Error, invalidLValueForRefParameter, "the form of this l-value argument is not valid for a `ref` parameter")
315+
314316
// 41000 - IR-level validation issues
315317

316318
DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected")

source/slang/emit.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -4408,6 +4408,13 @@ struct EmitVisitor
44084408
emit("inout ");
44094409
type = inOutType->getValueType();
44104410
}
4411+
else if( auto refType = as<IRRefType>(type))
4412+
{
4413+
// Note: There is no HLSL/GLSL equivalent for by-reference parameters,
4414+
// so we don't actually expect to encounter these in user code.
4415+
emit("inout ");
4416+
type = inOutType->getValueType();
4417+
}
44114418

44124419
emitIRType(ctx, type, name);
44134420
}

source/slang/hlsl.meta.slang

+16-16
Original file line numberDiff line numberDiff line change
@@ -617,32 +617,32 @@ __target_intrinsic(glsl, "groupMemoryBarrier()); (barrier()")
617617
void GroupMemoryBarrierWithGroupSync();
618618

619619
// Atomics
620-
void InterlockedAdd(in out int dest, int value, out int original_value);
621-
void InterlockedAdd(in out uint dest, uint value, out uint original_value);
620+
void InterlockedAdd(__ref int dest, int value, out int original_value);
621+
void InterlockedAdd(__ref uint dest, uint value, out uint original_value);
622622

623-
void InterlockedAnd(in out int dest, int value, out int original_value);
624-
void InterlockedAnd(in out uint dest, uint value, out uint original_value);
623+
void InterlockedAnd(__ref int dest, int value, out int original_value);
624+
void InterlockedAnd(__ref uint dest, uint value, out uint original_value);
625625

626-
void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value);
627-
void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value);
626+
void InterlockedCompareExchange(__ref int dest, int compare_value, int value, out int original_value);
627+
void InterlockedCompareExchange(__ref uint dest, uint compare_value, uint value, out uint original_value);
628628

629-
void InterlockedCompareStore(in out int dest, int compare_value, int value);
630-
void InterlockedCompareStore(in out uint dest, uint compare_value, uint value);
629+
void InterlockedCompareStore(__ref int dest, int compare_value, int value);
630+
void InterlockedCompareStore(__ref uint dest, uint compare_value, uint value);
631631

632-
void InterlockedExchange(in out int dest, int value, out int original_value);
633-
void InterlockedExchange(in out uint dest, uint value, out uint original_value);
632+
void InterlockedExchange(__ref int dest, int value, out int original_value);
633+
void InterlockedExchange(__ref uint dest, uint value, out uint original_value);
634634

635-
void InterlockedMax(in out int dest, int value, out int original_value);
636-
void InterlockedMax(in out uint dest, uint value, out uint original_value);
635+
void InterlockedMax(__ref int dest, int value, out int original_value);
636+
void InterlockedMax(__ref uint dest, uint value, out uint original_value);
637637

638638
void InterlockedMin(in out int dest, int value, out int original_value);
639639
void InterlockedMin(in out uint dest, uint value, out uint original_value);
640640

641-
void InterlockedOr(in out int dest, int value, out int original_value);
642-
void InterlockedOr(in out uint dest, uint value, out uint original_value);
641+
void InterlockedOr(__ref int dest, int value, out int original_value);
642+
void InterlockedOr(__ref uint dest, uint value, out uint original_value);
643643

644-
void InterlockedXor(in out int dest, int value, out int original_value);
645-
void InterlockedXor(in out uint dest, uint value, out uint original_value);
644+
void InterlockedXor(__ref int dest, int value, out int original_value);
645+
void InterlockedXor(__ref uint dest, uint value, out uint original_value);
646646

647647
// Is floating-point value finite?
648648
__generic<T : __BuiltinFloatingPointType> bool isfinite(T x);

source/slang/hlsl.meta.slang.h

+16-16
Original file line numberDiff line numberDiff line change
@@ -650,32 +650,32 @@ SLANG_RAW("__target_intrinsic(glsl, \"groupMemoryBarrier()); (barrier()\")\n")
650650
SLANG_RAW("void GroupMemoryBarrierWithGroupSync();\n")
651651
SLANG_RAW("\n")
652652
SLANG_RAW("// Atomics\n")
653-
SLANG_RAW("void InterlockedAdd(in out int dest, int value, out int original_value);\n")
654-
SLANG_RAW("void InterlockedAdd(in out uint dest, uint value, out uint original_value);\n")
653+
SLANG_RAW("void InterlockedAdd(__ref int dest, int value, out int original_value);\n")
654+
SLANG_RAW("void InterlockedAdd(__ref uint dest, uint value, out uint original_value);\n")
655655
SLANG_RAW("\n")
656-
SLANG_RAW("void InterlockedAnd(in out int dest, int value, out int original_value);\n")
657-
SLANG_RAW("void InterlockedAnd(in out uint dest, uint value, out uint original_value);\n")
656+
SLANG_RAW("void InterlockedAnd(__ref int dest, int value, out int original_value);\n")
657+
SLANG_RAW("void InterlockedAnd(__ref uint dest, uint value, out uint original_value);\n")
658658
SLANG_RAW("\n")
659-
SLANG_RAW("void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value);\n")
660-
SLANG_RAW("void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value);\n")
659+
SLANG_RAW("void InterlockedCompareExchange(__ref int dest, int compare_value, int value, out int original_value);\n")
660+
SLANG_RAW("void InterlockedCompareExchange(__ref uint dest, uint compare_value, uint value, out uint original_value);\n")
661661
SLANG_RAW("\n")
662-
SLANG_RAW("void InterlockedCompareStore(in out int dest, int compare_value, int value);\n")
663-
SLANG_RAW("void InterlockedCompareStore(in out uint dest, uint compare_value, uint value);\n")
662+
SLANG_RAW("void InterlockedCompareStore(__ref int dest, int compare_value, int value);\n")
663+
SLANG_RAW("void InterlockedCompareStore(__ref uint dest, uint compare_value, uint value);\n")
664664
SLANG_RAW("\n")
665-
SLANG_RAW("void InterlockedExchange(in out int dest, int value, out int original_value);\n")
666-
SLANG_RAW("void InterlockedExchange(in out uint dest, uint value, out uint original_value);\n")
665+
SLANG_RAW("void InterlockedExchange(__ref int dest, int value, out int original_value);\n")
666+
SLANG_RAW("void InterlockedExchange(__ref uint dest, uint value, out uint original_value);\n")
667667
SLANG_RAW("\n")
668-
SLANG_RAW("void InterlockedMax(in out int dest, int value, out int original_value);\n")
669-
SLANG_RAW("void InterlockedMax(in out uint dest, uint value, out uint original_value);\n")
668+
SLANG_RAW("void InterlockedMax(__ref int dest, int value, out int original_value);\n")
669+
SLANG_RAW("void InterlockedMax(__ref uint dest, uint value, out uint original_value);\n")
670670
SLANG_RAW("\n")
671671
SLANG_RAW("void InterlockedMin(in out int dest, int value, out int original_value);\n")
672672
SLANG_RAW("void InterlockedMin(in out uint dest, uint value, out uint original_value);\n")
673673
SLANG_RAW("\n")
674-
SLANG_RAW("void InterlockedOr(in out int dest, int value, out int original_value);\n")
675-
SLANG_RAW("void InterlockedOr(in out uint dest, uint value, out uint original_value);\n")
674+
SLANG_RAW("void InterlockedOr(__ref int dest, int value, out int original_value);\n")
675+
SLANG_RAW("void InterlockedOr(__ref uint dest, uint value, out uint original_value);\n")
676676
SLANG_RAW("\n")
677-
SLANG_RAW("void InterlockedXor(in out int dest, int value, out int original_value);\n")
678-
SLANG_RAW("void InterlockedXor(in out uint dest, uint value, out uint original_value);\n")
677+
SLANG_RAW("void InterlockedXor(__ref int dest, int value, out int original_value);\n")
678+
SLANG_RAW("void InterlockedXor(__ref uint dest, uint value, out uint original_value);\n")
679679
SLANG_RAW("\n")
680680
SLANG_RAW("// Is floating-point value finite?\n")
681681
SLANG_RAW("__generic<T : __BuiltinFloatingPointType> bool isfinite(T x);\n")

source/slang/ir-inst-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ INST(Nop, nop, 0, 0)
6565

6666
/* PtrTypeBase */
6767
INST(PtrType, Ptr, 1, 0)
68+
INST(RefType, Ref, 1, 0)
6869
/* OutTypeBase */
6970
INST(OutType, Out, 1, 0)
7071
INST(InOutType, InOut, 1, 0)

source/slang/ir-insts.h

+1
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ struct IRBuilder
559559
IRPtrType* getPtrType(IRType* valueType);
560560
IROutType* getOutType(IRType* valueType);
561561
IRInOutType* getInOutType(IRType* valueType);
562+
IRRefType* getRefType(IRType* valueType);
562563
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
563564

564565
IRArrayTypeBase* getArrayTypeBase(

source/slang/ir.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,11 @@ namespace Slang
13081308
return (IRInOutType*) getPtrType(kIROp_InOutType, valueType);
13091309
}
13101310

1311+
IRRefType* IRBuilder::getRefType(IRType* valueType)
1312+
{
1313+
return (IRRefType*) getPtrType(kIROp_RefType, valueType);
1314+
}
1315+
13111316
IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType)
13121317
{
13131318
IRInst* operands[] = { valueType };

source/slang/ir.h

+1
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,7 @@ struct IRPtrType : IRPtrTypeBase
828828
SIMPLE_IR_PARENT_TYPE(OutTypeBase, PtrTypeBase)
829829
SIMPLE_IR_TYPE(OutType, OutTypeBase)
830830
SIMPLE_IR_TYPE(InOutType, OutTypeBase)
831+
SIMPLE_IR_TYPE(RefType, OutTypeBase)
831832

832833
struct IRFuncType : IRType
833834
{

source/slang/lower-to-ir.cpp

+55-4
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,11 @@ void assign(
922922
LoweredValInfo const& left,
923923
LoweredValInfo const& right);
924924

925+
IRInst* getAddress(
926+
IRGenContext* context,
927+
LoweredValInfo const& inVal,
928+
SourceLoc diagnosticLocation);
929+
925930
void lowerStmt(
926931
IRGenContext* context,
927932
Stmt* stmt);
@@ -1668,7 +1673,24 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
16681673
// make a conscious decision at some point.
16691674
}
16701675

1671-
if (paramDecl->HasModifier<OutModifier>()
1676+
if(paramDecl->HasModifier<RefModifier>())
1677+
{
1678+
// A `ref` qualified parameter must be implemented with by-reference
1679+
// parameter passing, so the argument value should be lowered as
1680+
// an l-value.
1681+
//
1682+
LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr);
1683+
1684+
// According to our "calling convention" we need to
1685+
// pass a pointer into the callee. Unlike the case for
1686+
// `out` and `inout` below, it is never valid to do
1687+
// copy-in/copy-out for a `ref` parameter, so we just
1688+
// pass in the actual pointer.
1689+
//
1690+
IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc);
1691+
(*ioArgs).Add(argPtr);
1692+
}
1693+
else if (paramDecl->HasModifier<OutModifier>()
16721694
|| paramDecl->HasModifier<InOutModifier>())
16731695
{
16741696
// This is a `out` or `inout` parameter, and so
@@ -2930,6 +2952,26 @@ static LoweredValInfo maybeMoveMutableTemp(
29302952
}
29312953
}
29322954

2955+
IRInst* getAddress(
2956+
IRGenContext* context,
2957+
LoweredValInfo const& inVal,
2958+
SourceLoc diagnosticLocation)
2959+
{
2960+
LoweredValInfo val = inVal;
2961+
switch(val.flavor)
2962+
{
2963+
case LoweredValInfo::Flavor::Ptr:
2964+
return val.val;
2965+
2966+
// TODO: are there other cases we need to handle here (e.g.,
2967+
// turning a bound subscript/property into an address)
2968+
2969+
default:
2970+
context->getSink()->diagnose(diagnosticLocation, Diagnostics::invalidLValueForRefParameter);
2971+
return nullptr;
2972+
}
2973+
}
2974+
29332975
void assign(
29342976
IRGenContext* context,
29352977
LoweredValInfo const& inLeft,
@@ -3831,9 +3873,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
38313873
//
38323874
enum ParameterDirection
38333875
{
3834-
kParameterDirection_In,
3835-
kParameterDirection_Out,
3836-
kParameterDirection_InOut,
3876+
kParameterDirection_In, ///< Copy in
3877+
kParameterDirection_Out, ///< Copy out
3878+
kParameterDirection_InOut, ///< Copy in, copy out
3879+
kParameterDirection_Ref, ///< By-reference
38373880
};
38383881
struct ParameterInfo
38393882
{
@@ -3856,6 +3899,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
38563899
//
38573900
ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
38583901
{
3902+
if( paramDecl->HasModifier<RefModifier>() )
3903+
{
3904+
// The AST specified `ref`:
3905+
return kParameterDirection_Ref;
3906+
}
38593907
if( paramDecl->HasModifier<InOutModifier>() )
38603908
{
38613909
// The AST specified `inout`:
@@ -4350,6 +4398,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
43504398
case kParameterDirection_InOut:
43514399
irParamType = subBuilder->getInOutType(irParamType);
43524400
break;
4401+
case kParameterDirection_Ref:
4402+
irParamType = subBuilder->getRefType(irParamType);
4403+
break;
43534404

43544405
default:
43554406
SLANG_UNEXPECTED("unknown parameter direction");

source/slang/modifier-defs.h

+4
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,12 @@ SYNTAX_CLASS(RequiredGLSLVersionModifier, Modifier)
7171
FIELD(Token, versionNumberToken)
7272
END_SYNTAX_CLASS()
7373

74+
7475
SIMPLE_SYNTAX_CLASS(InOutModifier, OutModifier)
7576

77+
// `__ref` modifier for by-reference parameter passing
78+
SIMPLE_SYNTAX_CLASS(RefModifier, Modifier)
79+
7680
// This is a special sentinel modifier that gets added
7781
// to the list when we have multiple variable declarations
7882
// all sharing the same modifiers:

source/slang/parser.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -4399,6 +4399,7 @@ namespace Slang
43994399
MODIFIER(input, InputModifier);
44004400
MODIFIER(out, OutModifier);
44014401
MODIFIER(inout, InOutModifier);
4402+
MODIFIER(__ref, RefModifier);
44024403
MODIFIER(const, ConstModifier);
44034404
MODIFIER(instance, InstanceModifier);
44044405
MODIFIER(__builtin, BuiltinModifier);

source/slang/syntax.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ void Type::accept(IValVisitor* visitor, void* extra)
303303
return getPtrType(valueType, "InOutType").As<InOutType>();
304304
}
305305

306+
RefPtr<RefType> Session::getRefType(RefPtr<Type> valueType)
307+
{
308+
return getPtrType(valueType, "RefType").As<RefType>();
309+
}
310+
306311
RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName)
307312
{
308313
auto genericDecl = findMagicDecl(
@@ -2085,7 +2090,11 @@ void Type::accept(IValVisitor* visitor, void* extra)
20852090
{
20862091
auto paramDecl = paramDeclRef.getDecl();
20872092
auto paramType = GetType(paramDeclRef);
2088-
if( paramDecl->FindModifier<OutModifier>() )
2093+
if( paramDecl->FindModifier<RefModifier>() )
2094+
{
2095+
paramType = session->getRefType(paramType);
2096+
}
2097+
else if( paramDecl->FindModifier<OutModifier>() )
20892098
{
20902099
if(paramDecl->FindModifier<InOutModifier>() || paramDecl->FindModifier<InModifier>())
20912100
{

source/slang/type-defs.h

+4
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,10 @@ END_SYNTAX_CLASS()
355355
SYNTAX_CLASS(InOutType, OutTypeBase)
356356
END_SYNTAX_CLASS()
357357

358+
// The type for an `ref` parameter, e.g., `ref T`
359+
SYNTAX_CLASS(RefType, PtrTypeBase)
360+
END_SYNTAX_CLASS()
361+
358362
// A type alias of some kind (e.g., via `typedef`)
359363
SYNTAX_CLASS(NamedExpressionType, Type)
360364
DECL_FIELD(DeclRef<TypeDefDecl>, declRef)

0 commit comments

Comments
 (0)