Skip to content

Commit 31e7f84

Browse files
committed
Passing both assoctype-simple and assoctype-complex test cases.
1 parent d803bf7 commit 31e7f84

12 files changed

+157
-108
lines changed

source/slang/check.cpp

+24-60
Original file line numberDiff line numberDiff line change
@@ -140,46 +140,6 @@ namespace Slang
140140
return result;
141141
}
142142

143-
void insertSubstAtBottom(DeclRefBase & declRef, RefPtr<Substitutions> substToInsert)
144-
{
145-
RefPtr<Substitutions> lastSubst;
146-
auto subst = declRef.substitutions;
147-
while (subst)
148-
{
149-
if (subst)
150-
lastSubst = subst;
151-
subst = subst->outer;
152-
}
153-
if (lastSubst)
154-
lastSubst->outer = substToInsert;
155-
else
156-
declRef.substitutions = substToInsert;
157-
}
158-
159-
RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry)
160-
{
161-
RefPtr<ThisTypeSubstitution> thisSubst;
162-
auto subst = declRef.substitutions;
163-
while (subst)
164-
{
165-
if (auto s = subst.As<ThisTypeSubstitution>())
166-
{
167-
thisSubst = s;
168-
break;
169-
}
170-
subst = subst->outer;
171-
}
172-
if (!thisSubst)
173-
{
174-
thisSubst = new ThisTypeSubstitution();
175-
if (insertSubstEntry)
176-
{
177-
insertSubstAtBottom(declRef, thisSubst);
178-
}
179-
}
180-
return thisSubst;
181-
}
182-
183143
RefPtr<DeclRefType> getExprDeclRefType(Expr * expr)
184144
{
185145
if (auto typetype = expr->type->As<TypeType>())
@@ -221,27 +181,35 @@ namespace Slang
221181

222182
RefPtr<ThisTypeSubstitution> baseThisTypeSubst;
223183
if (auto baseDeclRefExpr = baseExpr->As<DeclRefExpr>())
184+
{
224185
baseThisTypeSubst = getThisTypeSubst(baseDeclRefExpr->declRef, false);
225-
186+
if (auto baseAssocType = baseDeclRefExpr->declRef.As<AssocTypeDecl>())
187+
{
188+
baseThisTypeSubst = new ThisTypeSubstitution();
189+
baseThisTypeSubst->sourceType = baseDeclRefExpr->type.type;
190+
if (auto typetype = baseThisTypeSubst->sourceType.As<TypeType>())
191+
baseThisTypeSubst->sourceType = typetype->type;
192+
}
193+
}
226194
if (auto assocTypeDecl = declRef.As<AssocTypeDecl>())
227195
{
228-
if (!baseThisTypeSubst)
229-
baseThisTypeSubst = new ThisTypeSubstitution();
230-
expr->type = GetTypeForDeclRef(DeclRef<AssocTypeDecl>(assocTypeDecl.getDecl(), baseThisTypeSubst));
231-
232-
RefPtr<ThisTypeSubstitution> outerSubst = getThisTypeSubst(*declRefOut, true);
233-
outerSubst->sourceType = expr->type.type;
234-
if (auto outerTypeType = outerSubst->sourceType.As<TypeType>())
235-
outerSubst->sourceType = outerTypeType->type;
236-
declRefOut->substitutions = outerSubst;
196+
auto newThisTypeSubst = new ThisTypeSubstitution();
197+
if (baseThisTypeSubst)
198+
newThisTypeSubst->sourceType = baseThisTypeSubst->sourceType;
199+
expr->type = GetTypeForDeclRef(DeclRef<AssocTypeDecl>(assocTypeDecl.getDecl(), newThisTypeSubst));
200+
auto declOutThisTypeSubst = getNewThisTypeSubst(*declRefOut);
201+
if (baseThisTypeSubst)
202+
declOutThisTypeSubst->sourceType = baseThisTypeSubst->sourceType;
237203
return expr;
238204
}
239205

240206
// propagate "this-type" substitutions
241207
if (baseThisTypeSubst)
242208
{
243209
if (auto declRefExpr = expr.As<DeclRefExpr>())
244-
insertSubstAtBottom(declRefExpr->declRef, baseThisTypeSubst);
210+
{
211+
getNewThisTypeSubst(declRefExpr->declRef)->sourceType = baseThisTypeSubst->sourceType;
212+
}
245213
}
246214
expr->type = GetTypeForDeclRef(declRef);
247215
return expr;
@@ -2000,8 +1968,6 @@ namespace Slang
20001968

20011969
void visitFuncDecl(FuncDecl *functionNode)
20021970
{
2003-
if (functionNode->nameAndLoc.name->text == "test")
2004-
printf("break");
20051971
if (functionNode->IsChecked(DeclCheckState::Checked))
20061972
return;
20071973

@@ -5861,7 +5827,6 @@ namespace Slang
58615827

58625828
RefPtr<Expr> CheckInvokeExprWithCheckedOperands(InvokeExpr *expr)
58635829
{
5864-
58655830
auto rs = ResolveInvoke(expr);
58665831
if (auto invoke = dynamic_cast<InvokeExpr*>(rs.Ptr()))
58675832
{
@@ -5894,9 +5859,6 @@ namespace Slang
58945859

58955860
RefPtr<Expr> visitInvokeExpr(InvokeExpr *expr)
58965861
{
5897-
if (auto mbrExpr = expr->FunctionExpr->As<MemberExpr>())
5898-
if (mbrExpr->name->text == "add")
5899-
printf("break");
59005862
// check the base expression first
59015863
expr->FunctionExpr = CheckExpr(expr->FunctionExpr);
59025864

@@ -6631,9 +6593,11 @@ namespace Slang
66316593
}
66326594
else if (auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>())
66336595
{
6634-
auto type = DeclRefType::Create(session, constraintDeclRef);
6635-
*outTypeResult = type;
6636-
return QualType(getTypeType(type));
6596+
// When we access a constraint or an inheritance decl (as a member),
6597+
// we are conceptually performing a "cast" to the given super-type,
6598+
// with the declaration showing that such a cast is legal.
6599+
auto type = GetSup(constraintDeclRef);
6600+
return QualType(type);
66376601
}
66386602
if( sink )
66396603
{

source/slang/emit.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,7 @@ struct EmitVisitor
395395
void emitRawTextSpan(char const* textBegin, char const* textEnd)
396396
{
397397
// TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things...
398-
auto len = int(textEnd - textBegin);
399-
398+
auto len = textEnd - textBegin;
400399
context->shared->sb.Append(textBegin, len);
401400
}
402401

source/slang/lookup.cpp

+22-21
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,27 @@ void lookUpMemberImpl(
410410
if (auto declRefType = type->As<DeclRefType>())
411411
{
412412
auto declRef = declRefType->declRef;
413-
if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>())
413+
if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
414+
{
415+
for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef))
416+
{
417+
// The super-type in the constraint (e.g., `Foo` in `T : Foo`)
418+
// will tell us a type we should use for lookup.
419+
auto bound = GetSup(constraintDeclRef);
420+
421+
// Go ahead and use the target type, with an appropriate breadcrumb
422+
// to indicate that we indirected through a type constraint.
423+
424+
BreadcrumbInfo breadcrumb;
425+
breadcrumb.prev = inBreadcrumbs;
426+
breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint;
427+
breadcrumb.declRef = constraintDeclRef;
428+
429+
// TODO: Need to consider case where this might recurse infinitely.
430+
lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb);
431+
}
432+
}
433+
else if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>())
414434
{
415435
LookupRequest request;
416436
request.semantics = semantics;
@@ -452,26 +472,7 @@ void lookUpMemberImpl(
452472
lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb);
453473
}
454474
}
455-
else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
456-
{
457-
for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef))
458-
{
459-
// The super-type in the constraint (e.g., `Foo` in `T : Foo`)
460-
// will tell us a type we should use for lookup.
461-
auto bound = GetSup(constraintDeclRef);
462-
463-
// Go ahead and use the target type, with an appropriate breadcrumb
464-
// to indicate that we indirected through a type constraint.
465-
466-
BreadcrumbInfo breadcrumb;
467-
breadcrumb.prev = inBreadcrumbs;
468-
breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint;
469-
breadcrumb.declRef = constraintDeclRef;
470-
471-
// TODO: Need to consider case where this might recurse infinitely.
472-
lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb);
473-
}
474-
}
475+
475476
}
476477

