Skip to content

Commit 09a9d67

Browse files
Allow pointers to existential values. (#5793)
* Fix pointer offset logic and add executable tests. * Fix. * Fix test. * Add existential ptr test. * Allow pointers to existential values. * Fix. * Fix. --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
1 parent 051ae8a commit 09a9d67

12 files changed

+215
-43
lines changed

source/slang/slang-check-decl.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -3105,6 +3105,17 @@ Type* unwrapArrayType(Type* type)
31053105
}
31063106
}
31073107

3108+
Type* unwrapModifiedType(Type* type)
3109+
{
3110+
for (;;)
3111+
{
3112+
if (auto modType = as<ModifiedType>(type))
3113+
type = modType->getBase();
3114+
else
3115+
return type;
3116+
}
3117+
}
3118+
31083119
void discoverExtensionDecls(List<ExtensionDecl*>& decls, Decl* parent)
31093120
{
31103121
if (auto extDecl = as<ExtensionDecl>(parent))

source/slang/slang-check-expr.cpp

+35-12
Original file line numberDiff line numberDiff line change
@@ -2307,7 +2307,10 @@ Expr* SemanticsVisitor::CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type*
23072307
Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr)
23082308
{
23092309
bool needDeref = false;
2310-
auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression, needDeref);
2310+
auto baseExpr = checkBaseForMemberExpr(
2311+
subscriptExpr->baseExpression,
2312+
CheckBaseContext::Subscript,
2313+
needDeref);
23112314

23122315
// If the base expression is a type, it means that this is an array declaration,
23132316
// then we should disable short-circuit in case there is logical expression in
@@ -2951,7 +2954,10 @@ Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
29512954
auto operatorName = getName("()");
29522955

29532956
bool needDeref = false;
2954-
expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref);
2957+
expr->functionExpr = maybeInsertImplicitOpForMemberBase(
2958+
expr->functionExpr,
2959+
CheckBaseContext::Member,
2960+
needDeref);
29552961

29562962
LookupResult lookupResult = lookUpMember(
29572963
m_astBuilder,
@@ -4060,27 +4066,36 @@ void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr)
40604066
}
40614067
}
40624068

4063-
Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr)
4069+
Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext)
40644070
{
40654071
Expr* expr = inExpr;
40664072
for (;;)
40674073
{
40684074
auto baseType = expr->type;
4075+
QualType elementType;
40694076
if (auto pointerLikeType = as<PointerLikeType>(baseType))
40704077
{
4071-
auto elementType = QualType(pointerLikeType->getElementType());
4078+
elementType = QualType(pointerLikeType->getElementType());
40724079
elementType.isLeftValue = baseType.isLeftValue;
40734080
elementType.hasReadOnlyOnTarget = baseType.hasReadOnlyOnTarget;
40744081
elementType.isWriteOnly = baseType.isWriteOnly;
4075-
4082+
}
4083+
else if (auto ptrType = as<PtrType>(baseType))
4084+
{
4085+
if (checkBaseContext == CheckBaseContext::Subscript)
4086+
return expr;
4087+
elementType = QualType(ptrType->getValueType());
4088+
elementType.isLeftValue = true;
4089+
}
4090+
if (elementType.type)
4091+
{
40764092
auto derefExpr = m_astBuilder->create<DerefExpr>();
40774093
derefExpr->base = expr;
40784094
derefExpr->type = elementType;
40794095

40804096
expr = derefExpr;
40814097
continue;
40824098
}
4083-
40844099
// Default case: just use the expression as-is
40854100
return expr;
40864101
}
@@ -4751,7 +4766,7 @@ Expr* SemanticsExprVisitor::visitStaticMemberExpr(StaticMemberExpr* expr)
47514766
expr->baseExpression = CheckTerm(expr->baseExpression);
47524767

47534768
// Not sure this is needed -> but guess someone could do
4754-
expr->baseExpression = MaybeDereference(expr->baseExpression);
4769+
expr->baseExpression = maybeDereference(expr->baseExpression, CheckBaseContext::Member);
47554770

47564771
// If the base of the member lookup has an interface type
47574772
// *without* a suitable this-type substitution, then we are
@@ -4779,9 +4794,12 @@ Expr* SemanticsVisitor::lookupMemberResultFailure(
47794794
return expr;
47804795
}
47814796

4782-
Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref)
4797+
Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(
4798+
Expr* baseExpr,
4799+
CheckBaseContext checkBaseContext,
4800+
bool& outNeedDeref)
47834801
{
4784-
auto derefExpr = MaybeDereference(baseExpr);
4802+
auto derefExpr = maybeDereference(baseExpr, checkBaseContext);
47854803

47864804
if (derefExpr != baseExpr)
47874805
outNeedDeref = true;
@@ -4834,11 +4852,15 @@ Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool&
48344852
return baseExpr;
48354853
}
48364854

