Skip to content

Commit 134354c

Browse files
committed
Adding support for associated types.
1 parent b623864 commit 134354c

13 files changed

+294
-38
lines changed

source/slang/check.cpp

+59-24
Original file line numberDiff line numberDiff line change
@@ -147,26 +147,64 @@ namespace Slang
147147
{
148148
if (baseExpr)
149149
{
150+
RefPtr<Expr> expr;
151+
150152
if (baseExpr->type->As<TypeType>())
151153
{
152-
auto expr = new StaticMemberExpr();
153-
expr->loc = loc;
154-
expr->BaseExpression = baseExpr;
155-
expr->name = declRef.GetName();
156-
expr->type = GetTypeForDeclRef(declRef);
157-
expr->declRef = declRef;
158-
return expr;
154+
auto sexpr = new StaticMemberExpr();
155+
sexpr->loc = loc;
156+
sexpr->BaseExpression = baseExpr;
157+
sexpr->name = declRef.GetName();
158+
sexpr->type = GetTypeForDeclRef(declRef);
159+
sexpr->declRef = declRef;
160+
expr = sexpr;
159161
}
160162
else
161163
{
162-
auto expr = new MemberExpr();
163-
expr->loc = loc;
164-
expr->BaseExpression = baseExpr;
165-
expr->name = declRef.GetName();
166-
expr->type = GetTypeForDeclRef(declRef);
167-
expr->declRef = declRef;
168-
return expr;
164+
auto sexpr = new MemberExpr();
165+
sexpr->loc = loc;
166+
sexpr->BaseExpression = baseExpr;
167+
sexpr->name = declRef.GetName();
168+
sexpr->type = GetTypeForDeclRef(declRef);
169+
sexpr->declRef = declRef;
170+
expr = sexpr;
171+
}
172+
if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
173+
{
174+
if (auto genConstraintType = baseExpr->type->As<GenericConstraintDeclRefType>())
175+
{
176+
// if this is a reference from a generic parameter, we need to generate a AssocTypeDeclRefType type.
177+
// for example, if we have an expression T.U where T:ISimple, and U is an associated type defined in ISimple.
178+
// then this expression should evaluate to AssocTypeDeclRefType(T, U).
179+
auto assocTypeDeclType = new AssocTypeDeclRefType();
180+
assocTypeDeclType->declRef = assocTypeDeclRef;
181+
assocTypeDeclType->sourceType = genConstraintType->subType;
182+
assocTypeDeclType->setSession(getSession());
183+
expr->type = QualType(getTypeType(assocTypeDeclType));
184+
}
169185
}
186+
else if (auto funcDeclRef = declRef.As<CallableDecl>())
187+
{
188+
if (auto genConstraintType = baseExpr->type->As<GenericConstraintDeclRefType>())
189+
{
190+
// if this is call expression, propagate the source associated type to the result type
191+
auto funcType = expr->type->As<FuncType>();
192+
if (auto assocRsType = funcType->resultType.As<AssocTypeDeclRefType>())
193+
{
194+
RefPtr<FuncType> newFuncType = new FuncType();
195+
newFuncType->paramTypes = funcType->paramTypes;
196+
RefPtr<AssocTypeDeclRefType> newRsType = new AssocTypeDeclRefType();
197+
newRsType->declRef = assocRsType->declRef;
198+
newRsType->sourceType = genConstraintType->subType;
199+
newRsType->setSession(getSession());
200+
newFuncType->resultType = newRsType;
201+
newFuncType->setSession(funcType->getSession());
202+
expr->type = QualType(newFuncType);
203+
}
204+
205+
}
206+
}
207+
return expr;
170208
}
171209
else
172210
{
@@ -1878,7 +1916,6 @@ namespace Slang
18781916
VisitFunctionDeclaration(functionNode);
18791917
// TODO: This should really onlye set "checked header"
18801918
functionNode->SetCheckState(DeclCheckState::Checked);
1881-
18821919
// TODO: should put the checking of the body onto a "work list"
18831920
// to avoid recursion here.
18841921
if (functionNode->Body)
@@ -2309,7 +2346,6 @@ namespace Slang
23092346
{
23102347
if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return;
23112348
functionNode->SetCheckState(DeclCheckState::CheckingHeader);
2312-
23132349
this->function = functionNode;
23142350
auto returnType = CheckProperType(functionNode->ReturnType);
23152351
functionNode->ReturnType = returnType;
@@ -4366,7 +4402,7 @@ namespace Slang
43664402

43674403

43684404
callExpr->FunctionExpr = baseExpr;
4369-
callExpr->type = QualType(candidate.resultType);
4405+
callExpr->type = QualType(candidate.resultType);// QualType(baseExpr->type->As<FuncType>()->resultType);
43704406

43714407
// A call may yield an l-value, and we should take a look at the candidate to be sure
43724408
if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDecl>())
@@ -4557,7 +4593,9 @@ namespace Slang
45574593
OverloadCandidate candidate;
45584594
candidate.flavor = OverloadCandidate::Flavor::Func;
45594595
candidate.item = item;
4560-
candidate.resultType = GetResultType(funcDeclRef);
4596+
auto baseExpr = ConstructLookupResultExpr(
4597+
item, context.baseExpr, context.funcLoc);
4598+
candidate.resultType = baseExpr->type->As<FuncType>()->resultType; // GetResultType(funcDeclRef);
45614599

45624600
AddOverloadCandidate(context, candidate);
45634601
}
@@ -5717,10 +5755,6 @@ namespace Slang
57175755

57185756
RefPtr<Expr> visitInvokeExpr(InvokeExpr *expr)
57195757
{
5720-
if (auto appExpr = expr->FunctionExpr->As<GenericAppExpr>())
5721-
if (auto varExpr = appExpr->FunctionExpr->As<VarExpr>())
5722-
if (varExpr->name->text == "test")
5723-
printf("break");
57245758
// check the base expression first
57255759
expr->FunctionExpr = CheckExpr(expr->FunctionExpr);
57265760

@@ -6453,7 +6487,7 @@ namespace Slang
64536487
// When we access a constraint or an inheritance decl (as a member),
64546488
// we are conceptually performing a "cast" to the given super-type,
64556489
// with the declaration showing that such a cast is legal.
6456-
auto type = GetSup(constraintDeclRef);
6490+
auto type = new GenericConstraintDeclRefType(session, GetSub(constraintDeclRef), GetSup(constraintDeclRef));
64576491
return QualType(type);
64586492
}
64596493
else if (auto funcDeclRef = declRef.As<CallableDecl>())
@@ -6463,7 +6497,8 @@ namespace Slang
64636497
}
64646498
else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
64656499
{
6466-
auto type = DeclRefType::Create(session, assocTypeDeclRef);
6500+
auto type = new AssocTypeDeclRefType(assocTypeDeclRef);
6501+
type->setSession(session);
64676502
*outTypeResult = type;
64686503
return QualType(getTypeType(type));
64696504
}

