Skip to content

Commit 4fb3b10

Browse files
authored
Improve generic type argument inference. (shader-slang#3370)
* Improve generic type argument inference. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 62426e9 commit 4fb3b10

11 files changed

+255
-45
lines changed

source/slang/slang-ast-support-types.h

+3
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ namespace Slang
8080
// No conversion at all
8181
kConversionCost_None = 0,
8282

83+
kConversionCost_GenericParamUpcast = 1,
84+
kConversionCost_UnconstraintGenericParam = 20,
85+
8386
// Convert between matrices of different layout
8487
kConversionCost_MatrixLayout = 5,
8588

source/slang/slang-ast-val.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ Val* DeclaredSubtypeWitness::_resolveImplOverride()
286286
return this;
287287
}
288288

289+
ConversionCost DeclaredSubtypeWitness::_getOverloadResolutionCostOverride()
290+
{
291+
return kConversionCost_GenericParamUpcast;
292+
}
293+
289294
Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
290295
{
291296
if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>())
@@ -431,6 +436,11 @@ Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, S
431436
return astBuilder->getTransitiveSubtypeWitness(substSubToMid, substMidToSup);
432437
}
433438

439+
ConversionCost TransitiveSubtypeWitness::_getOverloadResolutionCostOverride()
440+
{
441+
return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost();
442+
}
443+
434444
void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out)
435445
{
436446
// Note: we only print the constituent
@@ -471,6 +481,17 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a
471481
substSub, substSup, substWitness, getIndexInConjunction());
472482
}
473483

484+
ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostOverride()
485+
{
486+
auto witness = as<ConjunctionSubtypeWitness>(getConjunctionWitness());
487+
if (!witness)
488+
return kConversionCost_None;
489+
auto index = getIndexInConjunction();
490+
if (index < witness->getComponentCount())
491+
return witness->getComponentWitness(index)->getOverloadResolutionCost();
492+
return kConversionCost_None;
493+
}
494+
474495
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
475496

476497
void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out)
@@ -541,6 +562,14 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder,
541562
return result;
542563
}
543564

565+
ConversionCost ConjunctionSubtypeWitness::_getOverloadResolutionCostOverride()
566+
{
567+
ConversionCost result = kConversionCost_None;
568+
for (Index i = 0; i < getComponentCount(); i++)
569+
result += getComponentWitness(i)->getOverloadResolutionCost();
570+
return result;
571+
}
572+
544573
void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
545574
{
546575
out << "ExtractFromConjunctionSubtypeWitness(";

source/slang/slang-ast-val.h

+11
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ class SubtypeWitness : public Witness
457457

458458
Type* getSub() { return as<Type>(getOperand(0)); }
459459
Type* getSup() { return as<Type>(getOperand(1)); }
460+
461+
ConversionCost _getOverloadResolutionCostOverride();
462+
ConversionCost getOverloadResolutionCost();
460463
};
461464

462465
class TypeEqualityWitness : public SubtypeWitness
@@ -493,6 +496,8 @@ class DeclaredSubtypeWitness : public SubtypeWitness
493496
{
494497
setOperands(inSub, inSup, inDeclRef);
495498
}
499+
500+
ConversionCost _getOverloadResolutionCostOverride();
496501
};
497502

498503
// A witness that `sub : sup` because `sub : mid` and `mid : sup`
@@ -520,6 +525,8 @@ class TransitiveSubtypeWitness : public SubtypeWitness
520525
{
521526
setOperands(subType, supType, inSubToMid, inMidToSup);
522527
}
528+
529+
ConversionCost _getOverloadResolutionCostOverride();
523530
};
524531

525532
// A witness that `sub : sup` because `sub` was wrapped into
@@ -580,6 +587,8 @@ class ConjunctionSubtypeWitness : public SubtypeWitness
580587

581588
void _toTextOverride(StringBuilder& out);
582589
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
590+
591+
ConversionCost _getOverloadResolutionCostOverride();
583592
};
584593

585594
/// A witness that `T <: L` or `T <: R` because `T <: L&R`
@@ -609,6 +618,8 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness
609618

610619
void _toTextOverride(StringBuilder& out);
611620
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
621+
622+
ConversionCost _getOverloadResolutionCostOverride();
612623
};
613624

614625
/// A value that represents a modifier attached to some other value

source/slang/slang-check-constraint.cpp

+57-8
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,11 @@ namespace Slang
261261
DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
262262
ConstraintSystem* system,
263263
DeclRef<GenericDecl> genericDeclRef,
264-
ArrayView<Val*> knownGenericArgs)
264+
ArrayView<Val*> knownGenericArgs,
265+
ConversionCost& outBaseCost)
265266
{
267+
outBaseCost = kConversionCost_None;
268+
266269
// For now the "solver" is going to be ridiculously simplistic.
267270

268271
// The generic itself will have some constraints, and for now we add these
@@ -340,6 +343,8 @@ namespace Slang
340343
}
341344

