Skip to content

Commit 1a69812

Browse files
author
Tim Foley
authored
Fix atomic operations on RWBuffer (shader-slang#593)
* Fix atomic operations on RWBuffer An earlier change added support for passing true pointers to `__ref` parameters to fix the global `Interlocked*()` functions when applied to `groupshared` variables or `RWStructureBuffer<T>` elements. That change didn't apply to `RWBuffer<T>` or `RWTexture2D<T>`, etc. because those types had so far only declared `get` and `set` accessors, but not any `ref` accessors (which return a pointer). The main fixes here are: * Add `ref` accessors to the subscript oeprations on the `RW*` resource types * Adjust the logic for emitting calls to subscript accessors so that we don't get quite as eager about invoking a `ref` accessor, and instead try to invoke just a `get` or `set` accessor when these will suffice. This is important for Vulkan cross-compilation, where we don't yet support the semantics of our `ref` accessors. * Add a test case for atomics on a `RWBuffer` * Fix up `render-test` so that we can specify a format for a buffer resource, which allows us to use things other than `*StructuredBuffer` and `*ByteAddressBuffer`. The work there is probably not complete; I just did what I could to get the test working. * A bunch of files got whitespace edits thanks to the fact that I'm using editorconfig and others on the project seemingly arent... * fixup: remove ifdefed-out code
1 parent 8b16bbf commit 1a69812

17 files changed

+373
-249
lines changed

source/slang/core.meta.slang

+7-2
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ for( int C = 2; C <= 4; ++C )
326326
sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerState) << ")\n";
327327
sb << "__intrinsic_type(" << kIROp_SamplerStateType << ")\n";
328328
sb << "struct SamplerState {};";
329-
329+
330330
sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerComparisonState) << ")\n";
331331
sb << "__intrinsic_type(" << kIROp_SamplerComparisonStateType << ")\n";
332332
sb << "struct SamplerComparisonState {};";
@@ -688,6 +688,11 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt)
688688

689689
default:
690690
sb << "__target_intrinsic(glsl, \"imageStore($0, " << ivecN << "($1), $2)\") set;\n";
691+
692+
// TODO: really need a way to map access through `ref` accessor (e.g., when
693+
// used with an atomic operation) over to GLSL equivalent.
694+
//
695+
sb << "ref;\n";
691696
break;
692697
}
693698

@@ -911,7 +916,7 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt)
911916
auto componentName = kk.componentName;
912917

913918
EMIT_LINE_DIRECTIVE();
914-
919+
915920
sb << "__target_intrinsic(glsl, \"textureGather($p, $2, " << componentIndex << ")\")\n";
916921
sb << "vector<T, 4> Gather" << componentName << "(SamplerState s, ";
917922
sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n";

source/slang/core.meta.slang.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ for( int C = 2; C <= 4; ++C )
341341
sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerState) << ")\n";
342342
sb << "__intrinsic_type(" << kIROp_SamplerStateType << ")\n";
343343
sb << "struct SamplerState {};";
344-
344+
345345
sb << "__magic_type(SamplerState," << int(SamplerStateFlavor::SamplerComparisonState) << ")\n";
346346
sb << "__intrinsic_type(" << kIROp_SamplerComparisonStateType << ")\n";
347347
sb << "struct SamplerComparisonState {};";
@@ -703,6 +703,11 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt)
703703

704704
default:
705705
sb << "__target_intrinsic(glsl, \"imageStore($0, " << ivecN << "($1), $2)\") set;\n";
706+
707+
// TODO: really need a way to map access through `ref` accessor (e.g., when
708+
// used with an atomic operation) over to GLSL equivalent.
709+
//
710+
sb << "ref;\n";
706711
break;
707712
}
708713

@@ -926,7 +931,7 @@ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt)
926931
auto componentName = kk.componentName;
927932

928933
EMIT_LINE_DIRECTIVE();
929-
934+
930935
sb << "__target_intrinsic(glsl, \"textureGather($p, $2, " << componentIndex << ")\")\n";
931936
sb << "vector<T, 4> Gather" << componentName << "(SamplerState s, ";
932937
sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n";

source/slang/emit.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -2363,15 +2363,18 @@ struct EmitVisitor
23632363
// for temporary variables.
23642364
auto type = inst->getDataType();
23652365