source/slang/emit.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,11 @@ struct EmitVisitor
10541054

10551055
void visitAssocTypeDeclRefType(AssocTypeDeclRefType* /*type*/, TypeEmitArg const& /*arg*/)
10561056
{
1057-
SLANG_UNREACHABLE("visitAssocTypeDeclRefType in EmitVisitor");
1057+
//SLANG_UNREACHABLE("visitAssocTypeDeclRefType in EmitVisitor");
1058+
}
1059+
void visitGenericConstraintDeclRefType(GenericConstraintDeclRefType* /*type*/, TypeEmitArg const& /*arg*/)
1060+
{
1061+
//SLANG_UNREACHABLE("visitGenericConstraintDeclRefType in EmitVisitor");
10581062
}
10591063

10601064
void visitBasicExpressionType(BasicExpressionType* basicType, TypeEmitArg const& arg)

source/slang/ir-insts.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,9 @@ struct IRBuilder
346346
IRValue* getBoolValue(bool value);
347347
IRValue* getIntValue(IRType* type, IRIntegerValue value);
348348
IRValue* getFloatValue(IRType* type, IRFloatingPointValue value);
349-
350349
IRValue* getDeclRefVal(
351350
DeclRefBase const& declRef);
352-
351+
IRValue* getTypeVal(IRType* type); // create an IR value that represents a type
353352
IRValue* emitSpecializeInst(
354353
IRType* type,
355354
IRValue* genericVal,

source/slang/ir.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,16 @@ namespace Slang
562562
return irValue;
563563
}
564564

