Skip to content

Commit c763750

Browse files
Handle case where types can be used as their own Differential type. (shader-slang#4057)
* Avoid synthesis for when types can be used as their own differenial + Add test * Add missing files.. * Fix issue with method synthesis for self-differential types + Add a generic test * Fix * Fix issue with out-of-date type resolution cache. Witness tables created during the conformance checking phase not being taken into account during the decl type resolution phase because the epoch is not updated after conformance checking. This leads to certain complex associated-type lookup chains (such as the one in tests/compute/assoctype-nested-lookup) not resolving properly and causing errors. * Delete self-differential-type-synthesis-extension.slang * Quick fix to repopulate stdlib cache for deferred stdlib loading * Update slang-check-decl.cpp
1 parent e5d49cf commit c763750

10 files changed

+274
-27
lines changed

source/slang/slang-check-decl.cpp

+43-1
Original file line numberDiff line numberDiff line change
@@ -2144,7 +2144,40 @@ namespace Slang
21442144
SLANG_RELEASE_ASSERT(aggTypeDecl);
21452145
synth.pushContainerScope(aggTypeDecl);
21462146
}
2147-
else
2147+
2148+
// If we did not find an existing empty struct, we may need to synthesize one.
2149+
// But first, we check if the parent type can be used as its own differential type.
2150+
//
2151+
if (!aggTypeDecl
2152+
&& as<AggTypeDecl>(context->parentDecl)
2153+
&& canStructBeUsedAsSelfDifferentialType(as<AggTypeDecl>(context->parentDecl)))
2154+
{
2155+
// If the parent type can be used as its own differential type, we will create a typealias
2156+
// to itself as the differential type.
2157+
//
2158+
auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
2159+
assocTypeDef->nameAndLoc.name = getName("Differential");
2160+
assocTypeDef->type.type = context->conformingType;
2161+
assocTypeDef->parentDecl = context->parentDecl;
2162+
assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked);
2163+
context->parentDecl->members.add(assocTypeDef);
2164+
2165+
markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType);
2166+
2167+
if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable))
2168+
{
2169+
witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType));
2170+
2171+
// Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types.
2172+
m_astBuilder->incrementEpoch();
2173+
return true;
2174+
}
2175+
2176+
// Something went wrong.
2177+
return false;
2178+
}
2179+
2180+
if (!aggTypeDecl)
21482181
{
21492182
aggTypeDecl = m_astBuilder->create<StructDecl>();
21502183
aggTypeDecl->parentDecl = context->parentDecl;
@@ -5741,6 +5774,15 @@ namespace Slang
57415774
{
57425775
checkConformance(type, inheritanceDecl, decl);
57435776
}
5777+
5778+
// Successful conformance checking may have created new witness tables.
5779+
// Increment epoch to invalidate the cache, so subsequent canonical types are
5780+
// re-calculated.
5781+
//
5782+
// TODO: Is it really necessary to invalidate globally? Maybe there's a way to invalidate only the
5783+
// types that are affected by these interface decls.
5784+
//
5785+
astBuilder->incrementEpoch();
57445786
}
57455787
}
57465788

source/slang/slang-check-expr.cpp