2366-
// First we unwrap any layers of pointer-ness and array-ness
2367-
// from the types to get at the underlying data type.
2368-
while (auto ptrType = as<IRPtrTypeBase>(type))
2366+
// Unwrap any layers of array-ness from the type, so that
2367+
// we can look at the underlying data type, in case we
2368+
// should *never* expose a value of that type
2369+
while (auto arrayType = as<IRArrayTypeBase>(type))
23692370
{
2370-
type = ptrType->getValueType();
2371+
type = arrayType->getElementType();
23712372
}
2372-
while (auto ptrType = as<IRArrayTypeBase>(type))
2373+
2374+
// Don't allow temporaries of pointer types to be created.
2375+
if(as<IRPtrTypeBase>(type))
23732376
{
2374-
type = ptrType->getElementType();
2377+
return true;
23752378
}
23762379

23772380
// First we check for uniform parameter groups,

source/slang/hlsl.meta.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ for (int aa = 0; aa < kBaseBufferAccessLevelCount; ++aa)
11441144

11451145
if (kBaseBufferAccessLevels[aa].access != SLANG_RESOURCE_ACCESS_READ)
11461146
{
1147-
sb << "set;\n";
1147+
sb << "ref;\n";
11481148
}
11491149

11501150
sb << "}\n";

source/slang/hlsl.meta.slang.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,7 @@ for (int aa = 0; aa < kBaseBufferAccessLevelCount; ++aa)
11771177

11781178
if (kBaseBufferAccessLevels[aa].access != SLANG_RESOURCE_ACCESS_READ)
11791179
{
1180-
sb << "set;\n";
1180+
sb << "ref;\n";
11811181
}
11821182

11831183
sb << "}\n";

source/slang/lower-to-ir.cpp

