Skip to content

Commit 57b09a8

Browse files
authored
Use and() and or() functions for logical-AND and OR (shader-slang#6310)
* Use and() and or() functions for logical-AND and OR With this commit, Slang will emit function calls to `and()` and `or()` for the logical-AND and logical-OR when the operands are non-scalar and the target profile is SM6.0 and above. This is required change from SM6.0. For WGSL, there is no operator overloadings of `&&` and `||` when the operands are non-scalar. Unlike HLSL, WGSL also don't have `and()` nor `or()`. Alternatively, we can use `select()`.
1 parent 79aebc1 commit 57b09a8

12 files changed

+287
-23
lines changed

lock

Whitespace-only changes.

source/slang/hlsl.meta.slang

+6-2
Original file line numberDiff line numberDiff line change
@@ -6228,7 +6228,9 @@ bool all(vector<T,N> x)
62286228
case hlsl:
62296229
__intrinsic_asm "all";
62306230
case metal:
6231-
__intrinsic_asm "all";
6231+
if (__isBool<T>())
6232+
__intrinsic_asm "all";
6233+
__intrinsic_asm "all(bool$N0($0))";
62326234
case glsl:
62336235
__intrinsic_asm "all(bvec$N0($0))";
62346236
case spirv:
@@ -6256,7 +6258,9 @@ bool all(vector<T,N> x)
62566258
};
62576259
}
62586260
case wgsl:
6259-
__intrinsic_asm "all";
6261+
if (__isBool<T>())
6262+
__intrinsic_asm "all";
6263+
__intrinsic_asm "all(vec$N0<bool>($0))";
62606264
default:
62616265
bool result = true;
62626266
for(int i = 0; i < N; ++i)

source/slang/slang-emit-hlsl.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,53 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
821821
}
822822
break;
823823
}
824+
case kIROp_And:
825+
case kIROp_Or:
826+
{
827+
// SM6.0 requires to use `and()` and `or()` functions for the logical-AND and
828+
// logical-OR, respectively, with non-scalar operands.
829+
auto targetProfile = getTargetProgram()->getOptionSet().getProfile();
830+
if (targetProfile.getVersion() < ProfileVersion::DX_6_0)
831+
return false;
832+
833+
if (as<IRBasicType>(inst->getDataType()))
834+
return false;
835+
836+
if (inst->getOp() == kIROp_And)
837+
{
838+
m_writer->emit("and(");
839+
}
840+
else
841+
{
842+
m_writer->emit("or(");
843+
}
844+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
845+
m_writer->emit(", ");
846+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
847+
m_writer->emit(")");
848+
return true;
849+
}
850+
case kIROp_Select:
851+
{
852+
// SM6.0 requires to use `select()` instead of the ternary operator "?:" when the
853+
// operands are non-scalar.
854+
auto targetProfile = getTargetProgram()->getOptionSet().getProfile();
855+
if (targetProfile.getVersion() < ProfileVersion::DX_6_0)
856+
return false;
857+
858+
if (as<IRBasicType>(inst->getDataType()))
859+
return false;
860+
861+
m_writer->emit("select(");
862+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
863+
m_writer->emit(", ");
864+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
865+
m_writer->emit(", ");
866+
emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
867+
m_writer->emit(")");
868+
return true;
869+
}
870+
824871
case kIROp_BitCast:
825872
{
826873
// For simplicity, we will handle all bit-cast operations

source/slang/slang-emit-wgsl.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,40 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
13121312
}
13131313
break;
13141314

1315+
case kIROp_And:
1316+
case kIROp_Or:
1317+
{
1318+
// WGSL doesn't have operator overloadings for `&&` and `||` when the operands are
1319+
// non-scalar. Unlike HLSL, WGSL doesn't have `and()` and `or()`.
1320+
auto vecType = as<IRVectorType>(inst->getDataType());
1321+
if (!vecType)
1322+
return false;
1323+
1324+
// The function signature for `select` in WGSL is different from others:
1325+
// @const @must_use fn select(f: T, t: T, cond: bool) -> T
1326+
if (inst->getOp() == kIROp_And)
1327+
{
1328+
m_writer->emit("select(vec");
1329+
m_writer->emit(getIntVal(vecType->getElementCount()));
1330+
m_writer->emit("<bool>(false), ");
1331+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
1332+
m_writer->emit(", ");
1333+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1334+
m_writer->emit(")");
1335+
}
1336+
else
1337+
{
1338+
m_writer->emit("select(");
1339+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
1340+
m_writer->emit(", vec");
1341+
m_writer->emit(getIntVal(vecType->getElementCount()));
1342+
m_writer->emit("<bool>(true), ");
1343+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1344+
m_writer->emit(")");
1345+
}
1346+
return true;
1347+
}
1348+
13151349
case kIROp_BitCast:
13161350
{
13171351
// In WGSL there is a built-in bitcast function!

source/slang/slang-emit.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "slang-ir-insts.h"
5353
#include "slang-ir-layout.h"
5454
#include "slang-ir-legalize-array-return-type.h"
55+
#include "slang-ir-legalize-binary-operator.h"
5556
#include "slang-ir-legalize-global-values.h"
5657
#include "slang-ir-legalize-image-subscript.h"
5758
#include "slang-ir-legalize-mesh-outputs.h"
@@ -1469,6 +1470,10 @@ Result linkAndOptimizeIR(
14691470
floatNonUniformResourceIndex(irModule, NonUniformResourceIndexFloatMode::Textual);
14701471
}
14711472

1473+
if (isD3DTarget(targetRequest) || isKhronosTarget(targetRequest) ||
1474+
isWGPUTarget(targetRequest) || isMetalTarget(targetRequest))
1475+
legalizeLogicalAndOr(irModule->getModuleInst());
1476+
14721477
// Legalize non struct parameters that are expected to be structs for HLSL.
14731478
if (isD3DTarget(targetRequest))
14741479
legalizeNonStructParameterToStructForHLSL(irModule);

source/slang/slang-ir-insts.h

+3
Original file line numberDiff line numberDiff line change
@@ -4520,6 +4520,9 @@ struct IRBuilder
45204520
IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1);
45214521
IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1);
45224522

