Skip to content

Commit 5fa35fc

Browse files
authored
WGSL: Enable load & store from byte-addressible buffers (shader-slang#5252)
1 parent 0ff779b commit 5fa35fc

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

source/slang/hlsl.meta.slang

+5-5
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct AppendStructuredBuffer
107107
/// @category buffer_types
108108
__magic_type(HLSLByteAddressBufferType)
109109
__intrinsic_type($(kIROp_HLSLByteAddressBufferType))
110-
[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer)]
110+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer)]
111111
struct ByteAddressBuffer
112112
{
113113
[__readNone]
@@ -4388,15 +4388,15 @@ uint64_t __asuint64(uint2 i)
43884388
//
43894389

43904390
__intrinsic_op($(kIROp_ByteAddressBufferLoad))
4391-
[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer)]
4391+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer)]
43924392
T __byteAddressBufferLoad<T>(ByteAddressBuffer buffer, int offset, int alignment);
43934393

43944394
__intrinsic_op($(kIROp_ByteAddressBufferLoad))
4395-
[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)]
4395+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)]
43964396
T __byteAddressBufferLoad<T>(RWByteAddressBuffer buffer, int offset, int alignment);
43974397

43984398
__intrinsic_op($(kIROp_ByteAddressBufferLoad))
4399-
[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)]
4399+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)]
44004400
T __byteAddressBufferLoad<T>(RasterizerOrderedByteAddressBuffer buffer, int offset, int alignment);
44014401

44024402
__intrinsic_op($(kIROp_ByteAddressBufferStore))
@@ -4583,7 +4583,7 @@ struct $(item.name)
45834583

45844584
[__NoSideEffect]
45854585
[ForceInline]
4586-
[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)]
4586+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)]
45874587
uint Load(int location)
45884588
{
45894589
__target_switch

source/slang/slang-emit-wgsl.cpp

+38-2
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
348348
}
349349
break;
350350

351+
case kIROp_HLSLByteAddressBufferType:
352+
case kIROp_HLSLRWByteAddressBufferType:
353+
{
354+
m_writer->emit("array<u32>");
355+
}
356+
break;
357+
351358
case kIROp_VoidType:
352359
{
353360
// There is no void type in WGSL.
@@ -590,13 +597,15 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, IRInst* varDecl)
590597
m_writer->emit("<workgroup>");
591598
}
592599
else if (type->getOp() == kIROp_HLSLRWStructuredBufferType ||
593-
type->getOp() == kIROp_HLSLRasterizerOrderedStructuredBufferType)
600+
type->getOp() == kIROp_HLSLRasterizerOrderedStructuredBufferType ||
601+
type->getOp() == kIROp_HLSLRWByteAddressBufferType)
594602
{
595603
m_writer->emit("<");
596604
m_writer->emit("storage, read_write");
597605
m_writer->emit(">");
598606
}
599-
else if (type->getOp() == kIROp_HLSLStructuredBufferType)
607+
else if (type->getOp() == kIROp_HLSLStructuredBufferType ||
608+
type->getOp() == kIROp_HLSLByteAddressBufferType)
600609
{
601610
m_writer->emit("<");
602611
m_writer->emit("storage, read");
@@ -1178,6 +1187,33 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
11781187
}
11791188
}
11801189
break;
1190+
1191+
case kIROp_ByteAddressBufferLoad:
1192+
{
1193+
// Indices in Slang code count bytes, but in WASM they count u32's since
1194+
// byte address buffers translate to array<u32> in WASM, so divide by 4.
1195+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1196+
m_writer->emit("[(");
1197+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
1198+
m_writer->emit(")/4]");
1199+
return true;
1200+
}
1201+
break;
1202+
1203+
case kIROp_ByteAddressBufferStore:
1204+
{
1205+
// Indices in Slang code count bytes, but in WASM they count u32's since
1206+
// byte address buffers translate to array<u32> in WASM, so divide by 4.
1207+
auto base = inst->getOperand(0);
1208+
emitOperand(base, EmitOpInfo());
1209+
m_writer->emit("[(");
1210+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
1211+
m_writer->emit(")/4] = ");
1212+
emitOperand(inst->getOperand(inst->getOperandCount() - 1), getInfo(EmitOp::General));
1213+
return true;
1214+
}
1215+
break;
1216+
11811217
}
11821218

11831219
return false;

tests/compute/byte-address-buffer.slang

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
55
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
66
//TEST(compute):COMPARE_COMPUTE_EX:-d3d12 -compute -shaderobj
7-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -wgpu
87

98
// Confirm cross-compilation of `(RW)ByteAddressBuffer`
109
//

0 commit comments

Comments
 (0)