Skip to content

Commit

Permalink
Address low hanging fruit comments and format
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhangNV committed Jan 17, 2025
1 parent adec731 commit db3063a
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 37 deletions.
2 changes: 1 addition & 1 deletion docs/proposals/004-initialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ MyType x = MyType(y); // equivalent to `x = y`.
The compiler will attempt to resolve all type casts using type coercion rules, if that failed, will fall back to resolve it as a constructor call.

### Inheritance Initialization
For derived struct, slang will synthesized the constructor by bring the parameters from the base struct's constructor if the base struct also has a synthesized constructor. For example:
For derived structs, slang will synthesized the constructor by bringing the parameters from the base struct's constructor if the base struct also has a synthesized constructor. For example:
```csharp
struct Base
{
Expand Down
14 changes: 7 additions & 7 deletions source/slang/slang-ast-decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,20 +378,20 @@ class ConstructorDecl : public FunctionDeclBase
{
SLANG_AST_CLASS(ConstructorDecl)

enum class ConstructorTags : int
enum class ConstructorFlavor : int
{
None = 0x00,
UserDefined = 0x00,
// Indicates whether the declaration was synthesized by
// Slang and not explicitly provided by the user
Synthesized = 0x01,
SynthesizedDefault = 0x01,
// Member initialize constructor is a synthesized ctor,
// but it takes parameters.
MemberInitCtor = 0x02
SynthesizedMemberInit = 0x02
};

int m_tags = (int)ConstructorTags::None;
void addTag(ConstructorTags tag) { m_tags |= (int)tag; }
bool containsTag(ConstructorTags tag) { return m_tags & (int)tag; }
int m_flavor = (int)ConstructorFlavor::UserDefined;
void addFlavor(ConstructorFlavor flavor) { m_flavor |= (int)flavor; }
bool containsFlavor(ConstructorFlavor flavor) { return m_flavor & (int)flavor; }
};

// A subscript operation used to index instances of a type
Expand Down
29 changes: 19 additions & 10 deletions source/slang/slang-check-conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ DeclRef<StructDecl> findBaseStructDeclRef(

ConstructorDecl* SemanticsVisitor::_getSynthesizedConstructor(
StructDecl* structDecl,
ConstructorDecl::ConstructorTags tags)
ConstructorDecl::ConstructorFlavor flavor)
{
ConstructorDecl* synthesizedCtor = nullptr;
if (structDecl->m_synthesizedCtorMap.tryGetValue((int)tags, synthesizedCtor))
if (structDecl->m_synthesizedCtorMap.tryGetValue((int)flavor, synthesizedCtor))
{
return synthesizedCtor;
}
Expand All @@ -224,11 +224,8 @@ static StructDecl* _getStructDecl(Type* type)
return nullptr;
}

if (auto declRefType = as<DeclRefType>(type))
{
auto structDecl = as<StructDecl>(declRefType->getDeclRef());
if (auto structDecl = isDeclRefTypeOf<StructDecl>(type))
return structDecl.getDecl();
}

return nullptr;
}
Expand Down Expand Up @@ -302,11 +299,23 @@ bool SemanticsVisitor::isCStyleStruct(StructDecl* structDecl)
// if the member is an array, check if the element is legacy C-style rule.
if (auto arrayType = as<ArrayExpressionType>(varDecl->getType()))
{
if (arrayType->isUnsized())
{
getShared()->cacheCStyleStruct(structDecl, false);
return false;
}
auto* elementType = arrayType->getElementType();
for (;;)
{
if (auto nextType = as<ArrayExpressionType>(elementType))
{
if (arrayType->isUnsized())
{
getShared()->cacheCStyleStruct(structDecl, false);
return false;
}
elementType = nextType->getElementType();
}
else
break;
}
Expand Down Expand Up @@ -363,7 +372,7 @@ Expr* SemanticsVisitor::_createCtorInvokeExpr(
}

// translation from initializer list to constructor invocation if the struct has constructor.
bool SemanticsVisitor::_invokeExprForExplicitCtor(
bool SemanticsVisitor::createInvokeExprForExplicitCtor(
Type* toType,
InitializerListExpr* fromInitializerListExpr,
Expr** outExpr)
Expand Down Expand Up @@ -407,7 +416,7 @@ bool SemanticsVisitor::_invokeExprForExplicitCtor(
return false;
}

bool SemanticsVisitor::_invokeExprForSynthesizedCtor(
bool SemanticsVisitor::createInvokeExprForSynthesizedCtor(
Type* toType,
InitializerListExpr* fromInitializerListExpr,
Expr** outExpr)
Expand Down Expand Up @@ -846,13 +855,13 @@ bool SemanticsVisitor::_coerceInitializerList(
// Try to invoke the user-defined constructor if it exists. This call will
// report error diagnostics if the used-defined constructor exists but does not
// match the initialize list.
if (_invokeExprForExplicitCtor(toType, fromInitializerListExpr, outToExpr))
if (createInvokeExprForExplicitCtor(toType, fromInitializerListExpr, outToExpr))
{
return true;
}

// Try to invoke the synthesized constructor if it exists
if (_invokeExprForSynthesizedCtor(toType, fromInitializerListExpr, outToExpr))
if (createInvokeExprForSynthesizedCtor(toType, fromInitializerListExpr, outToExpr))
{
return true;
}
Expand Down
28 changes: 15 additions & 13 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2077,7 +2077,7 @@ static ConstructorDecl* _createCtor(
body->closingSourceLoc = ctor->closingSourceLoc;
ctor->body = body;
body->body = m_astBuilder->create<SeqStmt>();
ctor->addTag(ConstructorDecl::ConstructorTags::Synthesized);
ctor->addFlavor(ConstructorDecl::ConstructorFlavor::SynthesizedDefault);
decl->addMember(ctor);
addAutoDiffModifiersToFunc(visitor, m_astBuilder, ctor);
addVisibilityModifier(m_astBuilder, ctor, ctorVisibility);
Expand Down Expand Up @@ -2171,7 +2171,7 @@ static void checkSynthesizedConstructorWithoutDiagnostic(VisitorType& subVisitor
structDecl->invalidateMemberDictionary();
structDecl->buildMemberDictionary();
structDecl->m_synthesizedCtorMap.remove(
(int)ConstructorDecl::ConstructorTags::MemberInitCtor);
(int)ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit);
}
return;
}
Expand Down Expand Up @@ -8144,7 +8144,8 @@ void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
// When checking the synthesized constructor, it's possible to hit error, but we don't want
// to report this error, because this function is not created by user. Instead, when we
// detect this error, we will remove this synthesized constructor from the struct.
if (constructorDecl->containsTag(ConstructorDecl::ConstructorTags::MemberInitCtor) &&
if (constructorDecl->containsFlavor(
ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit) &&
!m_checkForSynthesizedCtor)
{
DiagnosticSink tempSink;
Expand All @@ -8159,14 +8160,14 @@ void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
auto structDecl = as<StructDecl>(constructorDecl->parentDecl);
ConstructorDecl* defaultCtor = nullptr;
if (structDecl->m_synthesizedCtorMap.tryGetValue(
(int)ConstructorDecl::ConstructorTags::Synthesized,
(int)ConstructorDecl::ConstructorFlavor::SynthesizedDefault,
defaultCtor))
{
structDecl->members.remove(defaultCtor);
structDecl->invalidateMemberDictionary();
structDecl->buildMemberDictionary();
structDecl->m_synthesizedCtorMap.remove(
(int)ConstructorDecl::ConstructorTags::Synthesized);
(int)ConstructorDecl::ConstructorFlavor::SynthesizedDefault);
}
}
return;
Expand Down Expand Up @@ -9239,7 +9240,7 @@ void SemanticsDeclBodyVisitor::synthesizeCtorBodyForBases(
// base's member initialize ctor. e.g. base->init(...);
baseCtor = _getSynthesizedConstructor(
declInfo.parent,
ConstructorDecl::ConstructorTags::MemberInitCtor);
ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit);
if (baseCtor)
{
Index idx = 0;
Expand Down Expand Up @@ -9375,7 +9376,8 @@ void SemanticsDeclBodyVisitor::synthesizeCtorBody(
// We treat the ctor with parameters and all parameters have default value as default ctor
// as well, but the method to synthesize them are totally different, therefore, we need to
// differentiate them here.
bool isMemberInitCtor = ctor->containsTag(ConstructorDecl::ConstructorTags::MemberInitCtor);
bool isMemberInitCtor =
ctor->containsFlavor(ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit);

// When we synthesize the member initialize constructor, we need to use the parameters in
// the function body, so this inout parameter is used to keep track of the index of the
Expand Down Expand Up @@ -9464,7 +9466,7 @@ void SemanticsDeclBodyVisitor::visitAggTypeDecl(AggTypeDecl* aggTypeDecl)
structDecl->invalidateMemberDictionary();
structDecl->buildMemberDictionary();
structDecl->m_synthesizedCtorMap.remove(
(int)ConstructorDecl::ConstructorTags::Synthesized);
(int)ConstructorDecl::ConstructorFlavor::SynthesizedDefault);
}
}
}
Expand Down Expand Up @@ -10019,7 +10021,7 @@ void SemanticsDeclHeaderVisitor::visitConstructorDecl(ConstructorDecl* decl)
// When checking the synthesized constructor, it's possible to hit error, but we don't want to
// report this error, because this function is not created by user. Instead, when we detect this
// error, we will remove this synthesized constructor from the struct.
if (decl->containsTag(ConstructorDecl::ConstructorTags::MemberInitCtor) &&
if (decl->containsFlavor(ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit) &&
!m_checkForSynthesizedCtor)
{
DiagnosticSink tempSink;
Expand Down Expand Up @@ -12262,7 +12264,7 @@ bool SemanticsDeclAttributesVisitor::_searchMembersWithHigherVisibility(
// constructor has parameters
ConstructorDecl* ctor = _getSynthesizedConstructor(
baseTypeDeclRef.getDecl(),
ConstructorDecl::ConstructorTags::MemberInitCtor);
ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit);

// The constructor has to have higher or equal visibility level than the struct itself,
// otherwise, it's not accessible so we will not pick up.
Expand Down Expand Up @@ -12333,9 +12335,9 @@ void SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* struct
// synthesize the constructor signature:
// 1. The constructor's name is always `$init`, we create one without parameters now.
ConstructorDecl* ctor = _createCtor(this, getASTBuilder(), structDecl, ctorVisibility);
ctor->addTag(ConstructorDecl::ConstructorTags::MemberInitCtor);
ctor->addFlavor(ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit);
structDecl->m_synthesizedCtorMap.addIfNotExists(
(int)ConstructorDecl::ConstructorTags::MemberInitCtor,
(int)ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit,
ctor);

ctor->members.reserve(resultMembers.getCount());
Expand Down Expand Up @@ -12386,7 +12388,7 @@ void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl)
DeclVisibility ctorVisibility = getDeclVisibility(structDecl);
auto ctor = _createCtor(this, m_astBuilder, structDecl, ctorVisibility);
structDecl->m_synthesizedCtorMap.addIfNotExists(
(int)ConstructorDecl::ConstructorTags::Synthesized,
(int)ConstructorDecl::ConstructorFlavor::SynthesizedDefault,
ctor);
}

Expand Down
6 changes: 3 additions & 3 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2799,12 +2799,12 @@ struct SemanticsVisitor : public SemanticsContext
CompletionSuggestions::ScopeKind scopeKind,
LookupResult const& lookupResult);

bool _invokeExprForExplicitCtor(
bool createInvokeExprForExplicitCtor(
Type* toType,
InitializerListExpr* fromInitializerListExpr,
Expr** outExpr);

bool _invokeExprForSynthesizedCtor(
bool createInvokeExprForSynthesizedCtor(
Type* toType,
InitializerListExpr* fromInitializerListExpr,
Expr** outExpr);
Expand All @@ -2813,7 +2813,7 @@ struct SemanticsVisitor : public SemanticsContext
bool _hasExplicitConstructor(StructDecl* structDecl, bool checkBaseType);
ConstructorDecl* _getSynthesizedConstructor(
StructDecl* structDecl,
ConstructorDecl::ConstructorTags tags);
ConstructorDecl::ConstructorFlavor flavor);
bool isCStyleStruct(StructDecl* structDecl);
bool _cStyleStructBasicCheck(Decl* decl);
};
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-check-differentiability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ struct CheckDifferentiabilityPassContext : public InstPassBase

bool isSynthesizeConstructor = false;

if(auto constructor = outerFuncInst->findDecoration<IRConstructorDecorartion>())
if (auto constructor = outerFuncInst->findDecoration<IRConstructorDecorartion>())
isSynthesizeConstructor = constructor->getSynthesizedStatus();

// This is a kernel function, we don't allow using TorchTensor type here.
Expand Down
3 changes: 2 additions & 1 deletion source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10238,7 +10238,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Used for diagnostics
getBuilder()->addConstructorDecoration(
irFunc,
constructorDecl->containsTag(ConstructorDecl::ConstructorTags::Synthesized));
constructorDecl->containsFlavor(
ConstructorDecl::ConstructorFlavor::SynthesizedDefault));
}

// We lower whatever statement was stored on the declaration
Expand Down
2 changes: 1 addition & 1 deletion tests/autodiff/differential-type-constructor.slang
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)

outputBuffer[0] = dp.d.a;

var dp2 = diffPair<MyStruct>(MyStruct(0.0, 0), MyStruct.Differential(1.f));;
var dp2 = diffPair<MyStruct>(MyStruct(0.0, 0), MyStruct.Differential(1.f));

outputBuffer[1] = dp2.d.a;
}
Expand Down

0 comments on commit db3063a

Please sign in to comment.