477478
}

source/slang/lower-to-ir.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -2652,8 +2652,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
26522652

26532653
LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
26542654
{
2655-
if (decl->getName()->text == "test")
2656-
printf("break");
26572655
// Collect the parameter lists we will use for our new function.
26582656
ParameterLists parameterLists;
26592657
collectParameterLists(decl, &parameterLists, kParameterListCollectMode_Default);
@@ -3086,7 +3084,7 @@ LoweredValInfo emitDeclRef(
30863084

30873085
// If this declaration reference doesn't involve any specializations,
30883086
// then we are done at this point.
3089-
if(!declRef.substitutions)
3087+
if(!hasGenericSubstitutions(declRef.substitutions))
30903088
return loweredDecl;
30913089

30923090
auto val = getSimpleVal(context, loweredDecl);

source/slang/slang.vcxproj

+3
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@
296296
<AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release|x64'">$(OutDir)slang-generate.exe</AdditionalInputs>
297297
</CustomBuild>
298298
</ItemGroup>
299+
<ItemGroup>
300+
<None Include="slang.natstepfilter" />
301+
</ItemGroup>
299302
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
300303
<ImportGroup Label="ExtensionTargets">
301304
</ImportGroup>

source/slang/slang.vcxproj.filters

+3
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,7 @@
7777
<CustomBuild Include="glsl.meta.slang" />
7878
<CustomBuild Include="hlsl.meta.slang" />
7979
</ItemGroup>
80+
<ItemGroup>
81+
<None Include="slang.natstepfilter" />
82+
</ItemGroup>
8083
</Project>

source/slang/syntax.cpp

+83-1
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,14 @@ void Type::accept(IValVisitor* visitor, void* extra)
530530
// we want to replace it with the actual associated type
531531
else if (auto assocTypeDecl = dynamic_cast<AssocTypeDecl*>(declRef.getDecl()))
532532
{
533+
auto thisSubst = getThisTypeSubst(declRef, false);
534+
auto oldSubstSrc = thisSubst ? thisSubst->sourceType : nullptr;
535+
bool restore = false;
536+
if (thisSubst && thisSubst->sourceType.Ptr() == dynamic_cast<Val*>(this))
537+
thisSubst->sourceType = nullptr;
533538
auto newSubst = declRef.substitutions->SubstituteImpl(subst, ioDiff);
539+
if (restore)
540+
thisSubst->sourceType = oldSubstSrc;
534541
if (auto thisTypeSubst = newSubst.As<ThisTypeSubstitution>())
535542
{
536543
if (thisTypeSubst->sourceType)
@@ -1258,6 +1265,8 @@ void Type::accept(IValVisitor* visitor, void* extra)
12581265

12591266
bool ThisTypeSubstitution::Equals(Substitutions* subst)
12601267
{
1268+
if (!subst)
1269+
return true;
12611270
if (subst && dynamic_cast<ThisTypeSubstitution*>(subst))
12621271
return true;
12631272
return false;
@@ -1323,7 +1332,7 @@ void Type::accept(IValVisitor* visitor, void* extra)
13231332
if (decl != declRef.decl)
13241333
return false;
13251334
if (!substitutions)
1326-
return !declRef.substitutions;
1335+
return !declRef.substitutions || declRef.substitutions.As<ThisTypeSubstitution>();
13271336
if (!substitutions->Equals(declRef.substitutions.Ptr()))
13281337
return false;
13291338

@@ -1634,4 +1643,77 @@ void Type::accept(IValVisitor* visitor, void* extra)
16341643

16351644

16361645

1646+
void insertSubstAtTop(DeclRefBase & declRef, RefPtr<Substitutions> substToInsert)
1647+
{
1648+
substToInsert->outer = declRef.substitutions;
1649+
declRef.substitutions = substToInsert;
1650+
}
1651+
1652+
RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry)
1653+
{
1654+
RefPtr<ThisTypeSubstitution> thisSubst;
1655+
auto subst = declRef.substitutions;
1656+
while (subst)
1657+
{
1658+
if (auto s = subst.As<ThisTypeSubstitution>())
1659+
{
1660+
thisSubst = s;
1661+
break;
1662+
}
1663+
subst = subst->outer;
1664+
}
1665+
if (!thisSubst)
1666+
{
1667+
thisSubst = new ThisTypeSubstitution();
1668+
if (insertSubstEntry)
1669+
{
1670+
insertSubstAtTop(declRef, thisSubst);
1671+
}
1672+
}
1673+
return thisSubst;
1674+
}
1675+
1676+
RefPtr<ThisTypeSubstitution> getNewThisTypeSubst(DeclRefBase & declRef)
1677+
{
1678+
auto oldSubst = getThisTypeSubst(declRef, false);
1679+
if (oldSubst)
1680+
removeSubstitution(declRef, oldSubst);
1681+
return getThisTypeSubst(declRef, true);
1682+
}
1683+
1684+
void removeSubstitution(DeclRefBase & declRef, RefPtr<Substitutions> toRemove)
1685+
{
1686+
if (!declRef.substitutions)
1687+
return;
1688+
if (toRemove == declRef.substitutions)
1689+
{
1690+
declRef.substitutions = declRef.substitutions->outer;
1691+
return;
1692+
}
1693+
auto prev = declRef.substitutions;
1694+
auto subst = prev->outer;
1695+
while (subst)
1696+
{
1697+
if (subst == toRemove)
1698+
{
1699+
prev->outer = subst->outer;
1700+
break;
1701+
}
1702+
prev = subst;
1703+
subst = subst->outer;
1704+
}
1705+
}
1706+
1707+
bool hasGenericSubstitutions(RefPtr<Substitutions> subst)
1708+
{
1709+
auto p = subst.Ptr();
1710+
while (p)
1711+
{
1712+
if (dynamic_cast<GenericSubstitution*>(p))
1713+
return true;
1714+
p = p->outer.Ptr();
1715+
}
1716+
return false;
1717+
}
1718+
16371719
}

source/slang/syntax.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,12 @@ namespace Slang
11551155
RefPtr<Substitutions> createDefaultSubstitutions(
11561156
Session* session,
11571157
Decl* decl);
1158-
1158+
1159+
RefPtr<ThisTypeSubstitution> getNewThisTypeSubst(DeclRefBase & declRef);
1160+
RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry);
1161+
void removeSubstitution(DeclRefBase & declRef, RefPtr<Substitutions> subst);
1162+
bool hasGenericSubstitutions(RefPtr<Substitutions> subst);
1163+
11591164
} // namespace Slang
11601165

11611166
#endif

tests/compute/assoctype-complex.slang

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//TEST(smoke, compute):COMPARE_COMPUTE:-xslang -use-ir
22
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
33

4-
RWStructuredBuffer<float> outputBuffer;
4+
RWStructuredBuffer<int> outputBuffer;
55
interface IBase
66
{
77
associatedtype V;
@@ -16,9 +16,10 @@ interface ISimple
1616
struct Val : IBase
1717
{
1818
typedef int V;
19+
int base;
1920
V sub(V a0, V a1)
2021
{
21-
return a0-a1;
22+
return a0 - a1 + base;
2223
}
2324
};
2425

@@ -42,6 +43,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
4243
{
4344
Simple s;
4445
Val v0, v1;
45-
float outVal = test(s, v0, v1); // == 1.0
46+
v0.base = 1;
47+
v1.base = 2;
48+
int outVal = test<Simple>(s, v0, v1); // == 4.0
4649
outputBuffer[dispatchThreadID.x] = outVal;
4750
}

0 commit comments

Comments
 (0)