4523+
IRInst* emitAnd(IRType* type, IRInst* left, IRInst* right);
4524+
IRInst* emitOr(IRType* type, IRInst* left, IRInst* right);
4525+
45234526
IRSPIRVAsmOperand* emitSPIRVAsmOperandLiteral(IRInst* literal);
45244527
IRSPIRVAsmOperand* emitSPIRVAsmOperandInst(IRInst* inst);
45254528
IRSPIRVAsmOperand* createSPIRVAsmOperandInst(IRInst* inst);

source/slang/slang-ir-legalize-binary-operator.cpp

+97
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,101 @@ void legalizeBinaryOp(IRInst* inst)
118118
}
119119
}
120120

121+
void legalizeLogicalAndOr(IRInst* inst)
122+
{
123+
switch (inst->getOp())
124+
{
125+
case kIROp_And:
126+
case kIROp_Or:
127+
{
128+
IRBuilder builder(inst);
129+
builder.setInsertBefore(inst);
130+
131+
// Logical-AND and logical-OR takes boolean types as its operands.
132+
// If they are not, legalize them by casting to boolean type.
133+
//
134+
SLANG_ASSERT(inst->getOperandCount() == 2);
135+
for (UInt i = 0; i < 2; i++)
136+
{
137+
auto operand = inst->getOperand(i);
138+
auto operandDataType = operand->getDataType();
139+
140+
if (auto vecType = as<IRVectorType>(operandDataType))
141+
{
142+
if (!as<IRBoolType>(vecType->getElementType()))
143+
{
144+
// Cast operand to vector<bool,N>
145+
auto elemCount = vecType->getElementCount();
146+
auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
147+
auto v = builder.emitCast(vb, operand);
148+
builder.replaceOperand(inst->getOperands() + i, v);
149+
}
150+
}
151+
else if (!as<IRBoolType>(operandDataType))
152+
{
153+
// Cast operand to bool
154+
auto s = builder.emitCast(builder.getBoolType(), operand);
155+
builder.replaceOperand(inst->getOperands() + i, s);
156+
}
157+
}
158+
159+
// Legalize the return type; mostly for SPIRV.
160+
// The return type of OpLogicalOr must be boolean type.
161+
// If not, we need to recreate the instruction with boolean return type.
162+
// Then, we have to cast it back to the original type so that other instrucitons that
163+
// use have the matching types.
164+
//
165+
auto dataType = inst->getDataType();
166+
auto lhs = inst->getOperand(0);
167+
auto rhs = inst->getOperand(1);
168+
IRInst* newInst = nullptr;
169+
170+
if (auto vecType = as<IRVectorType>(dataType))
171+
{
172+
if (!as<IRBoolType>(vecType->getElementType()))
173+
{
174+
// Return type should be vector<bool,N>
175+
auto elemCount = vecType->getElementCount();
176+
auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
177+
178+
if (inst->getOp() == kIROp_And)
179+
{
180+
newInst = builder.emitAnd(vb, lhs, rhs);
181+
}
182+
else
183+
{
184+
newInst = builder.emitOr(vb, lhs, rhs);
185+
}
186+
newInst = builder.emitCast(dataType, newInst);
187+
}
188+
}
189+
else if (!as<IRBoolType>(dataType))
190+
{
191+
// Return type should be bool
192+
if (inst->getOp() == kIROp_And)
193+
{
194+
newInst = builder.emitAnd(builder.getBoolType(), lhs, rhs);
195+
}
196+
else
197+
{
198+
newInst = builder.emitOr(builder.getBoolType(), lhs, rhs);
199+
}
200+
newInst = builder.emitCast(dataType, newInst);
201+
}
202+
203+
if (newInst && inst != newInst)
204+
{
205+
inst->replaceUsesWith(newInst);
206+
inst->removeAndDeallocate();
207+
}
208+
}
209+
break;
210+
}
211+
212+
for (auto child : inst->getModifiableChildren())
213+
{
214+
legalizeLogicalAndOr(child);
215+
}
216+
}
217+
121218
} // namespace Slang