565+
IRValue * IRBuilder::getTypeVal(IRType * type)
566+
{
567+
auto irValue = createValue<IRDeclRef>(
568+
this,
569+
kIROp_TypeType,
570+
nullptr);
571+
irValue->type = type;
572+
return irValue;
573+
}
574+
565575
IRValue* IRBuilder::emitSpecializeInst(
566576
Type* type,
567577
IRValue* genericVal,
@@ -3061,7 +3071,12 @@ namespace Slang
30613071
return builder->getDeclRefVal(declRef);
30623072
}
30633073
break;
3064-
3074+
case kIROp_TypeType:
3075+
{
3076+
IRValue* od = (IRValue*)originalValue;
3077+
return builder->getTypeVal(od->type);
3078+
}
3079+
break;
30653080
default:
30663081
SLANG_UNEXPECTED("no value registered for IR value");
30673082
return nullptr;

source/slang/lower-to-ir.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
20622062
SLANG_UNIMPLEMENTED_X("decl catch-all");
20632063
}
20642064

2065+
LoweredValInfo visitTypeDefDecl(TypeDefDecl * decl)
2066+
{
2067+
return LoweredValInfo::simple(context->irBuilder->getTypeVal(decl->type.type));
2068+
}
2069+
20652070
LoweredValInfo visitGenericTypeParamDecl(GenericTypeParamDecl* decl)
20662071
{
20672072
return LoweredValInfo();
@@ -2307,7 +2312,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
23072312

23082313
// For now, we don't have an IR-level representation
23092314
// for the type itself.
2310-
return LoweredValInfo();
2315+
return LoweredValInfo::simple(context->irBuilder->getTypeVal(DeclRefType::Create(context->getSession(),
2316+
DeclRef<Decl>(decl, nullptr))));
23112317
}
23122318

23132319

source/slang/lower.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,13 @@ struct LoweringVisitor
779779
translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>());
780780
}
781781

782+
RefPtr<Type> visitGenericConstraintDeclRefType(GenericConstraintDeclRefType* type)
783+
{
784+
// not supported by lowering
785+
SLANG_UNREACHABLE("visitGenericConstraintDeclRefType in LowerVisitor");
786+
return nullptr;
787+
}
788+
782789
RefPtr<Type> visitAssocTypeDeclRefType(AssocTypeDeclRefType* type)
783790
{
784791
// not supported by lowering

source/slang/mangle.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ namespace Slang
117117
{
118118
emitQualifiedName(context, declRefType->declRef);
119119
}
120+
else if (auto assocTypeDeclRefType = dynamic_cast<AssocTypeDeclRefType*>(type))
121+
{
122+
emitQualifiedName(context, assocTypeDeclRefType->declRef);
123+
}
120124
else
121125
{
122126
SLANG_UNEXPECTED("unimplemented case in mangling");

source/slang/syntax.cpp

+113-5
Original file line numberDiff line numberDiff line change
@@ -959,15 +959,60 @@ void Type::accept(IValVisitor* visitor, void* extra)
959959

960960
RefPtr<Val> AssocTypeDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff)
961961
{
962-
auto parentType = this->GetDeclRef().GetParent().SubstituteImpl(subst, ioDiff);
963-
if (auto aggDeclRef = parentType.As<AggTypeDecl>())
962+
if (!sourceType)
963+
return this;
964+
if (auto parentDeclRefType = sourceType->As<DeclRefType>())
964965
{
965-
Decl* targetTypeDecl = nullptr;
966-
if (aggDeclRef.getDecl()->memberDictionary.TryGetValue(this->GetDeclRef().decl->getName(), targetTypeDecl))
966+
auto parentDeclRef = parentDeclRefType->declRef;
967+
DeclRef<AggTypeDecl> newParentDeclRef = parentDeclRef.As<AggTypeDecl>();
968+
// search for a substitution that might apply to us
969+
for (auto s = subst; s; s = s->outer.Ptr())
967970
{
968-
return DeclRefType::Create(this->session, DeclRef<Decl>(targetTypeDecl, parentType.substitutions));
971+
// the generic decl associated with the substitution list must be
972+
// the generic decl that declared this parameter
973+
auto genericDecl = s->genericDecl;
974+
if (genericDecl != parentDeclRef.getDecl()->ParentDecl)
975+
continue;
976+
int index = 0;
977+
for (auto m : genericDecl->Members)
978+
{
979+
if (m.Ptr() == parentDeclRef.getDecl())
980+
{
981+
// We've found it, so return the corresponding specialization argument
982+
(*ioDiff)++;
983+
if (auto declRef = s->args[index].As<DeclRefType>())
984+
{
985+
newParentDeclRef = (*declRef).declRef.As<AggTypeDecl>();
986+
goto searchEnd;
987+
}
988+
}
989+
else if (auto typeParam = m.As<GenericTypeParamDecl>())
990+
{
991+
index++;
992+
}
993+
else if (auto valParam = m.As<GenericValueParamDecl>())
994+
{
995+
index++;
996+
}
997+
else
998+
{
999+
}
1000+
}
1001+
}
1002+
searchEnd:
1003+
if (newParentDeclRef)
1004+
{
1005+
Decl* targetTypeDecl = nullptr;
1006+
if (newParentDeclRef.getDecl()->memberDictionary.TryGetValue(this->GetDeclRef().decl->getName(), targetTypeDecl))
1007+
{
1008+
if (auto typeDefDecl = targetTypeDecl->As<TypeDefDecl>())
1009+
return GetType(DeclRef<TypeDefDecl>(typeDefDecl, subst));
1010+
else
1011+
return DeclRefType::Create(this->getSession(), DeclRef<Decl>(targetTypeDecl, subst));
1012+
}
9691013
}
9701014
}
1015+
9711016
return this;
9721017
}
9731018

@@ -981,6 +1026,69 @@ void Type::accept(IValVisitor* visitor, void* extra)
9811026
return this;
9821027
}
9831028

