diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 3b2cf12d0b..30a7af9385 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1372,10 +1372,10 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu case kIROp_Rsh: case kIROp_Lsh: { - // Shift amounts must be an unsigned type in WGSL + // Shift amounts must be an unsigned type in WGSL. + // We ensure this during legalization. // https://www.w3.org/TR/WGSL/#bit-expr - IRInst* const shiftAmount = inst->getOperand(1); - IRType* const shiftAmountType = shiftAmount->getDataType(); + SLANG_ASSERT(inst->getOperand(1)->getDataType()->getOp() != kIROp_IntType); // Dawn complains about mixing '<<' and '|', '^' and a bunch of other bit operators // without a paranthesis, so we'll always emit paranthesis around the shift amount. @@ -1392,18 +1392,9 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(info.op); m_writer->emit(" "); - if (shiftAmountType->getOp() == kIROp_IntType) - { - m_writer->emit("bitcast("); - emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); - m_writer->emit(")"); - } - else - { - m_writer->emit("("); - emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); - m_writer->emit(")"); - } + m_writer->emit("("); + emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); + m_writer->emit(")"); maybeCloseParens(needClose); diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp new file mode 100644 index 0000000000..a1affb7e9d --- /dev/null +++ b/source/slang/slang-ir-legalize-binary-operator.cpp @@ -0,0 +1,121 @@ +#include "slang-ir-legalize-binary-operator.h" + +#include "slang-ir-insts.h" + +namespace Slang +{ + +void legalizeBinaryOp(IRInst* inst) +{ + // For shifts, ensure that the shift amount is unsigned, as required by + // https://www.w3.org/TR/WGSL/#bit-expr. + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + IRInst* shiftAmount = inst->getOperand(1); + IRType* shiftAmountType = shiftAmount->getDataType(); + if (auto shiftAmountVectorType = as(shiftAmountType)) + { + IRType* shiftAmountElementType = shiftAmountVectorType->getElementType(); + IntInfo opIntInfo = getIntTypeInfo(shiftAmountElementType); + if (opIntInfo.isSigned) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + opIntInfo.isSigned = false; + shiftAmountElementType = builder.getType(getIntTypeOpFromInfo(opIntInfo)); + shiftAmountVectorType = builder.getVectorType( + shiftAmountElementType, + shiftAmountVectorType->getElementCount()); + IRInst* newShiftAmount = builder.emitCast(shiftAmountVectorType, shiftAmount); + builder.replaceOperand(inst->getOperands() + 1, newShiftAmount); + } + } + else if (isIntegralType(shiftAmountType)) + { + IntInfo opIntInfo = getIntTypeInfo(shiftAmountType); + if (opIntInfo.isSigned) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + opIntInfo.isSigned = false; + shiftAmountType = builder.getType(getIntTypeOpFromInfo(opIntInfo)); + IRInst* newShiftAmount = builder.emitCast(shiftAmountType, shiftAmount); + builder.replaceOperand(inst->getOperands() + 1, newShiftAmount); + } + } + } + + auto isVectorOrMatrix = [](IRType* type) + { + switch (type->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: + return true; + default: + return false; + } + }; + if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && + as(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + IRType* compositeType = inst->getOperand(0)->getDataType(); + IRInst* scalarValue = inst->getOperand(1); + // Retain the scalar type for shifts + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + auto vectorType = as(compositeType); + compositeType = + builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); + } + auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); + builder.replaceOperand(inst->getOperands() + 1, newRhs); + } + else if ( + as(inst->getOperand(0)->getDataType()) && + isVectorOrMatrix(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + IRType* compositeType = inst->getOperand(1)->getDataType(); + IRInst* scalarValue = inst->getOperand(0); + // Retain the scalar type for shifts + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + auto vectorType = as(compositeType); + compositeType = + builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); + } + auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); + builder.replaceOperand(inst->getOperands(), newLhs); + } + else if ( + isIntegralType(inst->getOperand(0)->getDataType()) && + isIntegralType(inst->getOperand(1)->getDataType())) + { + // Unless the operator is a shift, and if the integer operands differ in signedness, + // then convert the signed one to unsigned. + // We're assuming that the cases where this is bad have already been caught by + // common validation checks. + IntInfo opIntInfo[2] = { + getIntTypeInfo(inst->getOperand(0)->getDataType()), + getIntTypeInfo(inst->getOperand(1)->getDataType())}; + bool isShift = inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh; + bool signednessDiffers = opIntInfo[0].isSigned != opIntInfo[1].isSigned; + if (!isShift && signednessDiffers) + { + int signedOpIndex = (int)opIntInfo[1].isSigned; + opIntInfo[signedOpIndex].isSigned = false; + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newOp = builder.emitCast( + builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])), + inst->getOperand(signedOpIndex)); + builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp); + } + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-legalize-binary-operator.h b/source/slang/slang-ir-legalize-binary-operator.h new file mode 100644 index 0000000000..71c3197183 --- /dev/null +++ b/source/slang/slang-ir-legalize-binary-operator.h @@ -0,0 +1,16 @@ +#pragma once + +namespace Slang +{ + +struct IRInst; + +// Ensures: +// - Shift amounts are over unsigned scalar types. +// - If one operand is a composite type (vector or matrix), and the other one is a scalar +// type, then the scalar is converted to a composite type. +// - If 'inst' is not a shift, and if operands are integers of mixed signedness, then the +// signed operand is converted to unsigned. +void legalizeBinaryOp(IRInst* inst); + +} // namespace Slang diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index ce5b34c3e6..5bfa62e4af 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -2,6 +2,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" +#include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-varying-params.h" #include "slang-ir-specialize-address-space.h" #include "slang-ir-util.h" @@ -2120,6 +2121,40 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner } }; +static void processInst(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + legalizeBinaryOp(inst); + break; + + default: + for (auto child : inst->getModifiableChildren()) + { + processInst(child); + } + } +} + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) { List entryPoints; @@ -2145,6 +2180,8 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) MetalAddressSpaceAssigner metalAddressSpaceAssigner; specializeAddressSpace(module, &metalAddressSpaceAssigner); + + processInst(module->getModuleInst()); } } // namespace Slang diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index f76a0541c2..effc06f3ef 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1,6 +1,7 @@ #include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" +#include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-varying-params.h" #include "slang-ir-util.h" @@ -1487,64 +1488,6 @@ struct LegalizeWGSLEntryPointContext switchInst->removeAndDeallocate(); } - void legalizeBinaryOp(IRInst* inst) - { - auto isVectorOrMatrix = [](IRType* type) - { - switch (type->getOp()) - { - case kIROp_VectorType: - case kIROp_MatrixType: - return true; - default: - return false; - } - }; - if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && - as(inst->getOperand(1)->getDataType())) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newRhs = builder.emitMakeCompositeFromScalar( - inst->getOperand(0)->getDataType(), - inst->getOperand(1)); - builder.replaceOperand(inst->getOperands() + 1, newRhs); - } - else if ( - as(inst->getOperand(0)->getDataType()) && - isVectorOrMatrix(inst->getOperand(1)->getDataType())) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newLhs = builder.emitMakeCompositeFromScalar( - inst->getOperand(1)->getDataType(), - inst->getOperand(0)); - builder.replaceOperand(inst->getOperands(), newLhs); - } - else if ( - isIntegralType(inst->getOperand(0)->getDataType()) && - isIntegralType(inst->getOperand(1)->getDataType())) - { - // If integer operands differ in signedness, convert the signed one to unsigned. - // We're assuming that the cases where this is bad have already been caught by - // common validation checks. - IntInfo opIntInfo[2] = { - getIntTypeInfo(inst->getOperand(0)->getDataType()), - getIntTypeInfo(inst->getOperand(1)->getDataType())}; - if (opIntInfo[0].isSigned != opIntInfo[1].isSigned) - { - int signedOpIndex = (int)opIntInfo[1].isSigned; - opIntInfo[signedOpIndex].isSigned = false; - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newOp = builder.emitCast( - builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])), - inst->getOperand(signedOpIndex)); - builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp); - } - } - } - void processInst(IRInst* inst) { switch (inst->getOp()) diff --git a/tests/metal/byte-address-buffer.slang b/tests/metal/byte-address-buffer.slang index 24802815e7..d4b58061fe 100644 --- a/tests/metal/byte-address-buffer.slang +++ b/tests/metal/byte-address-buffer.slang @@ -20,11 +20,11 @@ struct TestStruct void main_kernel(uint3 tid: SV_DispatchThreadID) { // CHECK: uint [[WORD0:[a-zA-Z0-9_]+]] = as_type({{.*}}[(int(0))>>2]); - // CHECK: uint8_t [[A:[a-zA-Z0-9_]+]] = uint8_t([[WORD0]] >> int(0) & 255U); + // CHECK: uint8_t [[A:[a-zA-Z0-9_]+]] = uint8_t([[WORD0]] >> 0U & 255U); // CHECK: uint [[WORD1:[a-zA-Z0-9_]+]] = as_type({{.*}}[(int(0))>>2]); - // CHECK: half [[H:[a-zA-Z0-9_]+]] = as_type(ushort([[WORD1]] >> int(16) & 65535U)); + // CHECK: half [[H:[a-zA-Z0-9_]+]] = as_type(ushort([[WORD1]] >> 16U & 65535U)); - // CHECK: {{.*}}[(int(128))>>2] = as_type(({{.*}} & 4294967040U) | uint([[A]]) << int(0)); - // CHECK: {{.*}}[(int(128))>>2] = as_type(({{.*}} & 65535U) | uint(as_type([[H]])) << int(16)); + // CHECK: {{.*}}[(int(128))>>2] = as_type(({{.*}} & 4294967040U) | uint([[A]]) << 0U); + // CHECK: {{.*}}[(int(128))>>2] = as_type(({{.*}} & 65535U) | uint(as_type([[H]])) << 16U); buffer.Store(128, buffer.Load(0)); } diff --git a/tests/wgsl/bitshifts.slang b/tests/wgsl/bitshifts.slang new file mode 100644 index 0000000000..50d2fc43da --- /dev/null +++ b/tests/wgsl/bitshifts.slang @@ -0,0 +1,92 @@ +//TEST(compute):COMPARE_COMPUTE:-shaderobj + +//TEST_INPUT:ubuffer(data=[3 7 8 10], stride=4):name=inputBuffer +RWStructuredBuffer inputBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + int amount = 1; + + outputBuffer[0] = inputBuffer[0] >> amount; + + int2 v2 = int2(inputBuffer[0], inputBuffer[1]) >> amount; + outputBuffer[1] = v2[0]; + outputBuffer[2] = v2[1]; + + int3 v3 = int3(inputBuffer[0], inputBuffer[1], inputBuffer[2]) >> amount; + outputBuffer[3] = v3[0]; + outputBuffer[4] = v3[1]; + outputBuffer[5] = v3[2]; + + int4 v4 = int4(inputBuffer[0], inputBuffer[1], inputBuffer[2], inputBuffer[3]) >> amount; + outputBuffer[6] = v4[0]; + outputBuffer[7] = v4[1]; + outputBuffer[8] = v4[2]; + outputBuffer[9] = v4[3]; + + outputBuffer[10] = inputBuffer[0] << amount; + + v2 = int2(inputBuffer[0], inputBuffer[1]) << amount; + outputBuffer[11] = v2[0]; + outputBuffer[12] = v2[1]; + + v3 = int3(inputBuffer[0], inputBuffer[1], inputBuffer[2]) << amount; + outputBuffer[13] = v3[0]; + outputBuffer[14] = v3[1]; + outputBuffer[15] = v3[2]; + + v4 = int4(inputBuffer[0], inputBuffer[1], inputBuffer[2], inputBuffer[3]) << amount; + outputBuffer[16] = v4[0]; + outputBuffer[17] = v4[1]; + outputBuffer[18] = v4[2]; + outputBuffer[19] = v4[3]; + + v2 = inputBuffer[0] >> int2(amount); + outputBuffer[20] = v2[0]; + outputBuffer[21] = v2[1]; + + v3 = inputBuffer[1] >> int3(amount); + outputBuffer[22] = v3[0]; + outputBuffer[23] = v3[1]; + outputBuffer[24] = v3[2]; + + v4 = inputBuffer[2] >> int4(amount); + outputBuffer[25] = v4[0]; + outputBuffer[26] = v4[1]; + outputBuffer[27] = v4[2]; + outputBuffer[28] = v4[3]; + + v2 = inputBuffer[0] << int2(amount); + outputBuffer[29] = v2[0]; + outputBuffer[30] = v2[1]; + + v3 = inputBuffer[1] << int3(amount); + outputBuffer[31] = v3[0]; + outputBuffer[32] = v3[1]; + outputBuffer[33] = v3[2]; + + v4 = inputBuffer[2] << int4(amount); + outputBuffer[34] = v4[0]; + outputBuffer[35] = v4[1]; + outputBuffer[36] = v4[2]; + outputBuffer[37] = v4[3]; + + v2 = int2(inputBuffer[0], inputBuffer[1]) >> int2(1, 2); + outputBuffer[38] = v2[0]; + outputBuffer[39] = v2[1]; + + v3 = int3(inputBuffer[0], inputBuffer[1], inputBuffer[2]) >> int3(1, 2, 3); + outputBuffer[40] = v3[0]; + outputBuffer[41] = v3[1]; + outputBuffer[42] = v3[2]; + + v4 = int4(inputBuffer[0], inputBuffer[1], inputBuffer[2], inputBuffer[4]) >> int4(1, 2, 3, 4); + outputBuffer[43] = v4[0]; + outputBuffer[44] = v4[1]; + outputBuffer[45] = v4[2]; + outputBuffer[46] = v4[3]; +} diff --git a/tests/wgsl/bitshifts.slang.expected.txt b/tests/wgsl/bitshifts.slang.expected.txt new file mode 100644 index 0000000000..ff1ae84e4c --- /dev/null +++ b/tests/wgsl/bitshifts.slang.expected.txt @@ -0,0 +1,47 @@ +1 +1 +3 +1 +3 +4 +1 +3 +4 +5 +6 +6 +E +6 +E +10 +6 +E +10 +14 +1 +1 +3 +3 +3 +4 +4 +4 +4 +6 +6 +E +E +E +10 +10 +10 +10 +1 +1 +1 +1 +1 +1 +1 +1 +0