Skip to content

Commit d1c3852

Browse files
author
Tim Foley
authored
Support for transitive subtype witnesses (shader-slang#331)
* Change stdlib `saturate` to explicitly specialize `clamp` This exposes issue shader-slang#329, and so gives us an easy way to see if transitive subtype witnesses have been implemented correctly. * Fixup: invoke correct `clamp` overloads When switching the `clamp` calls in the stdlib definition of `saturate` I made two big mistakes: 1. I was passing in `<T>` in all cases, instead of, e.g., `<vector<T,N>>` in the vector case 2. Of course, the overloads don't actually take `<vector<T,N>>` for the vector case, because `vector<T,N>` is not a `__BuiltinArithmeticType` (`T` is), so instead it should be `clamp<T,N>(...)`. The issue behind (2) is that we don't support "conditional conformances," which would be a way to say that when `T : __BuiltinArithmeticType` then `vector<T,N> : __BuiltinArithmeticType`. That would be a great long-term wish-list feature, but not something I can see us adding in a hurry. Anyway the fix here is the simple one: change the vector/matrix call sites to invoke the correct overload in each case. * Add a notion of transitive subtype witnesses There are two pieces here: 1. Add the `TransitiveSubtypeWitness` class. This is a witness that `A : C` that works by storing nested subtype witnesses that show that `A : B` and `B : C` for some intermediate type `B`. All the basic `Val` operations are easy enough to define on this. - The one gotcha case is whether we can ever simplify away a `TransitiveSubtypeWitness` as part of substitution. That is, if we end up substituting so that both `A` and `B` end up as the same type, then we really just need the `B : C` sub-part. Stuff like that is left as future work. 2. Make the logic in `check.cpp` that constructs subtype witnesses based on found inheritance and constraint declarations able to build up transitive chains. Most of the required infrastructure was already there (the search process maintains a trail of "breadcrumbs" that represent all the steps getting from `A : B` to `B : C` to `C : D` ...). This change does *not* deal with the required changes in the IR to take advantage of transitive witnesses.
1 parent fab52a1 commit d1c3852

File tree

5 files changed

+156
-24
lines changed

5 files changed

+156
-24
lines changed

source/slang/check.cpp

+57-18
Original file line numberDiff line numberDiff line change
@@ -3678,9 +3678,24 @@ namespace Slang
36783678
struct TypeWitnessBreadcrumb
36793679
{
36803680
TypeWitnessBreadcrumb* prev;
3681+
3682+
RefPtr<Type> sub;
3683+
RefPtr<Type> sup;
36813684
DeclRef<Decl> declRef;
36823685
};
36833686

3687+
// Crete a subtype witness based on the declared relationship
3688+
// found in a single breadcrumb
3689+
RefPtr<SubtypeWitness> createSimplSubtypeWitness(
3690+
TypeWitnessBreadcrumb* breadcrumb)
3691+
{
3692+
RefPtr<DeclaredSubtypeWitness> witness = new DeclaredSubtypeWitness();
3693+
witness->sub = breadcrumb->sub;
3694+
witness->sup = breadcrumb->sup;
3695+
witness->declRef = breadcrumb->declRef;
3696+
return witness;
3697+
}
3698+
36843699
RefPtr<Val> createTypeWitness(
36853700
RefPtr<Type> type,
36863701
DeclRef<InterfaceDecl> interfaceDeclRef,
@@ -3696,29 +3711,46 @@ namespace Slang
36963711
UNREACHABLE_RETURN(nullptr);
36973712
}
36983713

3699-
auto breadcrumbs = inBreadcrumbs;
3714+
// We might have one or more steps in the breadcrumb trail, e.g.:
3715+
//
3716+
// (A : B) (B : C) (C : D)
3717+
//
3718+
// The chain is stored as a reversed linked list, so that
3719+
// the first entry would be the `(C : D)` relationship
3720+
// above.
3721+
//
3722+
// We are going to walk the list and build up a suitable
3723+
// subtype witness.
3724+
auto bb = inBreadcrumbs;
3725+
3726+
// Create a witness for the last step in the chain
3727+
RefPtr<SubtypeWitness> witness = createSimplSubtypeWitness(bb);
3728+
bb = bb->prev;
37003729

3701-
auto bb = breadcrumbs;
3702-
breadcrumbs = breadcrumbs->prev;
3730+
// Now, as long as we have more entries to deal with,
3731+
// we'll be in a situation like:
3732+
//
3733+
// ... (B : C) <witness>
3734+
//
3735+
// and we want to wrap up one more link in our chain.
37033736

3704-
if(breadcrumbs)
3737+
while (bb)
37053738
{
3706-
// There are multiple steps in the proof, so
3707-
// we need a transitive witness to show that
3708-
// because `A : B` and `B : C` then `A : C`
3709-
//
3710-
SLANG_UNEXPECTED("transitive type witness");
3711-
UNREACHABLE_RETURN(nullptr);
3712-
}
3739+
// Create simple witness for the step in the chain
3740+
RefPtr<SubtypeWitness> link = createSimplSubtypeWitness(bb);
37133741

3714-
// Simple case: we have a single declaration
3715-
// that shows that `type` conforms to `interfaceDeclRef`.
3716-
//
3742+
// Now join the link onto the existing chain represented
3743+
// by `witness`.
3744+
RefPtr<TransitiveSubtypeWitness> transitiveWitness = new TransitiveSubtypeWitness();
3745+
transitiveWitness->sub = link->sub;
3746+
transitiveWitness->sup = witness->sup;
3747+
transitiveWitness->subToMid = link;
3748+
transitiveWitness->midToSup = witness;
3749+
3750+
witness = transitiveWitness;
3751+
bb = bb->prev;
3752+
}
37173753