1029+
// GenericConstraintDeclRefType
1030+
1031+
String GenericConstraintDeclRefType::ToString()
1032+
{
1033+
// TODO: what is appropriate here?
1034+
return "<GenericConstraintType>";
1035+
}
1036+
1037+
bool GenericConstraintDeclRefType::EqualsImpl(Type * type)
1038+
{
1039+
if (auto other = type->As<GenericConstraintDeclRefType>())
1040+
{
1041+
return supType->Equals(other->supType) && subType->Equals(other->subType);
1042+
}
1043+
return false;
1044+
}
1045+
1046+
RefPtr<Val> GenericConstraintDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff)
1047+
{
1048+
auto genParamDecl = subType.As<DeclRefType>()->declRef.As<GenericTypeParamDecl>();
1049+
// search for a substitution that might apply to us
1050+
for (auto s = subst; s; s = s->outer.Ptr())
1051+
{
1052+
// the generic decl associated with the substitution list must be
1053+
// the generic decl that declared this parameter
1054+
auto genericDecl = s->genericDecl;
1055+
if (genericDecl != genParamDecl.getDecl()->ParentDecl)
1056+
continue;
1057+
int index = 0;
1058+
for (auto m : genericDecl->Members)
1059+
{
1060+
if (m.Ptr() == genParamDecl.getDecl())
1061+
{
1062+
// We've found it, so return the corresponding specialization argument
1063+
(*ioDiff)++;
1064+
return s->args[index];
1065+
}
1066+
else if (auto typeParam = m.As<GenericTypeParamDecl>())
1067+
{
1068+
index++;
1069+
}
1070+
else if (auto valParam = m.As<GenericValueParamDecl>())
1071+
{
1072+
index++;
1073+
}
1074+
else
1075+
{
1076+
}
1077+
}
1078+
}
1079+
return this;
1080+
}
1081+
1082+
int GenericConstraintDeclRefType::GetHashCode()
1083+
{
1084+
return combineHash(subType.GetHashCode(), supType.GetHashCode());
1085+
}
1086+
1087+
Type* GenericConstraintDeclRefType::CreateCanonicalType()
1088+
{
1089+
return this;
1090+
}
1091+
9841092
// ArithmeticExpressionType
9851093

9861094
// VectorExpressionType

0 commit comments

Comments
 (0)