+96-34
Original file line numberDiff line numberDiff line change
@@ -573,33 +573,8 @@ LoweredValInfo emitCallToDeclRef(
573573
bool justAGetter = true;
574574
for (auto accessorDeclRef : getMembersOfType<AccessorDecl>(subscriptDeclRef))
575575
{
576-
// If the subscript declares a `ref` accessor, then we can just
577-
// invoke that directly to get an l-value we can use.
578-
if(auto refAccessorDeclRef = accessorDeclRef.As<RefAccessorDecl>())
579-
{
580-
// The `ref` accessor will return a pointer to the value, so
581-
// we need to reflect that in the type of our `call` instruction.
582-
IRType* ptrType = context->irBuilder->getPtrType(type);
583-
584-
// Rather than call `emitCallToVal` here, we make a recursive call
585-
// to `emitCallToDeclRef` so that it can handle things like intrinsic-op
586-
// modifiers attached to the acecssor.
587-
LoweredValInfo callVal = emitCallToDeclRef(
588-
context,
589-
ptrType,
590-
refAccessorDeclRef,
591-
funcType,
592-
argCount,
593-
args);
594-
595-
// The result from the call needs to be implicitly dereferenced,
596-
// so that it can work as an l-value of the desired result type.
597-
return LoweredValInfo::ptr(getSimpleVal(context, callVal));
598-
}
599-
600-
// If we don't find a `ref` accessor, then we want to track whether
601-
// this subscript has any accessors other than `get` (assuming
602-
// that everything except `get` can be used for setting...).
576+
// We want to track whether this subscript has any accessors other than
577+
// `get` (assuming that everything except `get` can be used for setting...).
603578

604579
if (auto foundGetterDeclRef = accessorDeclRef.As<GetterDecl>())
605580
{
@@ -821,6 +796,15 @@ LoweredValInfo materialize(
821796
{
822797
auto boundSubscriptInfo = lowered.getBoundSubscriptInfo();
823798

799+
// We are being asked to extract a value from a subscript call
800+
// (e.g., `base[index]`). We will first check if the subscript
801+
// declared a getter and use that if possible, and then fall
802+
// back to a `ref` accessor if one is defined.
803+
//
804+
// (Picking the `get` over the `ref` accessor simplifies things
805+
// in case the `get` operation has a natural translation for
806+
// a target, while the general `ref` case does not...)
807+
824808
auto getters = getMembersOfType<GetterDecl>(boundSubscriptInfo->declRef);
825809
if (getters.Count())
826810
{
@@ -833,6 +817,27 @@ LoweredValInfo materialize(
833817
goto top;
834818
}
835819

820+
auto refAccessors = getMembersOfType<RefAccessorDecl>(boundSubscriptInfo->declRef);
821+
if(refAccessors.Count())
822+
{
823+
// The `ref` accessor will return a pointer to the value, so
824+
// we need to reflect that in the type of our `call` instruction.
825+
IRType* ptrType = context->irBuilder->getPtrType(boundSubscriptInfo->type);
826+
827+
LoweredValInfo refVal = emitCallToDeclRef(
828+
context,
829+
ptrType,
830+
*refAccessors.begin(),
831+
nullptr,
832+
boundSubscriptInfo->args);
833+
834+
// The result from the call needs to be implicitly dereferenced,
835+
// so that it can work as an l-value of the desired result type.
836+
lowered = LoweredValInfo::ptr(getSimpleVal(context, refVal));
837+
838+
goto top;
839+
}
840+
836841
SLANG_UNEXPECTED("subscript had no getter");
837842
UNREACHABLE_RETURN(LoweredValInfo());
838843
}
@@ -1308,7 +1313,7 @@ static void addNameHint(
13081313
{
13091314
Name* name = getNameForNameHint(context, decl);
13101315
if(!name)
1311-
return;
1316+
return;
13121317
context->irBuilder->addDecoration<IRNameHintDecoration>(inst)->name = name;
13131318
}
13141319

@@ -2958,18 +2963,51 @@ IRInst* getAddress(
29582963
SourceLoc diagnosticLocation)
29592964
{
29602965
LoweredValInfo val = inVal;
2966+
29612967
switch(val.flavor)
29622968
{
29632969
case LoweredValInfo::Flavor::Ptr:
29642970
return val.val;
29652971

2966-
// TODO: are there other cases we need to handle here (e.g.,
2967-
// turning a bound subscript/property into an address)
2972+
case LoweredValInfo::Flavor::BoundSubscript:
2973+
{
2974+
// If we are are trying to turn a subscript operation like `buffer[index]`
2975+
// into a pointer, then we need to find a `ref` accessor declared
2976+
// as part of the subscript operation being referenced.
2977+
//
2978+
auto subscriptInfo = val.getBoundSubscriptInfo();
2979+
auto refAccessors = getMembersOfType<RefAccessorDecl>(subscriptInfo->declRef);
2980+
if(refAccessors.Count())
2981+
{
2982+
// The `ref` accessor will return a pointer to the value, so
2983+
// we need to reflect that in the type of our `call` instruction.
2984+
IRType* ptrType = context->irBuilder->getPtrType(subscriptInfo->type);
2985+
2986+
LoweredValInfo refVal = emitCallToDeclRef(
2987+
context,
2988+
ptrType,
2989+
*refAccessors.begin(),
2990+
nullptr,
2991+
subscriptInfo->args);
2992+
2993+
// The result from the call should be a pointer, and it
2994+
// is the address that we wanted in the first place.
2995+
return getSimpleVal(context, refVal);
2996+
}
2997+
2998+
// Otherwise, there was no `ref` accessor, and so it is not possible
2999+
// to materialize this location into a pointer for whatever purpose
3000+
// we have in mind (e.g., passing it to an atomic operation).
3001+
}
3002+
3003+
// TODO: are there other cases we need to handled here?
29683004

29693005
default:
2970-
context->getSink()->diagnose(diagnosticLocation, Diagnostics::invalidLValueForRefParameter);
2971-
return nullptr;
3006+
break;
29723007
}
3008+
3009+
context->getSink()->diagnose(diagnosticLocation, Diagnostics::invalidLValueForRefParameter);
3010+
return nullptr;
29733011
}
29743012

29753013
void assign(
@@ -3090,15 +3128,17 @@ void assign(
30903128
// `someStructuredBuffer[index]`.
30913129
//
30923130
// When storing to such a value, we need to emit a call
3093-
// to the appropriate builtin "setter" accessor.
3131+
// to the appropriate builtin "setter" accessor, if there
3132+
// is one, and then fall back to a `ref` accessor if
3133+
// there is no setter.
3134+
//
30943135
auto subscriptInfo = left.getBoundSubscriptInfo();
30953136

30963137
// Search for an appropriate "setter" declaration
30973138
auto setters = getMembersOfType<SetterDecl>(subscriptInfo->declRef);
30983139
if (setters.Count())
30993140
{
31003141
auto allArgs = subscriptInfo->args;
3101-
31023142
addArgs(context, &allArgs, right);
31033143

31043144
emitCallToDeclRef(
@@ -3110,6 +3150,28 @@ void assign(
31103150
return;
31113151
}
31123152

3153+
auto refAccessors = getMembersOfType<RefAccessorDecl>(subscriptInfo->declRef);
3154+
if(refAccessors.Count())
3155+
{
3156+
// The `ref` accessor will return a pointer to the value, so
3157+
// we need to reflect that in the type of our `call` instruction.
3158+
IRType* ptrType = context->irBuilder->getPtrType(subscriptInfo->type);
3159+
3160+
LoweredValInfo refVal = emitCallToDeclRef(
3161+
context,
3162+
ptrType,
3163+
*refAccessors.begin(),
3164+
nullptr,
3165+
subscriptInfo->args);
3166+
3167+
// The result from the call needs to be implicitly dereferenced,
3168+
// so that it can work as an l-value of the desired result type.
3169+
left = LoweredValInfo::ptr(getSimpleVal(context, refVal));
3170+
3171+
// Tail-recursively attempt assignment again on the new l-value.
3172+
goto top;
3173+
}
3174+
31133175
// No setter found? Then we have an error!
31143176
SLANG_UNEXPECTED("no setter found");
31153177
break;

tests/compute/atomics-buffer.slang

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// atomics-buffer.slang
2+
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
4+
5+
// Note: not enabling D3D12 test yet because change
6+
// was developed on a machine that can run D3D12
7+
//
8+
//TEST_DISABLED(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12
9+
10+
//TEST_INPUT:ubuffer(format=R_UInt32, data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]):dxbinding(0),glbinding(0),out
11+
12+
RWBuffer<uint> outputBuffer;
13+
14+
void test(uint val)
15+
{
16+
uint originalValue;
17+
18+
InterlockedAdd(outputBuffer[val], val, originalValue);
19+
InterlockedAdd(outputBuffer[val ^ 1], val*16, originalValue);
20+
InterlockedAdd(outputBuffer[val ^ 2], val*16*16, originalValue);
21+
}
22+
23+
[numthreads(4, 1, 1)]
24+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
25+
{
26+
uint tid = dispatchThreadID.x;
27+
test(tid);
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
210
2+
301
3+
32
4+
123

tools/render-test/d3d-util.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using namespace Slang;
3131
case Format::RG_Float32: return DXGI_FORMAT_R32G32_FLOAT;
3232
case Format::R_Float32: return DXGI_FORMAT_R32_FLOAT;
3333
case Format::RGBA_Unorm_UInt8: return DXGI_FORMAT_R8G8B8A8_UNORM;
34+
case Format::R_UInt32: return DXGI_FORMAT_R32_UINT;
3435

3536
case Format::D_Float32: return DXGI_FORMAT_D32_FLOAT;
3637
case Format::D_Unorm24_S8: return DXGI_FORMAT_D24_UNORM_S8_UINT;
@@ -47,7 +48,8 @@ using namespace Slang;
4748
switch (format)
4849
{
4950
case DXGI_FORMAT_R32_FLOAT: /* fallthru */
50-
case DXGI_FORMAT_D32_FLOAT:
51+
case DXGI_FORMAT_R32_UINT:
52+
case DXGI_FORMAT_D32_FLOAT:
5153
{
5254
return DXGI_FORMAT_R32_TYPELESS;
5355
}
@@ -73,7 +75,7 @@ using namespace Slang;
7375
switch (format)
7476
{
7577
case DXGI_FORMAT_D32_FLOAT: /* fallthru */
76-
case DXGI_FORMAT_R32_TYPELESS:
78+
case DXGI_FORMAT_R32_TYPELESS:
7779
{
7880
return DXGI_FORMAT_D32_FLOAT;
7981
}
@@ -88,7 +90,7 @@ using namespace Slang;
8890
switch (format)
8991
{
9092
case DXGI_FORMAT_D32_FLOAT: /* fallthru */
91-
case DXGI_FORMAT_D24_UNORM_S8_UINT:
93+
case DXGI_FORMAT_D24_UNORM_S8_UINT:
9294
{
9395
return DXGI_FORMAT_UNKNOWN;
9496
}
@@ -102,7 +104,7 @@ using namespace Slang;
102104
switch (format)
103105
{
104106
case DXGI_FORMAT_D32_FLOAT: /* fallthru */
105-
case DXGI_FORMAT_R32_TYPELESS:
107+
case DXGI_FORMAT_R32_TYPELESS:
106108
{
107109
return DXGI_FORMAT_R32_FLOAT;
108110
}

0 commit comments

Comments
 (0)