@@ -598,37 +598,60 @@ namespace Slang
598
598
{
599
599
case BuiltinRequirementKind::DifferentialType:
600
600
{
601
- auto structDecl = m_astBuilder->create <StructDecl>();
602
- auto conformanceDecl = m_astBuilder->create <InheritanceDecl>();
603
- conformanceDecl->base .type = m_astBuilder->getDiffInterfaceType ();
604
- conformanceDecl->parentDecl = structDecl;
605
- structDecl->members .add (conformanceDecl);
606
- structDecl->parentDecl = parent;
607
-
608
- synthesizedDecl = structDecl;
609
- auto typeDef = m_astBuilder->create <TypeAliasDecl>();
610
- typeDef->nameAndLoc .name = getName (" Differential" );
611
- typeDef->parentDecl = structDecl;
612
-
613
- auto synthDeclRef = createDefaultSubstitutionsIfNeeded (m_astBuilder, this , makeDeclRef (structDecl));
614
-
615
- typeDef->type .type = DeclRefType::create (m_astBuilder, synthDeclRef);
616
- structDecl->members .add (typeDef);
601
+ if (!canStructBeUsedAsSelfDifferentialType (parent))
602
+ {
603
+ // Need to create a new struct type for the differential.
604
+ //
605
+ auto structDecl = m_astBuilder->create <StructDecl>();
606
+ auto conformanceDecl = m_astBuilder->create <InheritanceDecl>();
607
+ conformanceDecl->base .type = m_astBuilder->getDiffInterfaceType ();
608
+ conformanceDecl->parentDecl = structDecl;
609
+ structDecl->members .add (conformanceDecl);
610
+ structDecl->parentDecl = parent;
611
+
612
+ synthesizedDecl = structDecl;
613
+ auto typeDef = m_astBuilder->create <TypeAliasDecl>();
614
+ typeDef->nameAndLoc .name = getName (" Differential" );
615
+ typeDef->parentDecl = structDecl;
616
+
617
+ auto synthDeclRef = createDefaultSubstitutionsIfNeeded (m_astBuilder, this , makeDeclRef (structDecl));
618
+
619
+ typeDef->type .type = DeclRefType::create (m_astBuilder, synthDeclRef);
620
+ structDecl->members .add (typeDef);
621
+
622
+ synthesizedDecl->parentDecl = parent;
623
+ synthesizedDecl->nameAndLoc .name = item.declRef .getName ();
624
+ synthesizedDecl->loc = parent->loc ;
625
+ parent->members .add (synthesizedDecl);
626
+ parent->invalidateMemberDictionary ();
627
+
628
+ // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
629
+ // from user-provided definitions, and proceed to fill in its definition.
630
+ auto toBeSynthesized = m_astBuilder->create <ToBeSynthesizedModifier>();
631
+ addModifier (synthesizedDecl, toBeSynthesized);
632
+ }
633
+ else
634
+ {
635
+ // There's no need for a new struct decl.
636
+ // We can simply add a typealias to the existing concrete type.
637
+ //
638
+ auto typeDef = m_astBuilder->create <TypeAliasDecl>();
639
+ typeDef->nameAndLoc .name = item.declRef .getName ();
640
+ typeDef->parentDecl = parent;
641
+ typeDef->type .type = subType;
642
+
643
+ synthesizedDecl = parent;
644
+
645
+ parent->members .add (typeDef);
646
+ parent->invalidateMemberDictionary ();
647
+
648
+ markSelfDifferentialMembersOfType (parent, subType);
649
+ }
617
650
}
618
651
break ;
619
652
default :
620
653
return nullptr ;
621
654
}
622
- synthesizedDecl->parentDecl = parent;
623
- synthesizedDecl->nameAndLoc .name = item.declRef .getName ();
624
- synthesizedDecl->loc = parent->loc ;
625
- parent->members .add (synthesizedDecl);
626
- parent->invalidateMemberDictionary ();
627
-
628
- // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
629
- // from user-provided definitions, and proceed to fill in its definition.
630
- auto toBeSynthesized = m_astBuilder->create <ToBeSynthesizedModifier>();
631
- addModifier (synthesizedDecl, toBeSynthesized);
632
655
633
656
auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef (subType->getDeclRef (), synthesizedDecl);
634
657
return ConstructDeclRefExpr (
@@ -1145,6 +1168,51 @@ namespace Slang
1145
1168
return nullptr ;
1146
1169
}
1147
1170
1171
+ bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType (AggTypeDecl *aggTypeDecl)
1172
+ {
1173
+ // A struct can be used as its own differential type if all its members are differentiable
1174
+ // and their differential types are the same as the original types.
1175
+ //
1176
+ bool canBeUsed = true ;
1177
+ for (auto member : aggTypeDecl->members )
1178
+ {
1179
+ if (auto varDecl = as<VarDecl>(member))
1180
+ {
1181
+ // Try to get the differential type of the member.
1182
+ Type* diffType = tryGetDifferentialType (getASTBuilder (), varDecl->getType ());
1183
+ if (!diffType || !diffType->equals (varDecl->getType ()))
1184
+ {
1185
+ canBeUsed = false ;
1186
+ break ;
1187
+ }
1188
+ }
1189
+ }
1190
+ return canBeUsed;
1191
+ }
1192
+
1193
+ void SemanticsVisitor::markSelfDifferentialMembersOfType (AggTypeDecl *parent, Type* type)
1194
+ {
1195
+ // TODO: Handle extensions.
1196
+ // Add derivative member attributes to all the fields pointing to themselves.
1197
+ for (auto member : parent->getMembersOfType <VarDeclBase>())
1198
+ {
1199
+ auto derivativeMemberModifier = m_astBuilder->create <DerivativeMemberAttribute>();
1200
+ auto fieldLookupExpr = m_astBuilder->create <StaticMemberExpr>();
1201
+ fieldLookupExpr->type .type = member->getType ();
1202
+
1203
+ auto baseTypeExpr = m_astBuilder->create <SharedTypeExpr>();
1204
+ baseTypeExpr->base .type = type;
1205
+ auto baseTypeType = m_astBuilder->getOrCreate <TypeType>(type);
1206
+ baseTypeExpr->type .type = baseTypeType;
1207
+ fieldLookupExpr->baseExpression = baseTypeExpr;
1208
+
1209
+ fieldLookupExpr->declRef = makeDeclRef (member);
1210
+
1211
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
1212
+ addModifier (member, derivativeMemberModifier);
1213
+ }
1214
+ }
1215
+
1148
1216
Type* SemanticsVisitor::getDifferentialType (ASTBuilder* builder, Type* type, SourceLoc loc)
1149
1217
{
1150
1218
auto result = tryGetDifferentialType (builder, type);
0 commit comments