Skip to content

Commit c8e36bd

Browse files
authored
Hotfix/bool fix (shader-slang#907)
* * Handle ! for bool vector in glsl * Handle operators that have a boolean return value * || or && take bool * * Add comment in bool-op.slang test about doing || or && on vector types not supported for GLSL targets
1 parent 0441643 commit c8e36bd

File tree

6 files changed

+146
-24
lines changed

6 files changed

+146
-24
lines changed

source/slang/core.meta.slang

+7-4
Original file line numberDiff line numberDiff line change
@@ -1103,22 +1103,25 @@ for (auto op : unaryOps)
11031103
if ((type.flags & op.flags) == 0)
11041104
continue;
11051105

1106+
char const* resultType = type.name;
1107+
if (op.flags & BOOL_RESULT) resultType = "bool";
1108+
11061109
char const* fixity = (op.flags & POSTFIX) != 0 ? "__postfix " : "__prefix ";
11071110
char const* qual = (op.flags & ASSIGNMENT) != 0 ? "in out " : "";
11081111

11091112
// scalar version
11101113
sb << fixity;
1111-
sb << "__intrinsic_op(" << int(op.opCode) << ") " << type.name << " operator" << op.opName << "(" << qual << type.name << " value);\n";
1114+
sb << "__intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << qual << type.name << " value);\n";
11121115

11131116
// vector version
11141117
sb << "__generic<let N : int> ";
11151118
sb << fixity;
1116-
sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << type.name << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n";
1119+
sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n";
11171120

11181121
// matrix version
11191122
sb << "__generic<let N : int, let M : int> ";
11201123
sb << fixity;
1121-
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << type.name << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n";
1124+
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n";
11221125
}
11231126
}
11241127

@@ -1133,7 +1136,7 @@ for (auto op : binaryOps)
11331136
char const* rightType = leftType;
11341137
char const* resultType = leftType;
11351138

1136-
if (op.flags & COMPARISON) resultType = "bool";
1139+
if (op.flags & BOOL_RESULT) resultType = "bool";
11371140

11381141
char const* leftQual = "";
11391142
if(op.flags & ASSIGNMENT) leftQual = "in out ";

source/slang/core.meta.slang.h

+8-5
Original file line numberDiff line numberDiff line change
@@ -1121,22 +1121,25 @@ for (auto op : unaryOps)
11211121
if ((type.flags & op.flags) == 0)
11221122
continue;
11231123

1124+
char const* resultType = type.name;
1125+
if (op.flags & BOOL_RESULT) resultType = "bool";
1126+
11241127
char const* fixity = (op.flags & POSTFIX) != 0 ? "__postfix " : "__prefix ";
11251128
char const* qual = (op.flags & ASSIGNMENT) != 0 ? "in out " : "";
11261129

11271130
// scalar version
11281131
sb << fixity;
1129-
sb << "__intrinsic_op(" << int(op.opCode) << ") " << type.name << " operator" << op.opName << "(" << qual << type.name << " value);\n";
1132+
sb << "__intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << qual << type.name << " value);\n";
11301133

11311134
// vector version
11321135
sb << "__generic<let N : int> ";
11331136
sb << fixity;
1134-
sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << type.name << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n";
1137+
sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n";
11351138

11361139
// matrix version
11371140
sb << "__generic<let N : int, let M : int> ";
11381141
sb << fixity;
1139-
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << type.name << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n";
1142+
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n";
11401143
}
11411144
}
11421145

@@ -1151,7 +1154,7 @@ for (auto op : binaryOps)
11511154
char const* rightType = leftType;
11521155
char const* resultType = leftType;
11531156

1154-
if (op.flags & COMPARISON) resultType = "bool";
1157+
if (op.flags & BOOL_RESULT) resultType = "bool";
11551158

11561159
char const* leftQual = "";
11571160
if(op.flags & ASSIGNMENT) leftQual = "in out ";
@@ -1202,7 +1205,7 @@ for (auto op : binaryOps)
12021205
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftQual << "matrix<" << leftType << ",N,M> left, " << rightType << " right);\n";
12031206
}
12041207
}
1205-
SLANG_RAW("#line 1187 \"core.meta.slang\"")
1208+
SLANG_RAW("#line 1190 \"core.meta.slang\"")
12061209
SLANG_RAW("\n")
12071210
SLANG_RAW("\n")
12081211
SLANG_RAW("// Operators to apply to `enum` types\n")

