Skip to content

Commit 19867ff

Browse files
authored
Simplify implicit cast ctors for vector & matrix. (#6408)
* Simplify implicit cast ctors for vector & matrix. * Fix formatting. * Fix tests. * Fix Falcor test. * Mark __builtin_cast as internal.
1 parent 9580e31 commit 19867ff

20 files changed

+349
-94
lines changed

source/slang/core.meta.slang

+24-69
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//public module core;
1+
public module core;
22

33
// Slang `core` library
44

@@ -2367,49 +2367,6 @@ __generic<T> __extension vector<T, 4>
23672367

23682368
${{{{
23692369

2370-
// The above extensions are generic in the *type* of the vector,
2371-
// but explicit in the *size*. We will now declare an extension
2372-
// for each builtin type that is generic in the size.
2373-
//
2374-
for (int tt = 0; tt < kBaseTypeCount; ++tt)
2375-
{
2376-
if(kBaseTypes[tt].tag == BaseType::Void) continue;
2377-
2378-
sb << "__generic<let N : int> __extension vector<"
2379-
<< kBaseTypes[tt].name << ",N>\n{\n";
2380-
2381-
for (int ff = 0; ff < kBaseTypeCount; ++ff)
2382-
{
2383-
if(kBaseTypes[ff].tag == BaseType::Void) continue;
2384-
2385-
2386-
if( tt != ff )
2387-
{
2388-
auto cost = getBaseTypeConversionCost(
2389-
kBaseTypes[tt],
2390-
kBaseTypes[ff]);
2391-
auto op = getBaseTypeConversionOp(
2392-
kBaseTypes[tt],
2393-
kBaseTypes[ff]);
2394-
2395-
// Implicit conversion from a vector of the same
2396-
// size, but different element type.
2397-
sb << " __implicit_conversion(" << cost << ")\n";
2398-
sb << " __intrinsic_op(" << int(op) << ")\n";
2399-
sb << " __init(vector<" << kBaseTypes[ff].name << ",N> value);\n";
2400-
2401-
// Constructor to make a vector from a scalar of another type.
2402-
if (cost != kConversionCost_Impossible)
2403-
{
2404-
cost += kConversionCost_ScalarToVector;
2405-
sb << " __implicit_conversion(" << cost << ")\n";
2406-
sb << " [__unsafeForceInlineEarly]\n";
2407-
sb << " __init(" << kBaseTypes[ff].name << " value) { this = vector<" << kBaseTypes[tt].name << ",N>( " << kBaseTypes[tt].name << "(value)); }\n";
2408-
}
2409-
}
2410-
}
2411-
sb << "}\n";
2412-
}
24132370

24142371
for( int R = 1; R <= 4; ++R )
24152372
for( int C = 1; C <= 4; ++C )
@@ -2464,38 +2421,36 @@ for( int C = 1; C <= 4; ++C )
24642421
sb << "}\n";
24652422
}
24662423

2467-
for (int tt = 0; tt < kBaseTypeCount; ++tt)
2468-
{
2469-
if(kBaseTypes[tt].tag == BaseType::Void) continue;
2470-
auto toType = kBaseTypes[tt].name;
24712424
}}}}
24722425

2473-
__generic<let R : int, let C : int, let L : int> extension matrix<$(toType),R,C,L>
2426+
//@hidden:
2427+
__intrinsic_op($(kIROp_BuiltinCast))
2428+
internal T __builtin_cast<T, U>(U u);
2429+
2430+
// If T is implicitly convertible to U, then vector<T,N> is implicitly convertible to vector<U,N>.
2431+
__generic<ToType, let N : int> extension vector<ToType,N>
24742432
{
2475-
${{{{
2476-
for (int ff = 0; ff < kBaseTypeCount; ++ff)
2477-
{
2478-
if(kBaseTypes[ff].tag == BaseType::Void) continue;
2479-
if( tt == ff ) continue;
2433+
__implicit_conversion(constraint)
2434+
__intrinsic_op(BuiltinCast)
2435+
__init<FromType>(vector<FromType,N> value) where ToType(FromType) implicit;
24802436

2481-
auto cost = getBaseTypeConversionCost(
2482-
kBaseTypes[tt],
2483-
kBaseTypes[ff]);
2484-
auto fromType = kBaseTypes[ff].name;
2485-
auto op = getBaseTypeConversionOp(
2486-
kBaseTypes[tt],
2487-
kBaseTypes[ff]);
2488-
}}}}
2489-
__implicit_conversion($(cost))
2490-
__intrinsic_op($(op))
2491-
__init(matrix<$(fromType),R,C,L> value);
2492-
${{{{
2437+
__implicit_conversion(constraint+)
2438+
[__unsafeForceInlineEarly]
2439+
[__readNone]
2440+
[TreatAsDifferentiable]
2441+
__init<FromType>(FromType value) where ToType(FromType) implicit
2442+
{
2443+
this = __builtin_cast<vector<ToType,N>>(vector<FromType,N>(value));
24932444
}
2494-
}}}}
24952445
}
2496-
${{{{
2446+
2447+
// If T is implicitly convertible to U, then matrix<T,R,C,L> is implicitly convertible to matrix<U,R,C,L>.
2448+
__generic<ToType, let R : int, let C : int, let L : int> extension matrix<ToType,R,C,L>
2449+
{
2450+
__implicit_conversion(constraint)
2451+
__intrinsic_op(BuiltinCast)
2452+
__init<FromType>(matrix<FromType,R,C,L> value) where ToType(FromType) implicit;
24972453
}
2498-
}}}}
24992454

25002455
//@ hidden:
25012456
__generic<T, U>

source/slang/slang-ast-builder.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,14 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
948948
return witness;
949949
}
950950

951+
TypeCoercionWitness* ASTBuilder::getTypeCoercionWitness(
952+
Type* subType,
953+
Type* superType,
954+
DeclRef<Decl> declRef)
955+
{
956+
return getOrCreate<TypeCoercionWitness>(subType, superType, declRef.declRefBase);
957+
}
958+
951959
DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl)
952960
{
953961
return builder->getMemberDeclRef(parent, decl);

source/slang/slang-ast-builder.h

+5
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,11 @@ class ASTBuilder : public RefObject
636636
SubtypeWitness* subIsLWitness,
637637
SubtypeWitness* subIsRWitness);
638638

