Skip to content

Commit 3e84726

Browse files
authored
Fix spirv codegen for pointer to empty structs. (shader-slang#5355)
1 parent 20fa42e commit 3e84726

9 files changed

+129
-26
lines changed

source/slang/slang-compiler-tu.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ namespace Slang
104104
applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet);
105105
applySettingsToDiagnosticSink(&sink, &sink, m_optionSet);
106106

107-
TargetRequest* targetReq = new TargetRequest(linkage, targetEnum);
107+
RefPtr<TargetRequest> targetReq = new TargetRequest(linkage, targetEnum);
108108

109109
List<RefPtr<ComponentType>> allComponentTypes;
110110
allComponentTypes.add(this); // Add Module as a component type
@@ -206,8 +206,8 @@ namespace Slang
206206
}
207207
}
208208

209-
ISlangBlob* blob;
210-
outArtifact->loadBlob(ArtifactKeep::Yes, &blob);
209+
ComPtr<ISlangBlob> blob;
210+
outArtifact->loadBlob(ArtifactKeep::Yes, blob.writeRef());
211211

212212
// Add the precompiled blob to the module
213213
builder.setInsertInto(module);

source/slang/slang-ir-insts.h

+4
Original file line numberDiff line numberDiff line change
@@ -3624,6 +3624,10 @@ struct IRBuilder
36243624
IRGenericKind* getGenericKind();
36253625

36263626
IRPtrType* getPtrType(IRType* valueType);
3627+
3628+
// Form a ptr type to `valueType` using the same opcode and address space as `ptrWithAddrSpace`.
3629+
IRPtrTypeBase* getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace);
3630+
36273631
IROutType* getOutType(IRType* valueType);
36283632
IRInOutType* getInOutType(IRType* valueType);
36293633
IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);

source/slang/slang-ir-legalize-types.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,8 @@ static LegalVal legalizeStore(
934934

935935
case LegalVal::Flavor::simple:
936936
{
937+
if (legalVal.flavor == LegalVal::Flavor::none)
938+
return LegalVal();
937939
context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple());
938940
return legalVal;
939941
}
@@ -2248,7 +2250,7 @@ static LegalVal legalizeLocalVar(
22482250
// Easy case: the type is usable as-is, and we
22492251
// should just do that.
22502252
auto type = maybeSimpleType.getSimple();
2251-
type = context->builder->getPtrType(type);
2253+
type = context->builder->getPtrTypeWithAddressSpace(type, irLocalVar->getDataType());
22522254
if( originalRate )
22532255
{
22542256
type = context->builder->getRateQualifiedType(
@@ -3669,7 +3671,7 @@ static LegalVal legalizeGlobalVar(
36693671
auto legalValueType = legalizeType(
36703672
context,
36713673
originalValueType);
3672-
3674+
auto varPtrType = as<IRPtrTypeBase>(irGlobalVar->getDataType());
36733675
switch (legalValueType.flavor)
36743676
{
36753677
case LegalType::Flavor::simple:
@@ -3678,7 +3680,8 @@ static LegalVal legalizeGlobalVar(
36783680
context->builder->setDataType(
36793681
irGlobalVar,
36803682
context->builder->getPtrType(
3681-
legalValueType.getSimple()));
3683+
legalValueType.getSimple(),
3684+
varPtrType ? varPtrType->getAddressSpace():AddressSpace::Global));
36823685
return LegalVal::simple(irGlobalVar);
36833686

36843687
default:

source/slang/slang-ir-lower-buffer-element-type.cpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,26 @@ namespace Slang
671671
if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType()))
672672
{
673673
builder.setInsertBefore(ptrVal);
674-
auto newArrayPtrVal = builder.emitGetOffsetPtr(fieldAddr->getBase(), builder.getIntValue(builder.getIntType(), 1));
674+
auto newArrayPtrVal = fieldAddr->getBase();
675+
// Is base a pointer to an empty struct? If so, don't offset it.
676+
// For example, if the user has written:
677+
// ```
678+
// struct S {int arr[]};
679+
// uniform S* p;
680+
// void test() { p->arr[1]; }
681+
// ```
682+
// Then `S` will become an empty struct after we remove `arr[]`.
683+
// And `p` will be come a `void*`.
684+
// We don't want to offset `p` to `p+1` to get the starting address of the array in this case.
685+
IRSizeAndAlignment parentStructSize = {};
686+
getNaturalSizeAndAlignment(
687+
target->getOptionSet(),
688+
tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
689+
&parentStructSize);
690+
if (parentStructSize.size != 0)
691+
{
692+
newArrayPtrVal = builder.emitGetOffsetPtr(fieldAddr->getBase(), builder.getIntValue(builder.getIntType(), 1));
693+
}
675694
auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);
676695

677696
IRSizeAndAlignment arrayElementSizeAlignment;
@@ -685,12 +704,14 @@ namespace Slang
685704
&baseSizeAlignment);
686705

687706
// Convert pointer to uint64 and adjust offset.
688-
auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
689707
IRIntegerValue offset = baseSizeAlignment.size;
690708
offset = align(offset, arrayElementSizeAlignment.alignment);
691-
newArrayPtrVal = builder.emitAdd(rawPtr->getFullType(), rawPtr,
692-
builder.getIntValue(builder.getUInt64Type(), offset));
693-
709+
if (offset != 0)
710+
{
711+
auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
712+
newArrayPtrVal = builder.emitAdd(rawPtr->getFullType(), rawPtr,
713+
builder.getIntValue(builder.getUInt64Type(), offset));
714+
}
694715
newArrayPtrVal = builder.emitBitCast(
695716
builder.getPtrType(loweredInnerType.loweredType,
696717
ptrType->getAddressSpace()), newArrayPtrVal);

source/slang/slang-ir-spirv-legalize.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "slang-ir-loop-unroll.h"
2424
#include "slang-ir-lower-buffer-element-type.h"
2525
#include "slang-ir-specialize-address-space.h"
26+
#include "slang-legalize-types.h"
2627

2728
namespace Slang
2829
{
@@ -37,6 +38,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
3738

3839
IRModule* m_module;
3940

41+
DiagnosticSink* m_sink;
42+
4043
struct LoweredStructuredBufferTypeInfo
4144
{
4245
IRType* structType;
@@ -173,8 +176,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
173176
}
174177
}
175178

176-
SPIRVLegalizationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module)
177-
: m_sharedContext(sharedContext), m_module(module)
179+
SPIRVLegalizationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink)
180+
: m_sharedContext(sharedContext), m_module(module), m_sink(sink)
178181
{
179182
}
180183