4837-
Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref)
4855+
Expr* SemanticsVisitor::checkBaseForMemberExpr(
4856+
Expr* inBaseExpr,
4857+
CheckBaseContext checkBaseContext,
4858+
bool& outNeedDeref)
48384859
{
48394860
auto baseExpr = inBaseExpr;
48404861
baseExpr = CheckTerm(baseExpr);
4841-
return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref);
4862+
4863+
return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
48424864
}
48434865

48444866
Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)
@@ -4861,7 +4883,8 @@ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* bas
48614883
Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr* expr)
48624884
{
48634885
bool needDeref = false;
4864-
expr->baseExpression = checkBaseForMemberExpr(expr->baseExpression, needDeref);
4886+
expr->baseExpression =
4887+
checkBaseForMemberExpr(expr->baseExpression, CheckBaseContext::Member, needDeref);
48654888

48664889
if (!needDeref && as<DerefMemberExpr>(expr) && !as<PtrType>(expr->baseExpression->type))
48674890
{

source/slang/slang-check-impl.h

+15-4
Original file line numberDiff line numberDiff line change
@@ -2654,8 +2654,6 @@ struct SemanticsVisitor : public SemanticsContext
26542654
//
26552655
//
26562656

2657-
Expr* MaybeDereference(Expr* inExpr);
2658-
26592657
Expr* CheckMatrixSwizzleExpr(
26602658
MemberExpr* memberRefExpr,
26612659
Type* baseElementType,
@@ -2696,11 +2694,24 @@ struct SemanticsVisitor : public SemanticsContext
26962694

26972695
/// Perform checking operations required for the "base" expression of a member-reference like
26982696
/// `base.someField`
2699-
Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref);
2697+
enum class CheckBaseContext
2698+
{
2699+
Member,
2700+
Subscript,
2701+
};
2702+
Expr* checkBaseForMemberExpr(
2703+
Expr* baseExpr,
2704+
CheckBaseContext checkBaseContext,
2705+
bool& outNeedDeref);
2706+
2707+
Expr* maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext);
27002708

27012709
/// Prepare baseExpr for use as the base of a member expr.
27022710
/// This include inserting implicit open-existential operations as needed.
2703-
Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref);
2711+
Expr* maybeInsertImplicitOpForMemberBase(
2712+
Expr* baseExpr,
2713+
CheckBaseContext checkBaseContext,
2714+
bool& outNeedDeref);
27042715