source/slang/slang-ir-legalize-binary-operator.h

+5
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ struct IRInst;
1313
// signed operand is converted to unsigned.
1414
void legalizeBinaryOp(IRInst* inst);
1515

16+
// The logical binary operators such as AND and OR takes boolean types are its input.
17+
// If they are in integer type, as an example, we need to explicitly cast to bool type.
18+
// Also the return type from the logical operators should be a boolean type.
19+
void legalizeLogicalAndOr(IRInst* inst);
20+
1621
} // namespace Slang

source/slang/slang-ir.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -6020,6 +6020,20 @@ IRInst* IRBuilder::emitShl(IRType* type, IRInst* left, IRInst* right)
60206020
return inst;
60216021
}
60226022

6023+
IRInst* IRBuilder::emitAnd(IRType* type, IRInst* left, IRInst* right)
6024+
{
6025+
auto inst = createInst<IRInst>(this, kIROp_And, type, left, right);
6026+
addInst(inst);
6027+
return inst;
6028+
}
6029+
6030+
IRInst* IRBuilder::emitOr(IRType* type, IRInst* left, IRInst* right)
6031+
{
6032+
auto inst = createInst<IRInst>(this, kIROp_Or, type, left, right);
6033+
addInst(inst);
6034+
return inst;
6035+
}
6036+
60236037
IRInst* IRBuilder::emitGetNativePtr(IRInst* value)
60246038
{
60256039
auto valueType = value->getDataType();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain
2+
//TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain
3+
//TEST(compute):SIMPLE(filecheck=WGS):-target wgsl -stage compute -entry computeMain
4+
//TEST(compute):SIMPLE(filecheck=MTL):-target metal -stage compute -entry computeMain
5+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -output-using-type -xslang -Wno-30056
6+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -Wno-30056
7+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056
8+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -output-using-type -xslang -Wno-30056
9+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -output-using-type -xslang -Wno-30056
10+
11+
// Testnig logical-AND, logical-OR and ternary operator with non-scalar operands
12+
13+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
14+
RWStructuredBuffer<int> outputBuffer;
15+
16+
static int result = 0;
17+
18+
bool2 assignFunc(int index)
19+
{
20+
result += 10;
21+
return bool2(true);
22+
}
23+
24+
[numthreads(4, 1, 1)]
25+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
26+
{
27+
int index = dispatchThreadID.x;
28+
29+
// No short-circuiting for vector types
30+
31+
//SM5:(all({{.*}}&&
32+
//SM6:(all(and(
33+
//WGS:(all(select(vec2<bool>(false),
34+
//MTL:(all({{.*}}&&
35+
if (all(bool2(index >= 1) && assignFunc(index)))
36+
{
37+
result++;
38+
}
39+
40+
// Intentionally using non-boolean type for testing.
41+
42+
//SM5:(all({{.*}}||
43+
//SM6:(or(vector<bool,2>(
44+
//WGS:(select({{.*}}, vec2<bool>(true), vec2<bool>(
45+
//MTL:(all(bool2({{.*}}||
46+
if (all(int2(index >= 2) || !assignFunc(index)))
47+
{
48+
result++;
49+
}
50+
51+
//SM5:(all({{.*}}?{{.*}}:
52+
//SM6:(all(select(
53+
//WGS:(all(select(vec2<bool>(false),
54+
//MTL:(all(select(bool2(false)
55+
if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false)))
56+
{
57+
result++;
58+
}
59+
60+
outputBuffer[index] = result;
61+
62+
//CHK:30
63+
//CHK-NEXT:31
64+
//CHK-NEXT:32
65+
//CHK-NEXT:33
66+
}

tests/compute/logic-short-circuit-evaluation.slang

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
//TEST(compute):COMPARE_COMPUTE:-dx12 -compute -shaderobj
2-
//TEST(compute):COMPARE_COMPUTE:-vk -compute -shaderobj
3-
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj
4-
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -compile-arg -O3 -shaderobj
5-
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj
3+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj
4+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj
5+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj
6+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj
67

78
// Test doing vector comparisons
89

@@ -25,4 +26,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
2526

2627
// Only the last 4 elements will be 1.
2728
(index < 12) || assignFunc(index);
29+
30+
//CHK-COUNT-4: 1
31+
//CHK-COUNT-8: 0
32+
//CHK-COUNT-4: 1
2833
}

tests/compute/logic-short-circuit-evaluation.slang.expected.txt

-16
This file was deleted.

0 commit comments

Comments
 (0)