diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index b7f50cd1bb..267f7b2d47 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1,4 +1,4 @@ -//public module core; +public module core; // Slang `core` library @@ -2367,49 +2367,6 @@ __generic __extension vector ${{{{ -// The above extensions are generic in the *type* of the vector, -// but explicit in the *size*. We will now declare an extension -// for each builtin type that is generic in the size. -// -for (int tt = 0; tt < kBaseTypeCount; ++tt) -{ - if(kBaseTypes[tt].tag == BaseType::Void) continue; - - sb << "__generic __extension vector<" - << kBaseTypes[tt].name << ",N>\n{\n"; - - for (int ff = 0; ff < kBaseTypeCount; ++ff) - { - if(kBaseTypes[ff].tag == BaseType::Void) continue; - - - if( tt != ff ) - { - auto cost = getBaseTypeConversionCost( - kBaseTypes[tt], - kBaseTypes[ff]); - auto op = getBaseTypeConversionOp( - kBaseTypes[tt], - kBaseTypes[ff]); - - // Implicit conversion from a vector of the same - // size, but different element type. - sb << " __implicit_conversion(" << cost << ")\n"; - sb << " __intrinsic_op(" << int(op) << ")\n"; - sb << " __init(vector<" << kBaseTypes[ff].name << ",N> value);\n"; - - // Constructor to make a vector from a scalar of another type. - if (cost != kConversionCost_Impossible) - { - cost += kConversionCost_ScalarToVector; - sb << " __implicit_conversion(" << cost << ")\n"; - sb << " [__unsafeForceInlineEarly]\n"; - sb << " __init(" << kBaseTypes[ff].name << " value) { this = vector<" << kBaseTypes[tt].name << ",N>( " << kBaseTypes[tt].name << "(value)); }\n"; - } - } - } - sb << "}\n"; -} for( int R = 1; R <= 4; ++R ) for( int C = 1; C <= 4; ++C ) @@ -2464,38 +2421,36 @@ for( int C = 1; C <= 4; ++C ) sb << "}\n"; } -for (int tt = 0; tt < kBaseTypeCount; ++tt) -{ - if(kBaseTypes[tt].tag == BaseType::Void) continue; - auto toType = kBaseTypes[tt].name; }}}} -__generic extension matrix<$(toType),R,C,L> +//@hidden: +__intrinsic_op($(kIROp_BuiltinCast)) +internal T __builtin_cast(U u); + +// If T is implicitly convertible to U, then vector is implicitly convertible to vector. +__generic extension vector { -${{{{ - for (int ff = 0; ff < kBaseTypeCount; ++ff) - { - if(kBaseTypes[ff].tag == BaseType::Void) continue; - if( tt == ff ) continue; + __implicit_conversion(constraint) + __intrinsic_op(BuiltinCast) + __init(vector value) where ToType(FromType) implicit; - auto cost = getBaseTypeConversionCost( - kBaseTypes[tt], - kBaseTypes[ff]); - auto fromType = kBaseTypes[ff].name; - auto op = getBaseTypeConversionOp( - kBaseTypes[tt], - kBaseTypes[ff]); -}}}} - __implicit_conversion($(cost)) - __intrinsic_op($(op)) - __init(matrix<$(fromType),R,C,L> value); -${{{{ + __implicit_conversion(constraint+) + [__unsafeForceInlineEarly] + [__readNone] + [TreatAsDifferentiable] + __init(FromType value) where ToType(FromType) implicit + { + this = __builtin_cast>(vector(value)); } -}}}} } -${{{{ + +// If T is implicitly convertible to U, then matrix is implicitly convertible to matrix. +__generic extension matrix +{ + __implicit_conversion(constraint) + __intrinsic_op(BuiltinCast) + __init(matrix value) where ToType(FromType) implicit; } -}}}} //@ hidden: __generic diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 6ffaee7db3..b3afa53100 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -948,6 +948,14 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( return witness; } +TypeCoercionWitness* ASTBuilder::getTypeCoercionWitness( + Type* subType, + Type* superType, + DeclRef declRef) +{ + return getOrCreate(subType, superType, declRef.declRefBase); +} + DeclRef _getMemberDeclRef(ASTBuilder* builder, DeclRef parent, Decl* decl) { return builder->getMemberDeclRef(parent, decl); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index cae380e40c..67dfaaf52e 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -636,6 +636,11 @@ class ASTBuilder : public RefObject SubtypeWitness* subIsLWitness, SubtypeWitness* subIsRWitness); + TypeCoercionWitness* getTypeCoercionWitness( + Type* fromType, + Type* toType, + DeclRef declRef); + /// Helpers to get type info from the SharedASTBuilder const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) { diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index ff8e5684a3..ff55340ac7 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -612,6 +612,15 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl const TypeExp& _getSupOverride() const { return sup; } }; +class TypeCoercionConstraintDecl : public Decl +{ + SLANG_AST_CLASS(TypeCoercionConstraintDecl) + + SourceLoc whereTokenLoc = SourceLoc(); + TypeExp fromType; + TypeExp toType; +}; + class GenericValueParamDecl : public VarDeclBase { SLANG_AST_CLASS(GenericValueParamDecl) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index cc49012361..e4d5ccd095 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1279,10 +1279,10 @@ class ImplicitConversionModifier : public Modifier SLANG_AST_CLASS(ImplicitConversionModifier) // The conversion cost, used to rank conversions - ConversionCost cost; + ConversionCost cost = kConversionCost_None; // A builtin identifier for identifying conversions that need special treatment. - BuiltinConversionKind builtinConversionKind; + BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown; }; class FormatAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index d240077216..b3baee98f4 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -178,6 +178,11 @@ enum : ConversionCost // Additional cost when casting an LValue. kConversionCost_LValueCast = 800, + // The cost of this conversion is defined by the type coercion constraint. + kConversionCost_TypeCoercionConstraint = 1000, + kConversionCost_TypeCoercionConstraintPlusScalarToVector = + kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector, + // Conversion is impossible kConversionCost_Impossible = 0xFFFFFFFF, }; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 9bcfd21bc3..7613dbe807 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -845,6 +845,58 @@ void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) out << ")"; } +void TypeCoercionWitness::_toTextOverride(StringBuilder& out) +{ + out << "TypeCoercionWitness("; + if (getFromType()) + out << getFromType(); + if (getToType()) + out << getToType(); + out << ")"; +} + +Val* TypeCoercionWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) +{ + int diff = 0; + + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substFrom = as(getFromType()->substituteImpl(astBuilder, subst, &diff)); + auto substTo = as(getToType()->substituteImpl(astBuilder, subst, &diff)); + + if (!diff) + return this; + + (*ioDiff)++; + + TypeCoercionWitness* substValue = + astBuilder->getTypeCoercionWitness(substFrom, substTo, substDeclRef); + return substValue; +} + +Val* TypeCoercionWitness::_resolveImplOverride() +{ + Val* resolvedDeclRef = nullptr; + if (getDeclRef()) + resolvedDeclRef = getDeclRef().declRefBase->resolve(); + if (auto resolvedVal = as(resolvedDeclRef)) + return resolvedVal; + + auto newFrom = as(getFromType()->resolve()); + auto newTo = as(getToType()->resolve()); + + auto newDeclRef = as(resolvedDeclRef); + if (!newDeclRef) + newDeclRef = getDeclRef().declRefBase; + if (newFrom != getFromType() || newTo != getToType() || newDeclRef != getDeclRef()) + { + return getCurrentASTBuilder()->getTypeCoercionWitness(newFrom, newTo, newDeclRef); + } + return this; +} + // UNormModifierVal void UNormModifierVal::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 7b33a81112..3a14be17b9 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -621,6 +621,20 @@ class TypeEqualityWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +class TypeCoercionWitness : public Witness +{ + SLANG_AST_CLASS(TypeCoercionWitness) + + Type* getFromType() { return as(getOperand(0)); } + Type* getToType() { return as(getOperand(1)); } + + DeclRef getDeclRef() { return as(getOperand(2)); } + + void _toTextOverride(StringBuilder& out); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); +}; + // A witness that one type is a subtype of another // because some in-scope declaration says so class DeclaredSubtypeWitness : public SubtypeWitness diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 872d2616c7..642a4bf6ae 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -715,13 +715,6 @@ DeclRef SemanticsVisitor::trySolveConstraintSystem( // system as being solved now, as a result of the witness we found. } - // Add a flat cost to all unconstrained generic params. - for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType()) - { - if (!constrainedGenericParams.contains(typeParamDecl)) - outBaseCost += kConversionCost_UnconstraintGenericParam; - } - // Make sure we haven't constructed any spurious constraints // that we aren't able to satisfy: for (auto c : system->constraints) @@ -732,6 +725,58 @@ DeclRef SemanticsVisitor::trySolveConstraintSystem( } } + // Verify that all type coercion constraints can be satisfied. + for (auto constraintDecl : + genericDeclRef.getDecl()->getMembersOfType()) + { + DeclRef constraintDeclRef = + m_astBuilder + ->getGenericAppDeclRef( + genericDeclRef, + args.getArrayView().arrayView, + constraintDecl) + .as(); + auto fromType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->fromType.Ptr()); + auto toType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->toType.Ptr()); + auto conversionCost = getConversionCost(toType, fromType); + if (constraintDecl->findModifier()) + { + if (conversionCost > kConversionCost_GeneralConversion) + { + // The type arguments are not implicitly convertible, return failure. + return DeclRef(); + } + } + else + { + if (conversionCost == kConversionCost_Impossible) + { + // The type arguments are not convertible, return failure. + return DeclRef(); + } + } + if (auto fromDecl = isDeclRefTypeOf(constraintDecl->fromType)) + { + constrainedGenericParams.add(fromDecl.getDecl()); + } + if (auto toDecl = isDeclRefTypeOf(constraintDecl->toType)) + { + constrainedGenericParams.add(toDecl.getDecl()); + } + // If we are to expand the support of type coercion constraint beyond simple builtin core + // module functions, then the witness should be a reference to the conversion function. For + // now, this isn't required, and it is not easy to get it from the coercion logic, so we + // leave it empty. + args.add(m_astBuilder->getTypeCoercionWitness(fromType, toType, DeclRef())); + } + + // Add a flat cost to all unconstrained generic params. + for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType()) + { + if (!constrainedGenericParams.contains(typeParamDecl)) + outBaseCost += kConversionCost_UnconstraintGenericParam; + } + return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView); } diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 6dda9c1eac..a9785a585e 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1045,11 +1045,28 @@ int getTypeBitSize(Type* t) } ConversionCost SemanticsVisitor::getImplicitConversionCostWithKnownArg( - Decl* decl, + DeclRef decl, Type* toType, Expr* arg) { - ConversionCost candidateCost = getImplicitConversionCost(decl); + ConversionCost candidateCost = getImplicitConversionCost(decl.getDecl()); + + if (candidateCost == kConversionCost_TypeCoercionConstraint || + candidateCost == kConversionCost_TypeCoercionConstraintPlusScalarToVector) + { + if (auto genApp = as(decl.declRefBase)) + { + for (auto genArg : genApp->getArgs()) + { + if (auto wit = as(genArg)) + { + candidateCost -= kConversionCost_TypeCoercionConstraint; + candidateCost += getConversionCost(wit->getToType(), wit->getFromType()); + break; + } + } + } + } // Fix up the cost if the operand is a const lit. if (isScalarIntegerType(toType)) @@ -1577,10 +1594,8 @@ bool SemanticsVisitor::_coerce( ImplicitCastMethod method; for (auto candidate : overloadContext.bestCandidates) { - ConversionCost candidateCost = getImplicitConversionCostWithKnownArg( - candidate.item.declRef.getDecl(), - toType, - fromExpr); + ConversionCost candidateCost = + getImplicitConversionCostWithKnownArg(candidate.item.declRef, toType, fromExpr); if (candidateCost < bestCost) { method.conversionFuncOverloadCandidate = candidate; @@ -1632,7 +1647,7 @@ bool SemanticsVisitor::_coerce( // cost associated with the initializer we are invoking. // ConversionCost cost = getImplicitConversionCostWithKnownArg( - overloadContext.bestCandidate->item.declRef.getDecl(), + overloadContext.bestCandidate->item.declRef, toType, fromExpr); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1ef5b1cec9..5b5e05b735 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -147,6 +147,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase, void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl); + void visitTypeCoercionConstraintDecl(TypeCoercionConstraintDecl* decl); + void validateGenericConstraintSubType(GenericTypeConstraintDecl* decl, TypeExp type); void visitGenericDecl(GenericDecl* genericDecl); @@ -2911,6 +2913,16 @@ void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType( } } +void SemanticsDeclHeaderVisitor::visitTypeCoercionConstraintDecl(TypeCoercionConstraintDecl* decl) +{ + CheckConstraintSubType(decl->toType); + + if (!decl->fromType.type) + decl->fromType = TranslateTypeNodeForced(decl->fromType); + if (!decl->toType.type) + decl->toType = TranslateTypeNodeForced(decl->toType); +} + void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) { // TODO: are there any other validations we can do at this point? diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 59290f8ad9..6438a91e37 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1511,7 +1511,10 @@ struct SemanticsVisitor : public SemanticsContext // perform implicit type conversion. ConversionCost getImplicitConversionCost(Decl* decl); - ConversionCost getImplicitConversionCostWithKnownArg(Decl* decl, Type* toType, Expr* arg); + ConversionCost getImplicitConversionCostWithKnownArg( + DeclRef decl, + Type* toType, + Expr* arg); BuiltinConversionKind getImplicitConversionBuiltinKind(Decl* decl); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index b944d2bf4f..b75f95f9a8 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1675,6 +1675,15 @@ int SemanticsVisitor::CompareOverloadCandidates(OverloadCandidate* left, Overloa if (itemDiff) return itemDiff; + // If one candidate is an implicit conversion, and other candidate is not, + // then we should prefer the implicit conversion. + int leftIsImplicitConversion = + left->item.declRef.getDecl()->findModifier() ? 1 : 0; + int rightIsImplicitConversion = + right->item.declRef.getDecl()->findModifier() ? 1 : 0; + if (leftIsImplicitConversion != rightIsImplicitConversion) + return rightIsImplicitConversion - leftIsImplicitConversion; + auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item); if (specificityDiff) return specificityDiff; diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index ff6d643196..620c65d4ef 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -116,6 +116,7 @@ bool opCanBeConstExpr(IROp op) case kIROp_PtrCast: case kIROp_Reinterpret: case kIROp_BitCast: + case kIROp_BuiltinCast: case kIROp_MakeTuple: case kIROp_MakeDifferentialPair: case kIROp_MakeExistential: @@ -178,7 +179,13 @@ bool opCanBeConstExprByBackwardPass(IRInst* value) { if (value->getOp() == kIROp_Param) return isLoopPhi(as(value)); - return opCanBeConstExpr(value->getOp()); + if (opCanBeConstExpr(value->getOp())) + return true; + if (auto callInst = as(value)) + { + return !callInst->mightHaveSideEffects(); + } + return false; } void markConstExpr(PropagateConstExprContext* context, IRInst* value) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 55880eab5d..5a1966d004 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1202,6 +1202,7 @@ INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, HOIST INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) +INST(BuiltinCast, BuiltinCast, 1, 0) INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) INST(Unmodified, unmodified, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index dbefa68c7e..d64820aa6e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4025,7 +4025,7 @@ struct IRBuilder /// the inst. IRInst* emitDefaultConstructRaw(IRType* type); - IRInst* emitCast(IRType* type, IRInst* value); + IRInst* emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinCast = true); IRInst* emitVectorReshape(IRType* type, IRInst* value); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index fc399954bc..e29fdf975d 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -98,6 +98,7 @@ struct PeepholeContext : InstPassBase else if (remainingKeys.getCount() > 0) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto newValue = builder.emitElementExtract(updateInst->getElementValue(), remainingKeys); @@ -112,6 +113,7 @@ struct PeepholeContext : InstPassBase // accessChain!=accessChain2, then we can replace the inst with extract(x, // accessChain2). IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView()); @@ -140,6 +142,8 @@ struct PeepholeContext : InstPassBase if (vectorType->getElementType() != replacement->getFullType()) return false; IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); replacement = builder.emitMakeVectorFromScalar(inst->getFullType(), replacement); @@ -175,6 +179,7 @@ struct PeepholeContext : InstPassBase else if (inst->getOperand(0) == inst->getOperand(1)) { IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); return tryReplace(builder.emitDefaultConstruct(inst->getDataType())); } @@ -280,6 +285,8 @@ struct PeepholeContext : InstPassBase break; IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); IRInst* resultVal = nullptr; if (inst->getOp() == kIROp_AlignOf) @@ -319,6 +326,8 @@ struct PeepholeContext : InstPassBase if (inst->getOperand(0)->getOp() == kIROp_MakeResultError) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + inst->replaceUsesWith(builder.getBoolValue(true)); maybeRemoveOldInst(inst); changed = true; @@ -326,6 +335,8 @@ struct PeepholeContext : InstPassBase else if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + inst->replaceUsesWith(builder.getBoolValue(false)); maybeRemoveOldInst(inst); changed = true; @@ -359,6 +370,8 @@ struct PeepholeContext : InstPassBase if (const auto packType = as(pack->getDataType())) { IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); List args; for (UInt j = 0; j < packType->getOperandCount(); ++j) @@ -443,6 +456,8 @@ struct PeepholeContext : InstPassBase index->getValue() < startIndex + vecSize->getValue()) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newElement = builder.emitElementExtract( element, @@ -517,6 +532,8 @@ struct PeepholeContext : InstPassBase if (args.getCount() == arraySize->getValue()) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeArray = builder.emitMakeArray( arrayType, @@ -573,6 +590,8 @@ struct PeepholeContext : InstPassBase if (isComplete) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeArray = builder.emitMakeArray( arrayType, @@ -618,6 +637,8 @@ struct PeepholeContext : InstPassBase if (isValid) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeStruct = builder.emitMakeStruct( structType, @@ -678,6 +699,8 @@ struct PeepholeContext : InstPassBase // Create a makeStruct inst using args. IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeStruct = builder.emitMakeStruct( structType, @@ -694,6 +717,8 @@ struct PeepholeContext : InstPassBase { auto ptr = inst->getOperand(0); IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto neq = builder.emitNeq(ptr, builder.getNullPtrValue(ptr->getDataType())); inst->replaceUsesWith(neq); @@ -708,6 +733,8 @@ struct PeepholeContext : InstPassBase if (isTypeEqual(actualType, (IRType*)isTypeInst->getTypeOperand())) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto trueVal = builder.getBoolValue(true); inst->replaceUsesWith(trueVal); @@ -770,6 +797,7 @@ struct PeepholeContext : InstPassBase if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto trueVal = builder.getBoolValue(true); inst->replaceUsesWith(trueVal); @@ -779,6 +807,8 @@ struct PeepholeContext : InstPassBase else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto falseVal = builder.getBoolValue(false); inst->replaceUsesWith(falseVal); @@ -841,6 +871,7 @@ struct PeepholeContext : InstPassBase case kIROp_DefaultConstruct: { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); // See if we can replace the default construct inst with concrete values. if (auto newCtor = builder.emitDefaultConstruct(inst->getFullType(), false)) @@ -851,6 +882,21 @@ struct PeepholeContext : InstPassBase } } break; + case kIROp_BuiltinCast: + { + IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); + // See if we can replace the default construct inst with concrete values. + if (auto newCast = + builder.emitCast(inst->getFullType(), inst->getOperand(0), false)) + { + inst->replaceUsesWith(newCast); + maybeRemoveOldInst(inst); + changed = true; + } + } + break; case kIROp_VectorReshape: { auto fromType = as(inst->getOperand(0)->getDataType()); @@ -867,6 +913,7 @@ struct PeepholeContext : InstPassBase break; } IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); UInt index = 0; auto newInst = builder.emitSwizzle(resultType, inst->getOperand(0), 1, &index); @@ -882,6 +929,8 @@ struct PeepholeContext : InstPassBase if (!toCount) break; IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newInst = builder.emitVectorReshape(resultType, inst->getOperand(0)); if (newInst != inst) @@ -911,6 +960,7 @@ struct PeepholeContext : InstPassBase break; List rows; IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto toRowType = builder.getVectorType( resultType->getElementType(), @@ -1035,6 +1085,8 @@ struct PeepholeContext : InstPassBase break; } IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newInst = builder.emitMakeVectorFromScalar(vectorType, inst->getOperand(0)); @@ -1075,6 +1127,8 @@ struct PeepholeContext : InstPassBase else { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newMakeVector = builder.emitMakeVector( swizzle->getDataType(), @@ -1100,6 +1154,8 @@ struct PeepholeContext : InstPassBase if (isConcreteType(left) && isConcreteType(right)) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); bool result = left == right; inst->replaceUsesWith(builder.getBoolValue(result)); @@ -1123,6 +1179,8 @@ struct PeepholeContext : InstPassBase if (!SLANG_SUCCEEDED(res)) break; IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto stride = builder.getIntValue(inst->getDataType(), sizeAlignment.getStride()); @@ -1148,6 +1206,8 @@ struct PeepholeContext : InstPassBase if (isConcreteType(type)) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); bool result = false; switch (inst->getOp()) @@ -1186,6 +1246,8 @@ struct PeepholeContext : InstPassBase if (as(inst)->getPtr()->getOp() == kIROp_undefined) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto undef = builder.emitUndefined(inst->getDataType()); inst->replaceUsesWith(undef); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index cdabb1ac2d..f28f61ffc2 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3995,7 +3995,7 @@ static TypeCastStyle _getTypeStyleId(IRType* type) } } -IRInst* IRBuilder::emitCast(IRType* type, IRInst* value) +IRInst* IRBuilder::emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinCast) { if (isTypeEqual(type, value->getDataType())) return value; @@ -4009,8 +4009,17 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value) SLANG_UNREACHABLE("cast from void type"); } - SLANG_RELEASE_ASSERT(toStyle != TypeCastStyle::Unknown); - SLANG_RELEASE_ASSERT(fromStyle != TypeCastStyle::Unknown); + if (toStyle == TypeCastStyle::Unknown || fromStyle == TypeCastStyle::Unknown) + { + if (fallbackToBuiltinCast) + { + return emitIntrinsicInst(type, kIROp_BuiltinCast, 1, &value); + } + else + { + return nullptr; + } + } struct OpSeq { @@ -4057,7 +4066,18 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value) auto t = type; if (op.op1 != kIROp_Nop) { - t = getUInt64Type(); + if (toStyle == TypeCastStyle::Bool) + t = getIntType(); + else + t = getUInt64Type(); + if (auto vecType = as(type)) + t = getVectorType(t, vecType->getElementCount()); + else if (auto matType = as(type)) + t = getMatrixType( + t, + matType->getRowCount(), + matType->getColumnCount(), + matType->getLayout()); } auto result = emitIntrinsicInst(t, op.op0, 1, &value); if (op.op1 != kIROp_Nop) @@ -8293,6 +8313,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialWitnessTable: case kIROp_WrapExistential: + case kIROp_BuiltinCast: case kIROp_BitCast: case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ed8a52b9e5..fbe6d8a848 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1737,6 +1737,13 @@ struct ValLoweringVisitor : ValVisitorirBuilder->getTypeEqualityWitness(witnessType, subType, supType)); } + LoweredValInfo visitTypeCoercionWitness(TypeCoercionWitness*) + { + // When we fully support type coercion constraints, we should lower the witness into a + // function that does the conversion. + return LoweredValInfo(); + } + LoweredValInfo visitTransitiveSubtypeWitness(TransitiveSubtypeWitness* val) { // The base (subToMid) will turn into a value with diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 82cb8caf3d..aec3b4e908 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1721,6 +1721,20 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP constraint->sup = parser->ParseTypeExp(); AddMember(genericParent, constraint); } + else if (AdvanceIf(parser, TokenType::LParent)) + { + auto constraint = parser->astBuilder->create(); + constraint->whereTokenLoc = whereToken.loc; + parser->FillPosition(constraint); + constraint->toType = subType; + constraint->fromType = parser->ParseTypeExp(); + parser->ReadToken(TokenType::RParent); + if (AdvanceIf(parser, "implicit")) + { + addModifier(constraint, parser->astBuilder->create()); + } + AddMember(genericParent, constraint); + } } } @@ -8910,8 +8924,19 @@ static NodeBase* parseImplicitConversionModifier(Parser* parser, void* /*userDat ConversionCost cost = kConversionCost_Default; if (AdvanceIf(parser, TokenType::LParent)) { - cost = - ConversionCost(stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); + if (AdvanceIf(parser, "constraint")) + { + cost = kConversionCost_TypeCoercionConstraint; + if (AdvanceIf(parser, TokenType::OpAdd)) + { + cost = kConversionCost_TypeCoercionConstraintPlusScalarToVector; + } + } + else + { + cost = ConversionCost( + stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); + } if (AdvanceIf(parser, TokenType::Comma)) { builtinKind = BuiltinConversionKind(