Skip to content

Commit a01c09c

Browse files
authored
Literal folding on other operators (shader-slang#1314)
* Fold prefix operators if they prefix an int literal. * Make test case a bit more convoluted. * Remove ++ and -- as not appropriate for folding of literals. * Set output buffer name.
1 parent 78acd32 commit a01c09c

4 files changed

+93
-16
lines changed

source/slang/slang-parser.cpp

+48-13
Original file line numberDiff line numberDiff line change
@@ -4706,29 +4706,63 @@ namespace Slang
47064706
}
47074707
}
47084708

4709+
static IRIntegerValue _foldIntegerPrefixOp(TokenType tokenType, IRIntegerValue value)
4710+
{
4711+
switch (tokenType)
4712+
{
4713+
case TokenType::OpNot: return !value;
4714+
case TokenType::OpBitNot: return ~value;
4715+
case TokenType::OpAdd: return value;
4716+
case TokenType::OpSub: return -value;
4717+
default:
4718+
{
4719+
SLANG_ASSERT(!"Unexpected op");
4720+
return value;
4721+
}
4722+
}
4723+
}
4724+
4725+
static IRFloatingPointValue _foldFloatPrefixOp(TokenType tokenType, IRFloatingPointValue value)
4726+
{
4727+
switch (tokenType)
4728+
{
4729+
case TokenType::OpNot: return !value;
4730+
case TokenType::OpAdd: return value;
4731+
case TokenType::OpSub: return -value;
4732+
default:
4733+
{
4734+
SLANG_ASSERT(!"Unexpected op");
4735+
return value;
4736+
}
4737+
}
4738+
}
4739+
47094740
static RefPtr<Expr> parsePrefixExpr(Parser* parser)
47104741
{
4711-
switch( peekTokenType(parser) )
4742+
auto tokenType = peekTokenType(parser);
4743+
switch( tokenType )
47124744
{
47134745
default:
47144746
return parsePostfixExpr(parser);
47154747

4748+
47164749
case TokenType::OpInc:
47174750
case TokenType::OpDec:
4751+
{
4752+
RefPtr<PrefixExpr> prefixExpr = new PrefixExpr();
4753+
parser->FillPosition(prefixExpr.Ptr());
4754+
prefixExpr->FunctionExpr = parseOperator(parser);
4755+
4756+
auto arg = parsePrefixExpr(parser);
4757+
4758+
prefixExpr->Arguments.add(arg);
4759+
return prefixExpr;
4760+
}
47184761
case TokenType::OpNot:
47194762
case TokenType::OpBitNot:
47204763
case TokenType::OpAdd:
4721-
{
4722-
RefPtr<PrefixExpr> prefixExpr = new PrefixExpr();
4723-
parser->FillPosition(prefixExpr.Ptr());
4724-
prefixExpr->FunctionExpr = parseOperator(parser);
4725-
prefixExpr->Arguments.add(parsePrefixExpr(parser));
4726-
return prefixExpr;
4727-
}
47284764
case TokenType::OpSub:
47294765
{
4730-
// Special case prefix sub (aka neg), so if it's on a literal, it produces a new literal
4731-
47324766
RefPtr<PrefixExpr> prefixExpr = new PrefixExpr();
47334767
parser->FillPosition(prefixExpr.Ptr());
47344768
prefixExpr->FunctionExpr = parseOperator(parser);
@@ -4739,7 +4773,7 @@ namespace Slang
47394773
{
47404774
RefPtr<IntegerLiteralExpr> newLiteral = new IntegerLiteralExpr(*intLit);
47414775

4742-
IRIntegerValue value = -newLiteral->value;
4776+
IRIntegerValue value = _foldIntegerPrefixOp(tokenType, newLiteral->value);
47434777

47444778
// Need to get the basic type, so we can fit to underlying type
47454779
if (auto basicExprType = as<BasicExpressionType>(intLit->type.type))
@@ -4753,13 +4787,14 @@ namespace Slang
47534787
else if (auto floatLit = as<FloatingPointLiteralExpr>(arg))
47544788
{
47554789
RefPtr<FloatingPointLiteralExpr> newLiteral = new FloatingPointLiteralExpr(*floatLit);
4756-
newLiteral->value = -newLiteral->value;
4790+
newLiteral->value = _foldFloatPrefixOp(tokenType, floatLit->value);
47574791
return newLiteral;
47584792
}
4759-
4793+
47604794
prefixExpr->Arguments.add(arg);
47614795
return prefixExpr;
47624796
}
4797+
47634798
break;
47644799
}
47654800
}

tests/compute/static-const-array.slang

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// static-const-array.slang
22

33
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
4-
//TEST_DISABLED(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
4+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
5+
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -slang -compute
56

6-
7-
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out
7+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name outputBuffer
88
RWStructuredBuffer<int> outputBuffer;
99

1010
static const int kArray[] = { 16, 1, 32, 2 };
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// static-const-array.slang
2+
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -output-using-type
4+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -output-using-type
5+
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -slang -compute -output-using-type
6+
7+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out, name outputBuffer
8+
RWStructuredBuffer<float> outputBuffer;
9+
10+
static const float3 kArray[8] =
11+
{
12+
float3(-0.4706069, -0.4427112, - - + 0.6461146),
13+
float3(-0.9057375, +0.3003471, +0.9542373),
14+
float3(-0.3487388, +0.4037880, +0.5335386),
15+
float3(+0.1023042, +0.6439373, +0.6520134),
16+
float3(+0.5699277, +0.3513750, +0.6695386),
17+
float3(+0.2939128, -0.1131226, +0.3149309),
18+
float3(+0.7836658, -0.4208784, +0.8895339),
19+
float3(+0.1564120, -0.8198990, +0.8346850)
20+
};
21+
22+
float test(int val)
23+
{
24+
return kArray[val].x + kArray[val].y + kArray[val].z;
25+
}
26+
27+
[numthreads(8, 1, 1)]
28+
void computeMain(uint3 tid : SV_DispatchThreadID)
29+
{
30+
int inVal = tid.x;
31+
float outVal = test(inVal);
32+
outputBuffer[inVal] = outVal;
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
type: float
2+
-0.267204
3+
0.348847
4+
0.588588
5+
1.398255
6+
1.590841
7+
0.495721
8+
1.252321
9+
0.171198

0 commit comments

Comments
 (0)