@@ -2108,6 +2111,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
21082111
// safely lower the pointer load stores early together with other buffer types.
21092112
lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true);
21102113

2114+
// The above step may produce empty struct types, so we need to lower them out of existence.
2115+
legalizeEmptyTypes(m_sharedContext->m_targetProgram, m_module, m_sink);
2116+
21112117
// Specalize address space for all pointers.
21122118
SpirvAddressSpaceAssigner addressSpaceAssigner;
21132119
specializeAddressSpace(m_module, &addressSpaceAssigner);
@@ -2184,9 +2190,9 @@ SpvSnippet* SPIRVEmitSharedContext::getParsedSpvSnippet(IRTargetIntrinsicDecorat
21842190
return snippet;
21852191
}
21862192

2187-
void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module)
2193+
void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink)
21882194
{
2189-
SPIRVLegalizationContext context(sharedContext, module);
2195+
SPIRVLegalizationContext context(sharedContext, module, sink);
21902196
context.processModule();
21912197
}
21922198

@@ -2326,7 +2332,7 @@ void legalizeIRForSPIRV(
23262332
CodeGenContext* codeGenContext)
23272333
{
23282334
SLANG_UNUSED(entryPoints);
2329-
legalizeSPIRV(context, module);
2335+
legalizeSPIRV(context, module, codeGenContext->getSink());
23302336
simplifyIRForSpirvLegalization(context->m_targetProgram, codeGenContext->getSink(), module);
23312337
buildEntryPointReferenceGraph(context->m_referencingEntryPoints, module);
23322338
insertFragmentShaderInterlock(context, module);

source/slang/slang-ir.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -2881,6 +2881,13 @@ namespace Slang
28812881
operands);
28822882
}
28832883

2884+
IRPtrTypeBase* IRBuilder::getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace)
2885+
{
2886+
if (ptrWithAddrSpace->hasAddressSpace())
2887+
return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType, ptrWithAddrSpace->getAddressSpace());
2888+
return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType);
2889+
}
2890+
28842891
IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace)
28852892
{
28862893
return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(addressSpace)));

source/slang/slang-legalize-types.cpp