639+
TypeCoercionWitness* getTypeCoercionWitness(
640+
Type* fromType,
641+
Type* toType,
642+
DeclRef<Decl> declRef);
643+
639644
/// Helpers to get type info from the SharedASTBuilder
640645
const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice)
641646
{

source/slang/slang-ast-decl.h

+9
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,15 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl
612612
const TypeExp& _getSupOverride() const { return sup; }
613613
};
614614

615+
class TypeCoercionConstraintDecl : public Decl
616+
{
617+
SLANG_AST_CLASS(TypeCoercionConstraintDecl)
618+
619+
SourceLoc whereTokenLoc = SourceLoc();
620+
TypeExp fromType;
621+
TypeExp toType;
622+
};
623+
615624
class GenericValueParamDecl : public VarDeclBase
616625
{
617626
SLANG_AST_CLASS(GenericValueParamDecl)

source/slang/slang-ast-modifier.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1279,10 +1279,10 @@ class ImplicitConversionModifier : public Modifier
12791279
SLANG_AST_CLASS(ImplicitConversionModifier)
12801280

12811281
// The conversion cost, used to rank conversions
1282-
ConversionCost cost;
1282+
ConversionCost cost = kConversionCost_None;
12831283

12841284
// A builtin identifier for identifying conversions that need special treatment.
1285-
BuiltinConversionKind builtinConversionKind;
1285+
BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown;
12861286
};
12871287

12881288
class FormatAttribute : public Attribute

source/slang/slang-ast-support-types.h

+5
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ enum : ConversionCost
178178
// Additional cost when casting an LValue.
179179
kConversionCost_LValueCast = 800,
180180

181+
// The cost of this conversion is defined by the type coercion constraint.
182+
kConversionCost_TypeCoercionConstraint = 1000,
183+
kConversionCost_TypeCoercionConstraintPlusScalarToVector =
184+
kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector,
185+
181186
// Conversion is impossible
182187
kConversionCost_Impossible = 0xFFFFFFFF,
183188
};

source/slang/slang-ast-val.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,58 @@ void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
845845
out << ")";
846846
}
847847