source/slang/emit.cpp

+28-5
Original file line numberDiff line numberDiff line change
@@ -3621,6 +3621,33 @@ struct EmitVisitor
36213621
}
36223622
}
36233623

3624+
void emitNot(EmitContext* ctx, IRInst* inst, IREmitMode mode, EOpInfo& ioOuterPrec, bool* outNeedClose)
3625+
{
3626+
IRInst* operand = inst->getOperand(0);
3627+
3628+
if (getTarget(ctx) == CodeGenTarget::GLSL)
3629+
{
3630+
if (auto vectorType = as<IRVectorType>(operand->getDataType()))
3631+
{
3632+
// Handle as a function call
3633+
auto prec = kEOp_Postfix;
3634+
*outNeedClose = maybeEmitParens(ioOuterPrec, prec);
3635+
3636+
emit("not(");
3637+
emitIROperand(ctx, operand, mode, kEOp_General);
3638+
emit(")");
3639+
return;
3640+
}
3641+
}
3642+
3643+
auto prec = kEOp_Prefix;
3644+
*outNeedClose = maybeEmitParens(ioOuterPrec, prec);
3645+
3646+
emit("!");
3647+
emitIROperand(ctx, operand, mode, rightSide(prec, ioOuterPrec));
3648+
}
3649+
3650+
36243651
void emitComparison(EmitContext* ctx, IRInst* inst, IREmitMode mode, EOpInfo& ioOuterPrec, const EOpInfo& opPrec, bool* needCloseOut)
36253652
{
36263653
if (getTarget(ctx) == CodeGenTarget::GLSL)
@@ -3839,11 +3866,7 @@ struct EmitVisitor
38393866

38403867
case kIROp_Not:
38413868
{
3842-
auto prec = kEOp_Prefix;
3843-
needClose = maybeEmitParens(outerPrec, prec);
3844-
3845-
emit("!");
3846-
emitIROperand(ctx, inst->getOperand(0), mode, rightSide(prec, outerPrec));
3869+
emitNot(ctx, inst, mode, outerPrec, &needClose);
38473870
}
38483871
break;
38493872

source/slang/slang-stdlib.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace Slang
4646
{
4747
SINT_MASK = 1 << 0,
4848
FLOAT_MASK = 1 << 1,
49-
COMPARISON = 1 << 2,
49+
BOOL_RESULT = 1 << 2,
5050
BOOL_MASK = 1 << 3,
5151
UINT_MASK = 1 << 4,
5252
ASSIGNMENT = 1 << 5,
@@ -203,7 +203,7 @@ namespace Slang
203203
static const OpInfo unaryOps[] = {
204204
{ kIRPseudoOp_Pos, "+", ARITHMETIC_MASK },
205205
{ kIROp_Neg, "-", ARITHMETIC_MASK },
206-
{ kIROp_Not, "!", ANY_MASK },
206+
{ kIROp_Not, "!", BOOL_MASK | BOOL_RESULT },
207207
{ kIROp_BitNot, "~", INT_MASK },
208208
{ kIRPseudoOp_PreInc, "++", ARITHMETIC_MASK | ASSIGNMENT },
209209
{ kIRPseudoOp_PreDec, "--", ARITHMETIC_MASK | ASSIGNMENT },
@@ -217,19 +217,19 @@ namespace Slang
217217
{ kIROp_Mul, "*", ARITHMETIC_MASK },
218218
{ kIROp_Div, "/", ARITHMETIC_MASK },
219219
{ kIROp_Mod, "%", INT_MASK },
220-
{ kIROp_And, "&&", LOGICAL_MASK },
221-
{ kIROp_Or, "||", LOGICAL_MASK },
220+
{ kIROp_And, "&&", BOOL_MASK | BOOL_RESULT},
221+
{ kIROp_Or, "||", BOOL_MASK | BOOL_RESULT },
222222
{ kIROp_BitAnd, "&", LOGICAL_MASK },
223223
{ kIROp_BitOr, "|", LOGICAL_MASK },
224224
{ kIROp_BitXor, "^", LOGICAL_MASK },
225225
{ kIROp_Lsh, "<<", INT_MASK },
226226
{ kIROp_Rsh, ">>", INT_MASK },
227-
{ kIROp_Eql, "==", ANY_MASK | COMPARISON },
228-
{ kIROp_Neq, "!=", ANY_MASK | COMPARISON },
229-
{ kIROp_Greater, ">", ARITHMETIC_MASK | COMPARISON },
230-
{ kIROp_Less, "<", ARITHMETIC_MASK | COMPARISON },
231-
{ kIROp_Geq, ">=", ARITHMETIC_MASK | COMPARISON },
232-
{ kIROp_Leq, "<=", ARITHMETIC_MASK | COMPARISON },
227+
{ kIROp_Eql, "==", ANY_MASK | BOOL_RESULT },
228+
{ kIROp_Neq, "!=", ANY_MASK | BOOL_RESULT },
229+
{ kIROp_Greater, ">", ARITHMETIC_MASK | BOOL_RESULT },
230+
{ kIROp_Less, "<", ARITHMETIC_MASK | BOOL_RESULT },
231+
{ kIROp_Geq, ">=", ARITHMETIC_MASK | BOOL_RESULT },
232+
{ kIROp_Leq, "<=", ARITHMETIC_MASK | BOOL_RESULT },
233233
{ kIRPseudoOp_AddAssign, "+=", ASSIGNMENT | ARITHMETIC_MASK },
234234
{ kIRPseudoOp_SubAssign, "-=", ASSIGNMENT | ARITHMETIC_MASK },
235235
{ kIRPseudoOp_MulAssign, "*=", ASSIGNMENT | ARITHMETIC_MASK },

tests/bugs/bool-op.slang

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// enum.slang
2+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
3+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
4+
5+
// Confirm operations that produce bools - such as comparisons, or && ||, ! work correctly
6+
7+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
8+
RWStructuredBuffer<int> outputBuffer;
9+
10+
[numthreads(16, 1, 1)]
11+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
12+
{
13+
uint tid = dispatchThreadID.x;;
14+
15+
uint uv = tid;
16+
int iv = int(tid);
17+
18+
bool2 bv2 = { tid & 1, tid > 10 };
19+
20+
float2 f2 = { float(tid), float(tid + 1) };
21+
22+
bool bv = tid > 6;
23+
let not_bv2 = !bv2;
24+
25+
int r = 0;
26+
if (bv)
27+
{
28+
r |= 0x0001;
29+
}
30+
if (!bv)
31+
{
32+
r |= 0x0002;
33+
}
34+
35+
if (iv)
36+
{
37+
r|= 0x0004;
38+
}
39+
if (!iv)
40+
{
41+
r|= 0x0008;
42+
}
43+
44+
if (uv)
45+
{
46+
r |= 0x0010;
47+
}
48+
if (!uv)
49+
{
50+
r |= 0x0020;
51+
}
52+
53+
if (all(bv2))
54+
{
55+
r |= 0x0040;
56+
}
57+
if (all(not_bv2))
58+
{
59+
r |= 0x0080;
60+
}
61+
62+
if (any(!f2))
63+
{
64+
r |= 0x0100;
65+
}
66+
67+
// TODO(JS): Support on GLSL targets
68+
// This doesn't currently work on GLSL targets, and because there
69+
// do not appear to be any vector bitwise operations. Could be achieved
70+
// by deconstructing, and reconstructing the vec result
71+
/* if (all(f2 || bv2))
72+
{
73+
r |= 0x0200;
74+
} */
75+
76+
outputBuffer[tid] = r;
77+
}

tests/bugs/bool-op.slang.expected.txt

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
1AA
2+
16
3+
96
4+
16
5+
96
6+
16
7+
96
8+
15
9+
95
10+
15
11+
95
12+
55
13+
15
14+
55
15+
15
16+
55

0 commit comments

Comments
 (0)