3718-
RefPtr<DeclaredSubtypeWitness> witness = new DeclaredSubtypeWitness();
3719-
witness->sub = type;
3720-
witness->sup = DeclRefType::Create(getSession(), interfaceDeclRef);
3721-
witness->declRef = bb->declRef;
37223754
return witness;
37233755
}
37243756

@@ -3772,6 +3804,9 @@ namespace Slang
37723804
// the inheritance declaration.
37733805
TypeWitnessBreadcrumb breadcrumb;
37743806
breadcrumb.prev = inBreadcrumbs;
3807+
3808+
breadcrumb.sub = type;
3809+
breadcrumb.sup = inheritedType;
37753810
breadcrumb.declRef = inheritanceDeclRef;
37763811

37773812
if(doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb))
@@ -3786,6 +3821,8 @@ namespace Slang
37863821
auto inheritedType = GetSup(genConstraintDeclRef);
37873822
TypeWitnessBreadcrumb breadcrumb;
37883823
breadcrumb.prev = inBreadcrumbs;
3824+
breadcrumb.sub = type;
3825+
breadcrumb.sup = inheritedType;
37893826
breadcrumb.declRef = genConstraintDeclRef;
37903827
if (doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb))
37913828
{
@@ -3818,6 +3855,8 @@ namespace Slang
38183855

38193856
TypeWitnessBreadcrumb breadcrumb;
38203857
breadcrumb.prev = inBreadcrumbs;
3858+
breadcrumb.sub = sub;
3859+
breadcrumb.sup = sup;
38213860
breadcrumb.declRef = constraintDeclRef;
38223861

38233862
if(doesTypeConformToInterfaceImpl(originalType, sup, interfaceDeclRef, outWitness, &breadcrumb))

source/slang/hlsl.meta.slang

+3-3
Original file line numberDiff line numberDiff line change
@@ -824,14 +824,14 @@ __generic<T : __BuiltinFloatingPointType>
824824
__specialized_for_target(glsl)
825825
T saturate(T x)
826826
{
827-
return clamp(x, T(0), T(1));
827+
return clamp<T>(x, T(0), T(1));
828828
}
829829

830830
__generic<T : __BuiltinFloatingPointType, let N : int>
831831
__specialized_for_target(glsl)
832832
vector<T,N> saturate(vector<T,N> x)
833833
{
834-
return clamp(x,
834+
return clamp<T,N>(x,
835835
vector<T,N>(T(0)),
836836
vector<T,N>(T(1)));
837837
}
@@ -846,7 +846,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
846846
__specialized_for_target(glsl)
847847
matrix<T,N,M> saturate(matrix<T,N,M> x)
848848
{
849-
return clamp(x,
849+
return clamp<T,N,M>(x,
850850
__scalarToMatrix<T,N,M>(T(0)),
851851
__scalarToMatrix<T,N,M>(T(1)));
852852
}

source/slang/hlsl.meta.slang.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -826,14 +826,14 @@ sb << "__generic<T : __BuiltinFloatingPointType>\n";
826826
sb << "__specialized_for_target(glsl)\n";
827827
sb << "T saturate(T x)\n";
828828
sb << "{\n";
829-
sb << " return clamp(x, T(0), T(1));\n";
829+
sb << " return clamp<T>(x, T(0), T(1));\n";
830830
sb << "}\n";
831831
sb << "\n";
832832
sb << "__generic<T : __BuiltinFloatingPointType, let N : int>\n";
833833
sb << "__specialized_for_target(glsl)\n";
834834
sb << "vector<T,N> saturate(vector<T,N> x)\n";
835835
sb << "{\n";
836-
sb << " return clamp(x,\n";
836+
sb << " return clamp<T,N>(x,\n";
837837
sb << " vector<T,N>(T(0)),\n";
838838
sb << " vector<T,N>(T(1)));\n";
839839
sb << "}\n";
@@ -848,7 +848,7 @@ sb << "__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>\n";
848848
sb << "__specialized_for_target(glsl)\n";
849849
sb << "matrix<T,N,M> saturate(matrix<T,N,M> x)\n";
850850
sb << "{\n";
851-
sb << " return clamp(x,\n";
851+
sb << " return clamp<T,N,M>(x,\n";
852852
sb << " __scalarToMatrix<T,N,M>(T(0)),\n";
853853
sb << " __scalarToMatrix<T,N,M>(T(1)));\n";
854854
sb << "}\n";

source/slang/syntax.cpp

+78
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,84 @@ void Type::accept(IValVisitor* visitor, void* extra)
17411741
return hash;
17421742
}
17431743

1744+
// TransitiveSubtypeWitness
1745+
1746+
bool TransitiveSubtypeWitness::EqualsVal(Val* val)
1747+
{
1748+
auto otherWitness = dynamic_cast<TransitiveSubtypeWitness*>(val);
1749+
if(!otherWitness)
1750+
return false;
1751+
1752+
return sub->Equals(otherWitness->sub)
1753+
&& sup->Equals(otherWitness->sup)
1754+
&& subToMid->EqualsVal(otherWitness->subToMid)
1755+
&& midToSup->EqualsVal(otherWitness->midToSup);
1756+
}
1757+
1758+
RefPtr<Val> TransitiveSubtypeWitness::SubstituteImpl(Substitutions* subst, int * ioDiff)
1759+
{
1760+
int diff = 0;
1761+
1762+
RefPtr<Type> substSub = sub->SubstituteImpl(subst, &diff).As<Type>();
1763+
RefPtr<Type> substSup = sup->SubstituteImpl(subst, &diff).As<Type>();
1764+
RefPtr<SubtypeWitness> substSubToMid = subToMid->SubstituteImpl(subst, &diff).As<SubtypeWitness>();
1765+
RefPtr<SubtypeWitness> substMidToSup = midToSup->SubstituteImpl(subst, &diff).As<SubtypeWitness>();
1766+
1767+
// If nothing changed, then we can bail out early.
1768+
if (!diff)
1769+
return this;
1770+
1771+
// Something changes, so let the caller know.
1772+
(*ioDiff)++;
1773+
1774+
// TODO: are there cases where we can simplify?
1775+
//
1776+
// In principle, if either `subToMid` or `midToSub` turns into
1777+
// a reflexive subtype witness, then we could drop that side,
1778+
// and just return the other one (this would imply that `sub == mid`
1779+
// or `mid == sup` after substitutions).
1780+
//
1781+
// In the long run, is it also possible that if `sub` gets resolved
1782+
// to a concrete type *and* we decide to flatten out the inheritance
1783+
// graph into a linearized "class precedence list" stored in any
1784+
// aggregate type, then we could potentially just redirect to point
1785+
// to the appropriate inheritance decl in the original type.
1786+
//
1787+
// For now I'm going to ignore those possibilities and hope for the best.
1788+
1789+
// In the simple case, we just construct a new transitive subtype
1790+
// witness, and we move on with life.
1791+
RefPtr<TransitiveSubtypeWitness> result = new TransitiveSubtypeWitness();
1792+
result->sub = substSub;
1793+
result->sup = substSup;
1794+
result->subToMid = substSubToMid;
1795+
result->midToSup = substMidToSup;
1796+
return result;
1797+
}
1798+
1799+
String TransitiveSubtypeWitness::ToString()
1800+
{
1801+
// Note: we only print the constituent
1802+
// witnesses, and rely on them to print
1803+
// the starting and ending types.
1804+
StringBuilder sb;
1805+
sb << "TransitiveSubtypeWitness(";
1806+
sb << this->subToMid->ToString();
1807+
sb << ", ";
1808+
sb << this->midToSup->ToString();
1809+
sb << ")";
1810+
return sb.ProduceString();
1811+
}
1812+
1813+
int TransitiveSubtypeWitness::GetHashCode()
1814+
{
1815+
auto hash = sub->GetHashCode();
1816+
hash = combineHash(hash, sup->GetHashCode());
1817+
hash = combineHash(hash, subToMid->GetHashCode());
1818+
hash = combineHash(hash, midToSup->GetHashCode());
1819+
return hash;
1820+
}
1821+
17441822
// IRProxyVal
17451823

17461824
bool IRProxyVal::EqualsVal(Val* val)

source/slang/val-defs.h

+15
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,21 @@ RAW(
9999
)
100100
END_SYNTAX_CLASS()
101101

102+
// A witness that `sub : sup` because `sub : mid` and `mid : sup`
103+
SYNTAX_CLASS(TransitiveSubtypeWitness, SubtypeWitness)
104+
// Witness that `sub : mid`
105+
FIELD(RefPtr<SubtypeWitness>, subToMid);
106+
107+
// Witness that `mid : sup`
108+
FIELD(RefPtr<SubtypeWitness>, midToSup);
109+
RAW(
110+
virtual bool EqualsVal(Val* val) override;
111+
virtual String ToString() override;
112+
virtual int GetHashCode() override;
113+
virtual RefPtr<Val> SubstituteImpl(Substitutions * subst, int * ioDiff) override;
114+
)
115+
END_SYNTAX_CLASS()
116+
102117
// A value that is used as a proxy when we need to
103118
// put an IR-level value into AST types
104119
SYNTAX_CLASS(IRProxyVal, Val)

0 commit comments

Comments
 (0)