+94-26
Original file line numberDiff line numberDiff line change
@@ -598,37 +598,60 @@ namespace Slang
598598
{
599599
case BuiltinRequirementKind::DifferentialType:
600600
{
601-
auto structDecl = m_astBuilder->create<StructDecl>();
602-
auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
603-
conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
604-
conformanceDecl->parentDecl = structDecl;
605-
structDecl->members.add(conformanceDecl);
606-
structDecl->parentDecl = parent;
607-
608-
synthesizedDecl = structDecl;
609-
auto typeDef = m_astBuilder->create<TypeAliasDecl>();
610-
typeDef->nameAndLoc.name = getName("Differential");
611-
typeDef->parentDecl = structDecl;
612-
613-
auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
614-
615-
typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
616-
structDecl->members.add(typeDef);
601+
if (!canStructBeUsedAsSelfDifferentialType(parent))
602+
{
603+
// Need to create a new struct type for the differential.
604+
//
605+
auto structDecl = m_astBuilder->create<StructDecl>();
606+
auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
607+
conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
608+
conformanceDecl->parentDecl = structDecl;
609+
structDecl->members.add(conformanceDecl);
610+
structDecl->parentDecl = parent;
611+
612+
synthesizedDecl = structDecl;
613+
auto typeDef = m_astBuilder->create<TypeAliasDecl>();
614+
typeDef->nameAndLoc.name = getName("Differential");
615+
typeDef->parentDecl = structDecl;
616+
617+
auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
618+
619+
typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
620+
structDecl->members.add(typeDef);
621+
622+
synthesizedDecl->parentDecl = parent;
623+
synthesizedDecl->nameAndLoc.name = item.declRef.getName();
624+
synthesizedDecl->loc = parent->loc;
625+
parent->members.add(synthesizedDecl);
626+
parent->invalidateMemberDictionary();
627+
628+
// Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
629+
// from user-provided definitions, and proceed to fill in its definition.
630+
auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
631+
addModifier(synthesizedDecl, toBeSynthesized);
632+
}
633+
else
634+
{
635+
// There's no need for a new struct decl.
636+
// We can simply add a typealias to the existing concrete type.
637+
//
638+
auto typeDef = m_astBuilder->create<TypeAliasDecl>();
639+
typeDef->nameAndLoc.name = item.declRef.getName();
640+
typeDef->parentDecl = parent;
641+
typeDef->type.type = subType;
642+
643+
synthesizedDecl = parent;
644+
645+
parent->members.add(typeDef);
646+
parent->invalidateMemberDictionary();
647+
648+
markSelfDifferentialMembersOfType(parent, subType);
649+
}
617650
}
618651
break;
619652
default:
620653
return nullptr;
621654
}
622-
synthesizedDecl->parentDecl = parent;
623-
synthesizedDecl->nameAndLoc.name = item.declRef.getName();
624-
synthesizedDecl->loc = parent->loc;
625-
parent->members.add(synthesizedDecl);
626-
parent->invalidateMemberDictionary();
627-
628-
// Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
629-
// from user-provided definitions, and proceed to fill in its definition.
630-
auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
631-
addModifier(synthesizedDecl, toBeSynthesized);
632655

633656
auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl);
634657
return ConstructDeclRefExpr(
@@ -1145,6 +1168,51 @@ namespace Slang
11451168
return nullptr;
11461169
}
11471170

1171+
bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl)
1172+
{
1173+
// A struct can be used as its own differential type if all its members are differentiable
1174+
// and their differential types are the same as the original types.
1175+
//
1176+
bool canBeUsed = true;
1177+
for (auto member : aggTypeDecl->members)
1178+
{
1179+
if (auto varDecl = as<VarDecl>(member))
1180+
{
1181+
// Try to get the differential type of the member.
1182+
Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType());
1183+
if (!diffType || !diffType->equals(varDecl->getType()))
1184+
{
1185+
canBeUsed = false;
1186+
break;
1187+
}
1188+
}
1189+
}
1190+
return canBeUsed;
1191+
}
1192+
1193+
void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type)
1194+
{
1195+
// TODO: Handle extensions.
1196+
// Add derivative member attributes to all the fields pointing to themselves.
1197+
for (auto member : parent->getMembersOfType<VarDeclBase>())
1198+
{
1199+
auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
1200+
auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
1201+
fieldLookupExpr->type.type = member->getType();
1202+
1203+
auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
1204+
baseTypeExpr->base.type = type;
1205+
auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type);
1206+
baseTypeExpr->type.type = baseTypeType;
1207+
fieldLookupExpr->baseExpression = baseTypeExpr;
1208+
1209+
fieldLookupExpr->declRef = makeDeclRef(member);
1210+
1211+
derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
1212+
addModifier(member, derivativeMemberModifier);
1213+
}
1214+
}
1215+
11481216
Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc)
11491217
{
11501218
auto result = tryGetDifferentialType(builder, type);

source/slang/slang-check-impl.h

+3
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,9 @@ namespace Slang
13321332
Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc);
13331333
Type* tryGetDifferentialType(ASTBuilder* builder, Type* type);
13341334

