Skip to content

Commit 0441643

Browse files
authored
Add support for vector/scalar compares for GLSL (shader-slang#903)
* * leftSide and rightSide set op to nullptr, before was just uninitialized * Added support for GLSL for vector/scalar comparisons * Added test * * Remove unneeded precedence code. * Simplify function to _maybeEmitGLSLCast * * Take into account precedence & closing of brackets in same way as function call, if function call used for vector comparison (as on GLSL)
1 parent 3b33c1b commit 0441643

File tree

3 files changed

+96
-18
lines changed

3 files changed

+96
-18
lines changed

source/slang/emit.cpp

+55-18
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,7 @@ struct EmitVisitor
14691469
EOpInfo leftSide(EOpInfo const& outerPrec, EOpInfo const& prec)
14701470
{
14711471
EOpInfo result;
1472+
result.op = nullptr;
14721473
result.leftPrecedence = outerPrec.leftPrecedence;
14731474
result.rightPrecedence = prec.leftPrecedence;
14741475
return result;
@@ -1477,6 +1478,7 @@ struct EmitVisitor
14771478
EOpInfo rightSide(EOpInfo const& prec, EOpInfo const& outerPrec)
14781479
{
14791480
EOpInfo result;
1481+
result.op = nullptr;
14801482
result.leftPrecedence = prec.rightPrecedence;
14811483
result.rightPrecedence = outerPrec.rightPrecedence;
14821484
return result;
@@ -3599,34 +3601,69 @@ struct EmitVisitor
35993601
}
36003602
}
36013603

3602-
void emitComparison(EmitContext* ctx, IRInst* inst, IREmitMode mode, EOpInfo& inOutOuterPrec, const EOpInfo& opPrec, bool* needCloseOut)
3604+
void _maybeEmitGLSLCast(EmitContext* ctx, IRType* castType, IRInst* inst, IREmitMode mode)
36033605
{
3604-
*needCloseOut = maybeEmitParens(inOutOuterPrec, opPrec);
3605-
3606-
if (getTarget(ctx) == CodeGenTarget::GLSL
3607-
&& as<IRVectorType>(inst->getOperand(0)->getDataType())
3608-
&& as<IRVectorType>(inst->getOperand(1)->getDataType()))
3606+
// Wrap in cast if a cast type is specified
3607+
if (castType)
36093608
{
3610-
const char* funcName = getGLSLVectorCompareFunctionName(inst->op);
3611-
SLANG_ASSERT(funcName);
3612-
3613-
emit(funcName);
3609+
emitIRType(ctx, castType);
36143610
emit("(");
3615-
emitIROperand(ctx, inst->getOperand(0), mode, leftSide(inOutOuterPrec, opPrec));
3616-
emit(",");
3617-
emitIROperand(ctx, inst->getOperand(1), mode, rightSide(inOutOuterPrec, opPrec));
3611+
3612+
// Emit the operand
3613+
emitIROperand(ctx, inst, mode, kEOp_General);
3614+
36183615
emit(")");
36193616
}
36203617
else
36213618
{
3622-
emitIROperand(ctx, inst->getOperand(0), mode, leftSide(inOutOuterPrec, opPrec));
3623-
emit(" ");
3624-
emit(opPrec.op);
3625-
emit(" ");
3626-
emitIROperand(ctx, inst->getOperand(1), mode, rightSide(inOutOuterPrec, opPrec));
3619+
// Emit the operand
3620+
emitIROperand(ctx, inst, mode, kEOp_General);
36273621
}
36283622
}
36293623

3624+
void emitComparison(EmitContext* ctx, IRInst* inst, IREmitMode mode, EOpInfo& ioOuterPrec, const EOpInfo& opPrec, bool* needCloseOut)
3625+
{
3626+
if (getTarget(ctx) == CodeGenTarget::GLSL)
3627+
{
3628+
IRInst* left = inst->getOperand(0);
3629+
IRInst* right = inst->getOperand(1);
3630+
3631+
auto leftVectorType = as<IRVectorType>(left->getDataType());
3632+
auto rightVectorType = as<IRVectorType>(right->getDataType());
3633+
3634+
// If either side is a vector handle as a vector
3635+
if (leftVectorType || rightVectorType)
3636+
{
3637+
const char* funcName = getGLSLVectorCompareFunctionName(inst->op);
3638+
SLANG_ASSERT(funcName);
3639+
3640+
// Determine the vector type
3641+
const auto vecType = leftVectorType ? leftVectorType : rightVectorType;
3642+
3643+
// Handle as a function call
3644+
auto prec = kEOp_Postfix;
3645+
*needCloseOut = maybeEmitParens(ioOuterPrec, prec);
3646+
3647+
emit(funcName);
3648+
emit("(");
3649+
_maybeEmitGLSLCast(ctx, (leftVectorType ? nullptr : vecType), left, mode);
3650+
emit(",");
3651+
_maybeEmitGLSLCast(ctx, (rightVectorType ? nullptr : vecType), right, mode);
3652+
emit(")");
3653+
3654+
return;
3655+
}
3656+
}
3657+
3658+
*needCloseOut = maybeEmitParens(ioOuterPrec, opPrec);
3659+
3660+
emitIROperand(ctx, inst->getOperand(0), mode, leftSide(ioOuterPrec, opPrec));
3661+
emit(" ");
3662+
emit(opPrec.op);
3663+
emit(" ");
3664+
emitIROperand(ctx, inst->getOperand(1), mode, rightSide(ioOuterPrec, opPrec));
3665+
}
3666+
36303667
void emitIRInstExpr(
36313668
EmitContext* ctx,
36323669
IRInst* inst,
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//TEST(compute):COMPARE_COMPUTE:-dx12 -compute
2+
//TEST(compute):COMPARE_COMPUTE:-vk -compute
3+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
4+
5+
// Test doing vector comparisons
6+
RWStructuredBuffer<int> outputBuffer;
7+
8+
[numthreads(4, 4, 1)]
9+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
10+
{
11+
uint2 threadInGroup = dispatchThreadID.xy;
12+
13+
int r = 0;
14+
if(all((threadInGroup & 1) == 0))
15+
{
16+
r = 0;
17+
}
18+
else
19+
{
20+
r = 1;
21+
}
22+
23+
int index = threadInGroup.x + threadInGroup.y * 4;
24+
outputBuffer[index] = r;
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
0
2+
1
3+
0
4+
1
5+
1
6+
1
7+
1
8+
1
9+
0
10+
1
11+
0
12+
1
13+
1
14+
1
15+
1
16+
1

0 commit comments

Comments
 (0)