Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify implicit cast ctors for vector & matrix. #6408

Merged
merged 6 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 24 additions & 69 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//public module core;
public module core;

// Slang `core` library

Expand Down Expand Up @@ -2367,49 +2367,6 @@ __generic<T> __extension vector<T, 4>

${{{{

// The above extensions are generic in the *type* of the vector,
// but explicit in the *size*. We will now declare an extension
// for each builtin type that is generic in the size.
//
for (int tt = 0; tt < kBaseTypeCount; ++tt)
{
if(kBaseTypes[tt].tag == BaseType::Void) continue;

sb << "__generic<let N : int> __extension vector<"
<< kBaseTypes[tt].name << ",N>\n{\n";

for (int ff = 0; ff < kBaseTypeCount; ++ff)
{
if(kBaseTypes[ff].tag == BaseType::Void) continue;


if( tt != ff )
{
auto cost = getBaseTypeConversionCost(
kBaseTypes[tt],
kBaseTypes[ff]);
auto op = getBaseTypeConversionOp(
kBaseTypes[tt],
kBaseTypes[ff]);

// Implicit conversion from a vector of the same
// size, but different element type.
sb << " __implicit_conversion(" << cost << ")\n";
sb << " __intrinsic_op(" << int(op) << ")\n";
sb << " __init(vector<" << kBaseTypes[ff].name << ",N> value);\n";

// Constructor to make a vector from a scalar of another type.
if (cost != kConversionCost_Impossible)
{
cost += kConversionCost_ScalarToVector;
sb << " __implicit_conversion(" << cost << ")\n";
sb << " [__unsafeForceInlineEarly]\n";
sb << " __init(" << kBaseTypes[ff].name << " value) { this = vector<" << kBaseTypes[tt].name << ",N>( " << kBaseTypes[tt].name << "(value)); }\n";
}
}
}
sb << "}\n";
}

for( int R = 1; R <= 4; ++R )
for( int C = 1; C <= 4; ++C )
Expand Down Expand Up @@ -2464,38 +2421,36 @@ for( int C = 1; C <= 4; ++C )
sb << "}\n";
}

for (int tt = 0; tt < kBaseTypeCount; ++tt)
{
if(kBaseTypes[tt].tag == BaseType::Void) continue;
auto toType = kBaseTypes[tt].name;
}}}}

__generic<let R : int, let C : int, let L : int> extension matrix<$(toType),R,C,L>
//@hidden:
__intrinsic_op($(kIROp_BuiltinCast))
internal T __builtin_cast<T, U>(U u);

// If T is implicitly convertible to U, then vector<T,N> is implicitly convertible to vector<U,N>.
__generic<ToType, let N : int> extension vector<ToType,N>
{
${{{{
for (int ff = 0; ff < kBaseTypeCount; ++ff)
{
if(kBaseTypes[ff].tag == BaseType::Void) continue;
if( tt == ff ) continue;
__implicit_conversion(constraint)
__intrinsic_op(BuiltinCast)
__init<FromType>(vector<FromType,N> value) where ToType(FromType) implicit;

auto cost = getBaseTypeConversionCost(
kBaseTypes[tt],
kBaseTypes[ff]);
auto fromType = kBaseTypes[ff].name;
auto op = getBaseTypeConversionOp(
kBaseTypes[tt],
kBaseTypes[ff]);
}}}}
__implicit_conversion($(cost))
__intrinsic_op($(op))
__init(matrix<$(fromType),R,C,L> value);
${{{{
__implicit_conversion(constraint+)
[__unsafeForceInlineEarly]
[__readNone]
[TreatAsDifferentiable]
__init<FromType>(FromType value) where ToType(FromType) implicit
{
this = __builtin_cast<vector<ToType,N>>(vector<FromType,N>(value));
}
}}}}
}
${{{{

// If T is implicitly convertible to U, then matrix<T,R,C,L> is implicitly convertible to matrix<U,R,C,L>.
__generic<ToType, let R : int, let C : int, let L : int> extension matrix<ToType,R,C,L>
{
__implicit_conversion(constraint)
__intrinsic_op(BuiltinCast)
__init<FromType>(matrix<FromType,R,C,L> value) where ToType(FromType) implicit;
}
}}}}

//@ hidden:
__generic<T, U>
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang-ast-builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,14 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
return witness;
}

TypeCoercionWitness* ASTBuilder::getTypeCoercionWitness(
Type* subType,
Type* superType,
DeclRef<Decl> declRef)
{
return getOrCreate<TypeCoercionWitness>(subType, superType, declRef.declRefBase);
}

DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl)
{
return builder->getMemberDeclRef(parent, decl);
Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-ast-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,11 @@ class ASTBuilder : public RefObject
SubtypeWitness* subIsLWitness,
SubtypeWitness* subIsRWitness);

TypeCoercionWitness* getTypeCoercionWitness(
Type* fromType,
Type* toType,
DeclRef<Decl> declRef);

/// Helpers to get type info from the SharedASTBuilder
const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice)
{
Expand Down
9 changes: 9 additions & 0 deletions source/slang/slang-ast-decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,15 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl
const TypeExp& _getSupOverride() const { return sup; }
};

class TypeCoercionConstraintDecl : public Decl
{
SLANG_AST_CLASS(TypeCoercionConstraintDecl)

SourceLoc whereTokenLoc = SourceLoc();
TypeExp fromType;
TypeExp toType;
};

class GenericValueParamDecl : public VarDeclBase
{
SLANG_AST_CLASS(GenericValueParamDecl)
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-ast-modifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -1279,10 +1279,10 @@ class ImplicitConversionModifier : public Modifier
SLANG_AST_CLASS(ImplicitConversionModifier)

// The conversion cost, used to rank conversions
ConversionCost cost;
ConversionCost cost = kConversionCost_None;

// A builtin identifier for identifying conversions that need special treatment.
BuiltinConversionKind builtinConversionKind;
BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown;
};

class FormatAttribute : public Attribute
Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-ast-support-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ enum : ConversionCost
// Additional cost when casting an LValue.
kConversionCost_LValueCast = 800,

// The cost of this conversion is defined by the type coercion constraint.
kConversionCost_TypeCoercionConstraint = 1000,
kConversionCost_TypeCoercionConstraintPlusScalarToVector =
kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector,

// Conversion is impossible
kConversionCost_Impossible = 0xFFFFFFFF,
};
Expand Down
52 changes: 52 additions & 0 deletions source/slang/slang-ast-val.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,58 @@ void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
out << ")";
}

void TypeCoercionWitness::_toTextOverride(StringBuilder& out)
{
out << "TypeCoercionWitness(";
if (getFromType())
out << getFromType();
if (getToType())
out << getToType();
out << ")";
}

Val* TypeCoercionWitness::_substituteImplOverride(
ASTBuilder* astBuilder,
SubstitutionSet subst,
int* ioDiff)
{
int diff = 0;

auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
auto substFrom = as<Type>(getFromType()->substituteImpl(astBuilder, subst, &diff));
auto substTo = as<Type>(getToType()->substituteImpl(astBuilder, subst, &diff));

if (!diff)
return this;

(*ioDiff)++;

TypeCoercionWitness* substValue =
astBuilder->getTypeCoercionWitness(substFrom, substTo, substDeclRef);
return substValue;
}

Val* TypeCoercionWitness::_resolveImplOverride()
{
Val* resolvedDeclRef = nullptr;
if (getDeclRef())
resolvedDeclRef = getDeclRef().declRefBase->resolve();
if (auto resolvedVal = as<Witness>(resolvedDeclRef))
return resolvedVal;

auto newFrom = as<Type>(getFromType()->resolve());
auto newTo = as<Type>(getToType()->resolve());

auto newDeclRef = as<DeclRefBase>(resolvedDeclRef);
if (!newDeclRef)
newDeclRef = getDeclRef().declRefBase;
if (newFrom != getFromType() || newTo != getToType() || newDeclRef != getDeclRef())
{
return getCurrentASTBuilder()->getTypeCoercionWitness(newFrom, newTo, newDeclRef);
}
return this;
}

// UNormModifierVal

void UNormModifierVal::_toTextOverride(StringBuilder& out)
Expand Down
14 changes: 14 additions & 0 deletions source/slang/slang-ast-val.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,20 @@ class TypeEqualityWitness : public SubtypeWitness
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};

class TypeCoercionWitness : public Witness
{
SLANG_AST_CLASS(TypeCoercionWitness)

Type* getFromType() { return as<Type>(getOperand(0)); }
Type* getToType() { return as<Type>(getOperand(1)); }

DeclRef<Decl> getDeclRef() { return as<DeclRefBase>(getOperand(2)); }

void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
};

// A witness that one type is a subtype of another
// because some in-scope declaration says so
class DeclaredSubtypeWitness : public SubtypeWitness
Expand Down
59 changes: 52 additions & 7 deletions source/slang/slang-check-constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,13 +715,6 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
// system as being solved now, as a result of the witness we found.
}