1335+
// Helper function to check if a struct can be used as its own differential type.
1336+
bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl);
1337+
void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type);
13351338

13361339
public:
13371340

source/slang/slang-options.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1707,6 +1707,10 @@ SlangResult OptionsParser::_parse(
17071707
ScopedAllocation contents;
17081708
SLANG_RETURN_ON_FAIL(File::readAllBytes(fileName.value, contents));
17091709
SLANG_RETURN_ON_FAIL(m_session->loadStdLib(contents.getData(), contents.getSizeInBytes()));
1710+
1711+
// Ensure that the linkage's AST builder is up-to-date.
1712+
linkage->getASTBuilder()->m_cachedNodes = asInternal(m_session)->getGlobalASTBuilder()->m_cachedNodes;
1713+
17101714
break;
17111715
}
17121716
case OptionKind::CompileStdLib: m_compileStdLib = true; break;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type)
8+
// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors.
9+
//
10+
11+
12+
struct Ray<let N: int> : IDifferentiable {
13+
float a;
14+
vector<float, N> dir, o;
15+
}
16+
17+
[numthreads(1, 1, 1)]
18+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
19+
{
20+
Ray<4> ray = Ray<4>();
21+
Ray<4>.Differential ray2;
22+
23+
ray.a = 1.f;
24+
ray.o = float4(3.f, 4.f, 2.5f, 1.f);
25+
26+
ray2 = ray;
27+
28+
float t = 0.f;
29+
float.Differential dt = 0.f;
30+
31+
t = dt;
32+
33+
outputBuffer[0] = t;
34+
outputBuffer[1] = ray2.o.y;
35+
outputBuffer[2] = Ray<4>.dadd(ray2, ray2).o.w;
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
type: float
2+
0.000000
3+
4.000000
4+
2.000000
5+
0.000000
6+
0.000000
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3+
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
4+
5+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
6+
RWStructuredBuffer<float> outputBuffer;
7+
8+
// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type)
9+
// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors.
10+
// 1
11+
12+
struct Ray : IDifferentiable {
13+
float a;
14+
float3 dir, o;
15+
}
16+
17+
[numthreads(1, 1, 1)]
18+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
19+
{
20+
Ray ray = Ray();
21+
Ray.Differential ray2;
22+
23+
ray.a = 1.f;
24+
ray.o = float3(3.f, 4.f, 2.5f);
25+
26+
ray2 = ray;
27+
28+
float t = 0.f;
29+
float.Differential dt = 0.f;
30+
31+
t = dt;
32+
33+
outputBuffer[0] = t;
34+
outputBuffer[1] = ray2.o.y;
35+
outputBuffer[2] = Ray.dadd(ray2, ray2).a;
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
type: float
2+
0.000000
3+
4.000000
4+
2.000000
5+
0.000000
6+
0.000000
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
4+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
5+
6+
//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
7+
RWStructuredBuffer<float> outputBuffer;
8+
9+
interface IFoo
10+
{
11+
associatedtype Bar : IFoo;
12+
};
13+
14+
15+
struct FooPair<T : IFoo> : IFoo
16+
{
17+
T a;
18+
T.Bar b;
19+
20+
typealias Bar = FooPair<T.Bar>;
21+
};
22+
23+
24+
struct ConcreteFoo : IFoo
25+
{
26+
typealias Bar = ConcreteFoo;
27+
28+
float x;
29+
};
30+
31+
void test(FooPair<ConcreteFoo>.Bar pair)
32+
{
33+
pair.a.x = 1.0;
34+
pair.b.x = 2.0;
35+
36+
outputBuffer[0] = pair.a.x + pair.b.x;
37+
}
38+
39+
[numthreads(1, 1, 1)]
40+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
41+
{
42+
FooPair<ConcreteFoo>.Bar pair;
43+
test(pair);
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
type: float
2+
3.000000

0 commit comments

Comments
 (0)