848+
void TypeCoercionWitness::_toTextOverride(StringBuilder& out)
849+
{
850+
out << "TypeCoercionWitness(";
851+
if (getFromType())
852+
out << getFromType();
853+
if (getToType())
854+
out << getToType();
855+
out << ")";
856+
}
857+
858+
Val* TypeCoercionWitness::_substituteImplOverride(
859+
ASTBuilder* astBuilder,
860+
SubstitutionSet subst,
861+
int* ioDiff)
862+
{
863+
int diff = 0;
864+
865+
auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
866+
auto substFrom = as<Type>(getFromType()->substituteImpl(astBuilder, subst, &diff));
867+
auto substTo = as<Type>(getToType()->substituteImpl(astBuilder, subst, &diff));
868+
869+
if (!diff)
870+
return this;
871+
872+
(*ioDiff)++;
873+
874+
TypeCoercionWitness* substValue =
875+
astBuilder->getTypeCoercionWitness(substFrom, substTo, substDeclRef);
876+
return substValue;
877+
}
878+
879+
Val* TypeCoercionWitness::_resolveImplOverride()
880+
{
881+
Val* resolvedDeclRef = nullptr;
882+
if (getDeclRef())
883+
resolvedDeclRef = getDeclRef().declRefBase->resolve();
884+
if (auto resolvedVal = as<Witness>(resolvedDeclRef))
885+
return resolvedVal;
886+
887+
auto newFrom = as<Type>(getFromType()->resolve());
888+
auto newTo = as<Type>(getToType()->resolve());
889+
890+
auto newDeclRef = as<DeclRefBase>(resolvedDeclRef);
891+
if (!newDeclRef)
892+
newDeclRef = getDeclRef().declRefBase;
893+
if (newFrom != getFromType() || newTo != getToType() || newDeclRef != getDeclRef())
894+
{
895+
return getCurrentASTBuilder()->getTypeCoercionWitness(newFrom, newTo, newDeclRef);
896+
}
897+
return this;
898+
}
899+
848900
// UNormModifierVal
849901

850902
void UNormModifierVal::_toTextOverride(StringBuilder& out)

source/slang/slang-ast-val.h

+14
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,20 @@ class TypeEqualityWitness : public SubtypeWitness
621621
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
622622
};
623623

624+
class TypeCoercionWitness : public Witness
625+
{
626+
SLANG_AST_CLASS(TypeCoercionWitness)
627+
628+
Type* getFromType() { return as<Type>(getOperand(0)); }
629+
Type* getToType() { return as<Type>(getOperand(1)); }
630+
631+
DeclRef<Decl> getDeclRef() { return as<DeclRefBase>(getOperand(2)); }
632+
633+
void _toTextOverride(StringBuilder& out);
634+
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
635+
Val* _resolveImplOverride();
636+
};
637+
624638
// A witness that one type is a subtype of another
625639
// because some in-scope declaration says so
626640
class DeclaredSubtypeWitness : public SubtypeWitness

source/slang/slang-check-constraint.cpp

+52-7
Original file line numberDiff line numberDiff line change
@@ -715,13 +715,6 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
715715
// system as being solved now, as a result of the witness we found.
716716
}
717717

718-
// Add a flat cost to all unconstrained generic params.
719-
for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
720-
{
721-
if (!constrainedGenericParams.contains(typeParamDecl))
722-
outBaseCost += kConversionCost_UnconstraintGenericParam;
723-
}
724-
725718
// Make sure we haven't constructed any spurious constraints
726719
// that we aren't able to satisfy:
727720
for (auto c : system->constraints)
@@ -732,6 +725,58 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
732725
}
733726
}
734727