+32-10
Original file line numberDiff line numberDiff line change
@@ -896,24 +896,46 @@ static LegalType createLegalUniformBufferType(
896896
// Create a pointer type with a given legalized value type.
897897
static LegalType createLegalPtrType(
898898
TypeLegalizationContext* context,
899-
IROp op,
899+
IRInst* originalPtrType,
900900
LegalType legalValueType)
901901
{
902902
switch (legalValueType.flavor)
903903
{
904904
case LegalType::Flavor::none:
905+
if (auto ptrType = as<IRPtrType>(originalPtrType))
906+
{
907+
switch (ptrType->getAddressSpace())
908+
{
909+
case AddressSpace::UserPointer:
910+
case AddressSpace::Global:
911+
// If this is a physical pointer, we need to create an untyped pointer if
912+
// the element type is nothing.
913+
return LegalType::simple(
914+
context->getBuilder()->getPtrTypeWithAddressSpace(
915+
context->getBuilder()->getVoidType(),
916+
ptrType));
917+
}
918+
}
905919
return LegalType();
906920

907921
case LegalType::Flavor::simple:
908922
{
909-
// Easy case: we just have a simple element type,
910-
// so we want to create a uniform buffer that wraps it.
923+
// Easy case: we just have a simple element type.
924+
if (auto ptrTypeBase = as<IRPtrTypeBase>(originalPtrType))
925+
{
926+
if (ptrTypeBase->hasAddressSpace())
927+
{
928+
return LegalType::simple(
929+
context->getBuilder()->getPtrTypeWithAddressSpace(
930+
legalValueType.getSimple(),
931+
ptrTypeBase));
932+
}
933+
}
911934
return LegalType::simple(createBuiltinGenericType(
912935
context,
913-
op,
936+
originalPtrType->getOp(),
914937
legalValueType.getSimple()));
915938
}
916-
break;
917939

918940
case LegalType::Flavor::implicitDeref:
919941
{
@@ -936,7 +958,7 @@ static LegalType createLegalPtrType(
936958
// will matter.
937959
return LegalType::implicitDeref(createLegalPtrType(
938960
context,
939-
op,
961+
originalPtrType,
940962
legalValueType.getImplicitDeref()->valueType));
941963
}
942964
break;
@@ -948,11 +970,11 @@ static LegalType createLegalPtrType(
948970

949971
auto ordinaryType = createLegalPtrType(
950972
context,
951-
op,
973+
originalPtrType,
952974
pairType->ordinaryType);
953975
auto specialType = createLegalPtrType(
954976
context,
955-
op,
977+
originalPtrType,
956978
pairType->specialType);
957979

958980
return LegalType::pair(ordinaryType, specialType, pairType->pairInfo);
@@ -974,7 +996,7 @@ static LegalType createLegalPtrType(
974996
newElement.key = ee.key;
975997
newElement.type = createLegalPtrType(
976998
context,
977-
op,
999+
originalPtrType,
9781000
ee.type);
9791001

9801002
ptrPseudoTupleType->elements.add(newElement);
@@ -1310,7 +1332,7 @@ LegalType legalizeTypeImpl(
13101332
if (legalValueType.flavor == LegalType::Flavor::simple &&
13111333
legalValueType.getSimple() == ptrType->getValueType())
13121334
return LegalType::simple(ptrType);
1313-
return createLegalPtrType(context, ptrType->getOp(), legalValueType);
1335+
return createLegalPtrType(context, ptrType, legalValueType);
13141336
}
13151337
else if(auto structType = as<IRStructType>(type))
13161338
{

tests/spirv/ptr-empty-struct.slang

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target spirv
2+
3+
// CHECK: OpPtrAccessChain
4+
5+
struct EmptyStruct {
6+
};
7+
8+
[vk::push_constant] EmptyStruct* pc;
9+
10+
RWStructuredBuffer<int> outputBuffer;
11+
12+
[numthreads(64)]
13+
void ComputeMain(uint tid: SV_DispatchThreadID) {
14+
outputBuffer[tid] = ((int*)(pc))[0];
15+
}

tests/spirv/ptr-unsized-array-2.slang

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target spirv
2+
3+
// CHECK-DAG: %[[cbuffer__t:[A-Za-z0-9_]+]] = OpTypeStruct %_ptr_PhysicalStorageBuffer_uint
4+
// CHECK-DAG: %light_buffer = OpVariable %_ptr_PushConstant_[[cbuffer__t]] PushConstant
5+
6+
// CHECK: OpAccessChain %_ptr_PushConstant
7+
// CHECK-NEXT: OpLoad
8+
// CHECK-NEXT: OpBitcast %_ptr_PhysicalStorageBuffer
9+
10+
struct LightBuffer {
11+
uint8_t lights[];
12+
}
13+
14+
[vk::push_constant]
15+
LightBuffer* light_buffer;
16+
17+
[shader("vertex")]
18+
float4 vertMain() : SV_Position {
19+
return float4(light_buffer.lights[0]);
20+
}
21+
22+
[shader("fragment")]
23+
float4 fragMain() : COLOR0 {
24+
return float4(1.0);
25+
}

0 commit comments

Comments
 (0)