27052716
Expr* lookupMemberResultFailure(
27062717
DeclRefExpr* expr,

source/slang/slang-check-type.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,10 @@ bool isManagedType(Type* type)
216216
{
217217
if (auto declRefValueType = as<DeclRefType>(type))
218218
{
219-
if (as<ClassDecl>(declRefValueType->getDeclRef().getDecl()))
219+
auto decl = declRefValueType->getDeclRef().getDecl();
220+
if (as<ClassDecl>(decl))
220221
return true;
221-
if (as<InterfaceDecl>(declRefValueType->getDeclRef().getDecl()))
222+
if (as<InterfaceDecl>(decl) && decl->findModifier<ComInterfaceAttribute>())
222223
return true;
223224
}
224225
return false;

source/slang/slang-check.h

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ bool isFromCoreModule(Decl* decl);
2424
void registerBuiltinDecls(Session* session, Decl* decl);
2525

2626
Type* unwrapArrayType(Type* type);
27+
Type* unwrapModifiedType(Type* type);
2728

2829
OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericConstraints(
2930
ASTBuilder* builder,

source/slang/slang-ir-lower-buffer-element-type.cpp

-22
Original file line numberDiff line numberDiff line change
@@ -901,28 +901,6 @@ struct LoweredElementTypeContext
901901
{
902902
builder.setInsertBefore(ptrVal);
903903
auto newArrayPtrVal = fieldAddr->getBase();
904-
// Is base a pointer to an empty struct? If so, don't offset it.
905-
// For example, if the user has written:
906-
// ```
907-
// struct S {int arr[]};
908-
// uniform S* p;
909-
// void test() { p->arr[1]; }
910-
// ```
911-
// Then `S` will become an empty struct after we remove `arr[]`.
912-
// And `p` will be come a `void*`.
913-
// We don't want to offset `p` to `p+1` to get the starting address of
914-
// the array in this case.
915-
IRSizeAndAlignment parentStructSize = {};
916-
getNaturalSizeAndAlignment(
917-
target->getOptionSet(),
918-
tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
919-
&parentStructSize);
920-
if (parentStructSize.size != 0)
921-
{
922-
newArrayPtrVal = builder.emitGetOffsetPtr(
923-
fieldAddr->getBase(),
924-
builder.getIntValue(builder.getIntType(), 1));
925-
}
926904
auto loweredInnerType =
927905
getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);
928906

tests/bugs/gh-3825.slang

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ float4 fragment(): SV_Target
2121
}
2222

2323
// CHECK: OpDecorate %_ptr_PhysicalStorageBuffer_Descriptors_natural ArrayStride 4
24-
// CHECK: %{{.*}} = OpPtrAccessChain %_ptr_PhysicalStorageBuffer_Descriptors_natural %{{.*}} %int_1
2524
// CHECK: OpBitcast %ulong
2625
// CHECK: OpIAdd %ulong %{{.*}} %ulong_4
2726
// CHECK: OpBitcast %_ptr_PhysicalStorageBuffer

tests/spirv/existential-ptr.slang

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly -output-using-type
2+
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
3+
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
4+
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
5+
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
6+
7+
interface IFoo
8+
{
9+
int getVal();
10+
}
11+
12+
struct Foo : IFoo
13+
{
14+
int val;
15+
int getVal() { return val; }
16+
}
17+
18+
struct Bar : IFoo
19+
{
20+
float val;
21+
int getVal() { return (int)val + 1; }
22+
}
23+
24+
//TEST_INPUT: set pFoo = ubuffer(data=[0 0 2 0 2.0f], stride=4);
25+
//TEST_INPUT: type_conformance Foo:IFoo = 1;
26+
//TEST_INPUT: type_conformance Bar:IFoo = 2;
27+
uniform IFoo* pFoo;
28+
29+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4);
30+
RWStructuredBuffer<float> outputBuffer;
31+
32+
[numthreads(1,1,1)]
33+
void computeMain()
34+
{
35+
// CHECK: 3.0
36+
outputBuffer[0] = pFoo->getVal();
37+
}

tests/spirv/ptr-member-func.slang

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
2+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
3+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
4+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
5+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
6+
7+
struct Obj
8+
{
9+
int val;
10+
11+
[mutating]
12+
void addOne() { val++; }
13+
14+
int getValPlusOne() { return val + 1; }
15+
}
16+
17+
//TEST_INPUT: set pObj = ubuffer(data=[2 0 0 0], stride=4);
18+
uniform Obj* pObj;
19+
20+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
21+
uniform RWStructuredBuffer<uint> outputBuffer;
22+
23+
[numthreads(1,1,1)]
24+
void computeMain()
25+
{
26+
pObj->addOne();
27+
// CHECK: 4
28+
outputBuffer[0] = pObj->getValPlusOne();
29+
}

tests/spirv/ptr-unsized-array-3.slang

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
2+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
3+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
4+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
5+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
6+
7+
// Test a pointer to a struct with a trailing unsized array.
8+
9+
struct MeshStorage {
10+
int foo;
11+
uint64_t QuadData[];
12+
};
13+
14+
//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4);
15+
uniform MeshStorage* pStorage;
16+
17+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
18+
uniform RWStructuredBuffer<uint> outputBuffer;
19+
20+
[numthreads(1,1,1)]
21+
void computeMain()
22+
{
23+
// CHECK: 5
24+
// CHECK: 6
25+
// CHECK: 1
26+
outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF);
27+
outputBuffer[1] = (int)(pStorage.QuadData[1]>>32);
28+
outputBuffer[2] = pStorage.foo;
29+
}

tests/spirv/ptr-unsized-array-4.slang

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
2+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
3+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
4+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
5+
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
6+
7+
// Test a pointer to a struct that has only one field and is an unsized array.
8+
struct MeshStorage {
9+
uint64_t QuadData[];
10+
};
11+
12+
//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4);
13+
uniform MeshStorage* pStorage;
14+
15+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
16+
uniform RWStructuredBuffer<uint> outputBuffer;
17+
18+
[numthreads(1,1,1)]
19+
void computeMain()
20+
{
21+
// CHECK: 3
22+
// CHECK: 4
23+
outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF);
24+
outputBuffer[1] = (int)(pStorage.QuadData[1]>>32);
25+
}

0 commit comments

Comments
 (0)