342345
QualType type;
346+
bool typeConstraintOptional = true;
347+
343348
for (auto& c : system->constraints)
344349
{
345350
if (c.decl != typeParam.getDecl())
@@ -348,11 +353,12 @@ namespace Slang
348353
auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
349354
SLANG_RELEASE_ASSERT(cType);
350355

351-
if (!type)
356+
if (!type || (typeConstraintOptional && !c.isOptional))
352357
{
353358
type = cType;
359+
typeConstraintOptional = c.isOptional;
354360
}
355-
else
361+
else if (!typeConstraintOptional)
356362
{
357363
auto joinType = TryJoinTypes(type, cType);
358364
if (!joinType)
@@ -397,6 +403,7 @@ namespace Slang
397403
// TODO(tfoley): figure out how this needs to interact with
398404
// compile-time integers that aren't just constants...
399405
IntVal* val = nullptr;
406+
bool valOptional = true;
400407
for (auto& c : system->constraints)
401408
{
402409
if (c.decl != valParam.getDecl())
@@ -405,13 +412,14 @@ namespace Slang
405412
auto cVal = as<IntVal>(c.val);
406413
SLANG_RELEASE_ASSERT(cVal);
407414

408-
if (!val)
415+
if (!val || (valOptional && !c.isOptional))
409416
{
410417
val = cVal;
418+
valOptional = c.isOptional;
411419
}
412420
else
413421
{
414-
if(!val->equals(cVal))
422+
if(!valOptional && !val->equals(cVal))
415423
{
416424
// failure!
417425
return DeclRef<Decl>();
@@ -450,6 +458,8 @@ namespace Slang
450458
// search for a conformance `Robin : ISidekick`, which involved
451459
// apply the substitutions we already know...
452460

461+
HashSet<Decl*> constrainedGenericParams;
462+
453463
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
454464
{
455465
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
@@ -458,6 +468,10 @@ namespace Slang
458468
// Extract the (substituted) sub- and super-type from the constraint.
459469
auto sub = getSub(m_astBuilder, constraintDeclRef);
460470
auto sup = getSup(m_astBuilder, constraintDeclRef);
471+
472+
// Mark sub type as constrained.
473+
if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type))
474+
constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl());
461475

462476
if (sub->equals(sup))
463477
{
@@ -475,6 +489,7 @@ namespace Slang
475489
{
476490
// We found a witness, so it will become an (implicit) argument.
477491
args.add(subTypeWitness);
492+
outBaseCost += subTypeWitness->getOverloadResolutionCost();
478493
}
479494
else
480495
{
@@ -489,6 +504,13 @@ namespace Slang
489504
// system as being solved now, as a result of the witness we found.
490505
}
491506

507+
// Add a flat cost to all unconstrained generic params.
508+
for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
509+
{
510+
if (!constrainedGenericParams.contains(typeParamDecl))
511+
outBaseCost += kConversionCost_UnconstraintGenericParam;
512+
}
513+
492514
// Make sure we haven't constructed any spurious constraints
493515
// that we aren't able to satisfy:
494516
for (auto c : system->constraints)
@@ -810,6 +832,29 @@ namespace Slang
810832
return false;
811833
}
812834

835+
void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal)
836+
{
837+
// If `param` is an unconstrained integer val param, and `arg` is a const int val,
838+
// we add a constraint to the system that `param` must be equal to `arg`.
839+
// If `param` is already constrained, ignore and do nothing.
840+
if (auto typeCastParam = as<TypeCastIntVal>(param))
841+
{
842+
param = as<IntVal>(typeCastParam->getBase());
843+
}
844+
auto intParam = as<GenericParamIntVal>(param);
845+
if (!intParam)
846+
return;
847+
for (auto c : constraints.constraints)
848+
if (c.decl == intParam->getDeclRef().getDecl())
849+
return;
850+
Constraint c;
851+
c.decl = intParam->getDeclRef().getDecl();
852+
c.isUsedAsLValue = paramIsLVal;
853+
c.val = arg;
854+
c.isOptional = true;
855+
constraints.constraints.add(c);
856+
}
857+
813858
bool SemanticsVisitor::TryUnifyTypes(
814859
ConstraintSystem& constraints,
815860
QualType fst,
@@ -880,6 +925,12 @@ namespace Slang
880925
{
881926
if(auto sndScalarType = as<BasicExpressionType>(snd))
882927
{
928+
// Try unify the vector count param. In case the vector count is defined by a generic value
929+
// parameter, we want to be able to infer that parameter should be 1.
930+
// However, we don't want a failed unification to fail the entire generic argument inference,
931+
// because a scalar can still be casted into a vector of any length.
932+
933+
maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue);
883934
return TryUnifyTypes(
884935
constraints,
885936
QualType(fstVectorType->getElementType(), fst.isLeftValue),
@@ -891,15 +942,13 @@ namespace Slang
891942
{
892943
if(auto sndVectorType = as<VectorExpressionType>(snd))
893944
{
945+
maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue);
894946
return TryUnifyTypes(
895947
constraints,
896948
QualType(fstScalarType, fst.isLeftValue),
897949
QualType(sndVectorType->getElementType(), snd.isLeftValue));
898950
}
899951
}
900-
901-
// TODO: the same thing for vectors...
902-
903952
return false;
904953
}
905954

source/slang/slang-check-decl.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -6630,7 +6630,9 @@ namespace Slang
66306630

66316631
if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type))
66326632
return DeclRef<ExtensionDecl>();
6633-
auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>());
6633+
6634+
ConversionCost baseCost;
6635+
auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>(), baseCost);
66346636
if (!solvedDeclRef)
66356637
{
66366638
return DeclRef<ExtensionDecl>();

0 commit comments

Comments
 (0)