// Add a flat cost to all unconstrained generic params.
for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
{
if (!constrainedGenericParams.contains(typeParamDecl))
outBaseCost += kConversionCost_UnconstraintGenericParam;
}

// Make sure we haven't constructed any spurious constraints
// that we aren't able to satisfy:
for (auto c : system->constraints)
Expand All @@ -732,6 +725,58 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
}
}

// Verify that all type coercion constraints can be satisfied.
for (auto constraintDecl :
genericDeclRef.getDecl()->getMembersOfType<TypeCoercionConstraintDecl>())
{
DeclRef<TypeCoercionConstraintDecl> constraintDeclRef =
m_astBuilder
->getGenericAppDeclRef(
genericDeclRef,
args.getArrayView().arrayView,
constraintDecl)
.as<TypeCoercionConstraintDecl>();
auto fromType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->fromType.Ptr());
auto toType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->toType.Ptr());
auto conversionCost = getConversionCost(toType, fromType);
if (constraintDecl->findModifier<ImplicitConversionModifier>())
{
if (conversionCost > kConversionCost_GeneralConversion)
{
// The type arguments are not implicitly convertible, return failure.
return DeclRef<Decl>();
}
}
else
{
if (conversionCost == kConversionCost_Impossible)
{
// The type arguments are not convertible, return failure.
return DeclRef<Decl>();
}
}
if (auto fromDecl = isDeclRefTypeOf<Decl>(constraintDecl->fromType))
{
constrainedGenericParams.add(fromDecl.getDecl());
}
if (auto toDecl = isDeclRefTypeOf<Decl>(constraintDecl->toType))
{
constrainedGenericParams.add(toDecl.getDecl());
}
// If we are to expand the support of type coercion constraint beyond simple builtin core
// module functions, then the witness should be a reference to the conversion function. For
// now, this isn't required, and it is not easy to get it from the coercion logic, so we
// leave it empty.
args.add(m_astBuilder->getTypeCoercionWitness(fromType, toType, DeclRef<Decl>()));
}

// Add a flat cost to all unconstrained generic params.
for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
{
if (!constrainedGenericParams.contains(typeParamDecl))
outBaseCost += kConversionCost_UnconstraintGenericParam;
}

return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView);
}

Expand Down
29 changes: 22 additions & 7 deletions source/slang/slang-check-conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,11 +1045,28 @@ int getTypeBitSize(Type* t)
}

ConversionCost SemanticsVisitor::getImplicitConversionCostWithKnownArg(
Decl* decl,
DeclRef<Decl> decl,
Type* toType,
Expr* arg)
{
ConversionCost candidateCost = getImplicitConversionCost(decl);
ConversionCost candidateCost = getImplicitConversionCost(decl.getDecl());

if (candidateCost == kConversionCost_TypeCoercionConstraint ||
candidateCost == kConversionCost_TypeCoercionConstraintPlusScalarToVector)
{
if (auto genApp = as<GenericAppDeclRef>(decl.declRefBase))
{
for (auto genArg : genApp->getArgs())
{
if (auto wit = as<TypeCoercionWitness>(genArg))
{
candidateCost -= kConversionCost_TypeCoercionConstraint;
candidateCost += getConversionCost(wit->getToType(), wit->getFromType());
break;
}
}
}
}

// Fix up the cost if the operand is a const lit.
if (isScalarIntegerType(toType))
Expand Down Expand Up @@ -1577,10 +1594,8 @@ bool SemanticsVisitor::_coerce(
ImplicitCastMethod method;
for (auto candidate : overloadContext.bestCandidates)
{
ConversionCost candidateCost = getImplicitConversionCostWithKnownArg(
candidate.item.declRef.getDecl(),
toType,
fromExpr);
ConversionCost candidateCost =
getImplicitConversionCostWithKnownArg(candidate.item.declRef, toType, fromExpr);
if (candidateCost < bestCost)
{
method.conversionFuncOverloadCandidate = candidate;
Expand Down Expand Up @@ -1632,7 +1647,7 @@ bool SemanticsVisitor::_coerce(
// cost associated with the initializer we are invoking.
//
ConversionCost cost = getImplicitConversionCostWithKnownArg(
overloadContext.bestCandidate->item.declRef.getDecl(),
overloadContext.bestCandidate->item.declRef,
toType,
fromExpr);

Expand Down
Loading
Loading