@@ -261,8 +261,11 @@ namespace Slang
261
261
DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem (
262
262
ConstraintSystem* system,
263
263
DeclRef<GenericDecl> genericDeclRef,
264
- ArrayView<Val*> knownGenericArgs)
264
+ ArrayView<Val*> knownGenericArgs,
265
+ ConversionCost& outBaseCost)
265
266
{
267
+ outBaseCost = kConversionCost_None ;
268
+
266
269
// For now the "solver" is going to be ridiculously simplistic.
267
270
268
271
// The generic itself will have some constraints, and for now we add these
@@ -340,6 +343,8 @@ namespace Slang
340
343
}
341
344
342
345
QualType type;
346
+ bool typeConstraintOptional = true ;
347
+
343
348
for (auto & c : system ->constraints )
344
349
{
345
350
if (c.decl != typeParam.getDecl ())
@@ -348,11 +353,12 @@ namespace Slang
348
353
auto cType = QualType (as<Type>(c.val ), c.isUsedAsLValue );
349
354
SLANG_RELEASE_ASSERT (cType);
350
355
351
- if (!type)
356
+ if (!type || (typeConstraintOptional && !c. isOptional ) )
352
357
{
353
358
type = cType;
359
+ typeConstraintOptional = c.isOptional ;
354
360
}
355
- else
361
+ else if (!typeConstraintOptional)
356
362
{
357
363
auto joinType = TryJoinTypes (type, cType);
358
364
if (!joinType)
@@ -397,6 +403,7 @@ namespace Slang
397
403
// TODO(tfoley): figure out how this needs to interact with
398
404
// compile-time integers that aren't just constants...
399
405
IntVal* val = nullptr ;
406
+ bool valOptional = true ;
400
407
for (auto & c : system ->constraints )
401
408
{
402
409
if (c.decl != valParam.getDecl ())
@@ -405,13 +412,14 @@ namespace Slang
405
412
auto cVal = as<IntVal>(c.val );
406
413
SLANG_RELEASE_ASSERT (cVal);
407
414
408
- if (!val)
415
+ if (!val || (valOptional && !c. isOptional ) )
409
416
{
410
417
val = cVal;
418
+ valOptional = c.isOptional ;
411
419
}
412
420
else
413
421
{
414
- if (!val->equals (cVal))
422
+ if (!valOptional && ! val->equals (cVal))
415
423
{
416
424
// failure!
417
425
return DeclRef<Decl>();
@@ -450,6 +458,8 @@ namespace Slang
450
458
// search for a conformance `Robin : ISidekick`, which involved
451
459
// apply the substitutions we already know...
452
460
461
+ HashSet<Decl*> constrainedGenericParams;
462
+
453
463
for ( auto constraintDecl : genericDeclRef.getDecl ()->getMembersOfType <GenericTypeConstraintDecl>() )
454
464
{
455
465
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef (
@@ -458,6 +468,10 @@ namespace Slang
458
468
// Extract the (substituted) sub- and super-type from the constraint.
459
469
auto sub = getSub (m_astBuilder, constraintDeclRef);
460
470
auto sup = getSup (m_astBuilder, constraintDeclRef);
471
+
472
+ // Mark sub type as constrained.
473
+ if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl ()->sub .type ))
474
+ constrainedGenericParams.add (subDeclRefType->getDeclRef ().getDecl ());
461
475
462
476
if (sub->equals (sup))
463
477
{
@@ -475,6 +489,7 @@ namespace Slang
475
489
{
476
490
// We found a witness, so it will become an (implicit) argument.
477
491
args.add (subTypeWitness);
492
+ outBaseCost += subTypeWitness->getOverloadResolutionCost ();
478
493
}
479
494
else
480
495
{
@@ -489,6 +504,13 @@ namespace Slang
489
504
// system as being solved now, as a result of the witness we found.
490
505
}
491
506
507
+ // Add a flat cost to all unconstrained generic params.
508
+ for (auto typeParamDecl : genericDeclRef.getDecl ()->getMembersOfType <GenericTypeParamDecl>())
509
+ {
510
+ if (!constrainedGenericParams.contains (typeParamDecl))
511
+ outBaseCost += kConversionCost_UnconstraintGenericParam ;
512
+ }
513
+
492
514
// Make sure we haven't constructed any spurious constraints
493
515
// that we aren't able to satisfy:
494
516
for (auto c : system ->constraints )
@@ -810,6 +832,29 @@ namespace Slang
810
832
return false ;
811
833
}
812
834
835
+ void SemanticsVisitor::maybeUnifyUnconstraintIntParam (ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal)
836
+ {
837
+ // If `param` is an unconstrained integer val param, and `arg` is a const int val,
838
+ // we add a constraint to the system that `param` must be equal to `arg`.
839
+ // If `param` is already constrained, ignore and do nothing.
840
+ if (auto typeCastParam = as<TypeCastIntVal>(param))
841
+ {
842
+ param = as<IntVal>(typeCastParam->getBase ());
843
+ }
844
+ auto intParam = as<GenericParamIntVal>(param);
845
+ if (!intParam)
846
+ return ;
847
+ for (auto c : constraints.constraints )
848
+ if (c.decl == intParam->getDeclRef ().getDecl ())
849
+ return ;
850
+ Constraint c;
851
+ c.decl = intParam->getDeclRef ().getDecl ();
852
+ c.isUsedAsLValue = paramIsLVal;
853
+ c.val = arg;
854
+ c.isOptional = true ;
855
+ constraints.constraints .add (c);
856
+ }
857
+
813
858
bool SemanticsVisitor::TryUnifyTypes (
814
859
ConstraintSystem& constraints,
815
860
QualType fst,
@@ -880,6 +925,12 @@ namespace Slang
880
925
{
881
926
if (auto sndScalarType = as<BasicExpressionType>(snd))
882
927
{
928
+ // Try unify the vector count param. In case the vector count is defined by a generic value
929
+ // parameter, we want to be able to infer that parameter should be 1.
930
+ // However, we don't want a failed unification to fail the entire generic argument inference,
931
+ // because a scalar can still be casted into a vector of any length.
932
+
933
+ maybeUnifyUnconstraintIntParam (constraints, fstVectorType->getElementCount (), m_astBuilder->getIntVal (m_astBuilder->getIntType (), 1 ), fst.isLeftValue );
883
934
return TryUnifyTypes (
884
935
constraints,
885
936
QualType (fstVectorType->getElementType (), fst.isLeftValue ),
@@ -891,15 +942,13 @@ namespace Slang
891
942
{
892
943
if (auto sndVectorType = as<VectorExpressionType>(snd))
893
944
{
945
+ maybeUnifyUnconstraintIntParam (constraints, sndVectorType->getElementCount (), m_astBuilder->getIntVal (m_astBuilder->getIntType (), 1 ), snd.isLeftValue );
894
946
return TryUnifyTypes (
895
947
constraints,
896
948
QualType (fstScalarType, fst.isLeftValue ),
897
949
QualType (sndVectorType->getElementType (), snd.isLeftValue ));
898
950
}
899
951
}
900
-
901
- // TODO: the same thing for vectors...
902
-
903
952
return false ;
904
953
}
905
954
0 commit comments