@@ -884,7 +884,8 @@ namespace Slang
884
884
885
885
auto funcDeclRef = getDeclRef (m_astBuilder, funcDeclRefExpr);
886
886
auto intrinsicMod = funcDeclRef.getDecl ()->findModifier <IntrinsicOpModifier>();
887
- if (!intrinsicMod)
887
+ auto implicitCast = funcDeclRef.getDecl ()->findModifier <ImplicitConversionModifier>();
888
+ if (!intrinsicMod && !implicitCast)
888
889
{
889
890
// We can't constant fold anything that doesn't map to a builtin
890
891
// operation right now.
@@ -942,54 +943,91 @@ namespace Slang
942
943
943
944
// At this point, all the operands had simple integer values, so we are golden.
944
945
IntegerLiteralValue resultValue = 0 ;
945
- auto opName = funcDeclRef.getName ();
946
-
947
- // handle binary operators
948
- if (opName == getName (" -" ))
946
+ // If this is an implicit cast, we can try to fold.
947
+ if (implicitCast)
949
948
{
950
- if (argCount == 1 )
949
+ auto targetBasicType = as<BasicExpressionType>(invokeExpr.getExpr ()->type .type );
950
+ if (!targetBasicType)
951
+ return nullptr ;
952
+ switch (targetBasicType->baseType )
951
953
{
952
- resultValue = -constArgVals[0 ];
954
+ case BaseType::Bool:
955
+ resultValue = constArgVals[0 ] != 0 ;
956
+ break ;
957
+ case BaseType::Int:
958
+ case BaseType::UInt:
959
+ case BaseType::UInt16 :
960
+ case BaseType::Int16:
961
+ case BaseType::UInt8 :
962
+ case BaseType::Int8:
963
+ resultValue = constArgVals[0 ];
964
+ break ;
965
+ default :
966
+ return nullptr ;
953
967
}
954
- else if (argCount == 2 )
968
+ }
969
+ else
970
+ {
971
+ auto opName = funcDeclRef.getName ();
972
+
973
+ // handle binary operators
974
+ if (opName == getName (" -" ))
955
975
{
956
- resultValue = constArgVals[0 ] - constArgVals[1 ];
976
+ if (argCount == 1 )
977
+ {
978
+ resultValue = -constArgVals[0 ];
979
+ }
980
+ else if (argCount == 2 )
981
+ {
982
+ resultValue = constArgVals[0 ] - constArgVals[1 ];
983
+ }
984
+ }
985
+ else if (opName == getName (" !" ))
986
+ {
987
+ resultValue = constArgVals[0 ] != 0 ;
988
+ }
989
+ else if (opName == getName (" ~" ))
990
+ {
991
+ resultValue = ~constArgVals[0 ];
957
992
}
958
- }
959
993
960
- // simple binary operators
994
+ // simple binary operators
961
995
#define CASE (OP ) \
962
- else if (opName == getName (#OP)) do { \
963
- if (argCount != 2 ) return nullptr ; \
964
- resultValue = constArgVals[0 ] OP constArgVals[1 ]; \
965
- } while (0 )
966
-
967
- CASE (+); // TODO: this can also be unary...
968
- CASE (*);
969
- CASE (<<);
970
- CASE (>>);
971
- CASE (&);
972
- CASE (|);
973
- CASE (^);
996
+ else if (opName == getName (#OP)) do { \
997
+ if (argCount != 2 ) return nullptr ; \
998
+ resultValue = constArgVals[0 ] OP constArgVals[1 ]; \
999
+ } while (0 )
1000
+
1001
+ CASE (+); // TODO: this can also be unary...
1002
+ CASE (*);
1003
+ CASE (<<);
1004
+ CASE (>>);
1005
+ CASE (&);
1006
+ CASE (|);
1007
+ CASE (^);
1008
+ CASE (!=);
1009
+ CASE (==);
1010
+ CASE (>=);
1011
+ CASE (<=);
1012
+ CASE (<);
1013
+ CASE (>);
974
1014
#undef CASE
975
-
976
- // binary operators with chance of divide-by-zero
977
- // TODO: issue a suitable error in that case
1015
+ // binary operators with chance of divide-by-zero
1016
+ // TODO: issue a suitable error in that case
978
1017
#define CASE (OP ) \
979
- else if (opName == getName (#OP)) do { \
980
- if (argCount != 2 ) return nullptr ; \
981
- if (!constArgVals[1 ]) return nullptr ; \
982
- resultValue = constArgVals[0 ] OP constArgVals[1 ]; \
983
- } while (0 )
984
-
985
- CASE (/);
986
- CASE (%);
1018
+ else if (opName == getName (#OP)) do { \
1019
+ if (argCount != 2 ) return nullptr ; \
1020
+ if (!constArgVals[1 ]) return nullptr ; \
1021
+ resultValue = constArgVals[0 ] OP constArgVals[1 ]; \
1022
+ } while (0 )
1023
+ CASE (/);
1024
+ CASE (%);
987
1025
#undef CASE
988
-
989
- // TODO(tfoley): more cases
990
- else
991
- {
992
- return nullptr ;
1026
+ // TODO(tfoley): more cases
1027
+ else
1028
+ {
1029
+ return nullptr ;
1030
+ }
993
1031
}
994
1032
995
1033
IntVal* result = m_astBuilder->create <ConstantIntVal>(resultValue);
@@ -1126,13 +1164,27 @@ namespace Slang
1126
1164
return tryConstantFoldExpr (expr, circularityInfo);
1127
1165
}
1128
1166
1129
- IntVal* SemanticsVisitor::CheckIntegerConstantExpression (Expr* inExpr, DiagnosticSink* sink)
1167
+ IntVal* SemanticsVisitor::CheckIntegerConstantExpression (Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, DiagnosticSink* sink)
1130
1168
{
1131
1169
// No need to issue further errors if the expression didn't even type-check.
1132
1170
if (IsErrorExpr (inExpr)) return nullptr ;
1133
1171
1134
1172
// First coerce the expression to the expected type
1135
- auto expr = coerce (m_astBuilder->getIntType (),inExpr);
1173
+ Expr* expr = nullptr ;
1174
+ switch (coercionType)
1175
+ {
1176
+ case IntegerConstantExpressionCoercionType::SpecificType:
1177
+ expr = coerce (expectedType, inExpr);
1178
+ break ;
1179
+ case IntegerConstantExpressionCoercionType::AnyInteger:
1180
+ if (isScalarIntegerType (inExpr->type ))
1181
+ expr = inExpr;
1182
+ else
1183
+ expr = coerce (m_astBuilder->getIntType (), inExpr);
1184
+ break ;
1185
+ default :
1186
+ break ;
1187
+ }
1136
1188
1137
1189
// No need to issue further errors if the type coercion failed.
1138
1190
if (IsErrorExpr (expr)) return nullptr ;
@@ -1145,9 +1197,9 @@ namespace Slang
1145
1197
return result;
1146
1198
}
1147
1199
1148
- IntVal* SemanticsVisitor::CheckIntegerConstantExpression (Expr* inExpr)
1200
+ IntVal* SemanticsVisitor::CheckIntegerConstantExpression (Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType )
1149
1201
{
1150
- return CheckIntegerConstantExpression (inExpr, getSink ());
1202
+ return CheckIntegerConstantExpression (inExpr, coercionType, expectedType, getSink ());
1151
1203
}
1152
1204
1153
1205
IntVal* SemanticsVisitor::CheckEnumConstantExpression (Expr* expr)
@@ -1218,7 +1270,7 @@ namespace Slang
1218
1270
IntVal* elementCount = nullptr ;
1219
1271
if (indexExpr)
1220
1272
{
1221
- elementCount = CheckIntegerConstantExpression (indexExpr);
1273
+ elementCount = CheckIntegerConstantExpression (indexExpr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr );
1222
1274
}
1223
1275
1224
1276
auto elementType = CoerceToUsableType (TypeExp (baseExpr, baseTypeType->type ));
0 commit comments