728+
// Verify that all type coercion constraints can be satisfied.
729+
for (auto constraintDecl :
730+
genericDeclRef.getDecl()->getMembersOfType<TypeCoercionConstraintDecl>())
731+
{
732+
DeclRef<TypeCoercionConstraintDecl> constraintDeclRef =
733+
m_astBuilder
734+
->getGenericAppDeclRef(
735+
genericDeclRef,
736+
args.getArrayView().arrayView,
737+
constraintDecl)
738+
.as<TypeCoercionConstraintDecl>();
739+
auto fromType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->fromType.Ptr());
740+
auto toType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->toType.Ptr());
741+
auto conversionCost = getConversionCost(toType, fromType);
742+
if (constraintDecl->findModifier<ImplicitConversionModifier>())
743+
{
744+
if (conversionCost > kConversionCost_GeneralConversion)
745+
{
746+
// The type arguments are not implicitly convertible, return failure.
747+
return DeclRef<Decl>();
748+
}
749+
}
750+
else
751+
{
752+
if (conversionCost == kConversionCost_Impossible)
753+
{
754+
// The type arguments are not convertible, return failure.
755+
return DeclRef<Decl>();
756+
}
757+
}
758+
if (auto fromDecl = isDeclRefTypeOf<Decl>(constraintDecl->fromType))
759+
{
760+
constrainedGenericParams.add(fromDecl.getDecl());
761+
}
762+
if (auto toDecl = isDeclRefTypeOf<Decl>(constraintDecl->toType))
763+
{
764+
constrainedGenericParams.add(toDecl.getDecl());
765+
}
766+
// If we are to expand the support of type coercion constraint beyond simple builtin core
767+
// module functions, then the witness should be a reference to the conversion function. For
768+
// now, this isn't required, and it is not easy to get it from the coercion logic, so we
769+
// leave it empty.
770+
args.add(m_astBuilder->getTypeCoercionWitness(fromType, toType, DeclRef<Decl>()));
771+
}
772+
773+
// Add a flat cost to all unconstrained generic params.
774+
for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
775+
{
776+
if (!constrainedGenericParams.contains(typeParamDecl))
777+
outBaseCost += kConversionCost_UnconstraintGenericParam;
778+
}
779+
735780
return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView);
736781
}
737782

source/slang/slang-check-conversion.cpp

+22-7
Original file line numberDiff line numberDiff line change
@@ -1045,11 +1045,28 @@ int getTypeBitSize(Type* t)
10451045
}
10461046

10471047
ConversionCost SemanticsVisitor::getImplicitConversionCostWithKnownArg(
1048-
Decl* decl,
1048+
DeclRef<Decl> decl,
10491049
Type* toType,
10501050
Expr* arg)
10511051
{
1052-
ConversionCost candidateCost = getImplicitConversionCost(decl);
1052+
ConversionCost candidateCost = getImplicitConversionCost(decl.getDecl());
1053+
1054+
if (candidateCost == kConversionCost_TypeCoercionConstraint ||
1055+
candidateCost == kConversionCost_TypeCoercionConstraintPlusScalarToVector)
1056+
{
1057+
if (auto genApp = as<GenericAppDeclRef>(decl.declRefBase))
1058+
{
1059+
for (auto genArg : genApp->getArgs())
1060+
{
1061+
if (auto wit = as<TypeCoercionWitness>(genArg))
1062+
{
1063+
candidateCost -= kConversionCost_TypeCoercionConstraint;
1064+
candidateCost += getConversionCost(wit->getToType(), wit->getFromType());
1065+
break;
1066+
}
1067+
}
1068+
}
1069+
}
10531070

10541071
// Fix up the cost if the operand is a const lit.
10551072
if (isScalarIntegerType(toType))
@@ -1577,10 +1594,8 @@ bool SemanticsVisitor::_coerce(
15771594
ImplicitCastMethod method;
15781595
for (auto candidate : overloadContext.bestCandidates)
15791596
{
1580-
ConversionCost candidateCost = getImplicitConversionCostWithKnownArg(
1581-
candidate.item.declRef.getDecl(),
1582-
toType,
1583-
fromExpr);
1597+
ConversionCost candidateCost =
1598+
getImplicitConversionCostWithKnownArg(candidate.item.declRef, toType, fromExpr);
15841599
if (candidateCost < bestCost)
15851600
{
15861601
method.conversionFuncOverloadCandidate = candidate;
@@ -1632,7 +1647,7 @@ bool SemanticsVisitor::_coerce(
16321647
// cost associated with the initializer we are invoking.
16331648
//
16341649
ConversionCost cost = getImplicitConversionCostWithKnownArg(
1635-
overloadContext.bestCandidate->item.declRef.getDecl(),
1650+
overloadContext.bestCandidate->item.declRef,
16361651
toType,
16371652
fromExpr);
16381653

0 commit comments

Comments
 (0)