Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WGSL: Convert signed vector shift amounts to unsigned #6023

Merged
merged 8 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,10 +1372,10 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
case kIROp_Rsh:
case kIROp_Lsh:
{
// Shift amounts must be an unsigned type in WGSL
// Shift amounts must be an unsigned type in WGSL.
// We ensure this during legalization.
// https://www.w3.org/TR/WGSL/#bit-expr
IRInst* const shiftAmount = inst->getOperand(1);
IRType* const shiftAmountType = shiftAmount->getDataType();
SLANG_ASSERT(inst->getOperand(1)->getDataType()->getOp() != kIROp_IntType);

// Dawn complains about mixing '<<' and '|', '^' and a bunch of other bit operators
// without a paranthesis, so we'll always emit paranthesis around the shift amount.
Expand All @@ -1392,18 +1392,9 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
m_writer->emit(info.op);
m_writer->emit(" ");

if (shiftAmountType->getOp() == kIROp_IntType)
{
m_writer->emit("bitcast<u32>(");
emitOperand(inst->getOperand(1), rightSide(outerPrec, info));
m_writer->emit(")");
}
else
{
m_writer->emit("(");
emitOperand(inst->getOperand(1), rightSide(outerPrec, info));
m_writer->emit(")");
}
m_writer->emit("(");
emitOperand(inst->getOperand(1), rightSide(outerPrec, info));
m_writer->emit(")");

maybeCloseParens(needClose);

Expand Down
121 changes: 121 additions & 0 deletions source/slang/slang-ir-legalize-binary-operator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include "slang-ir-legalize-binary-operator.h"

#include "slang-ir-insts.h"

namespace Slang
{

void legalizeBinaryOp(IRInst* inst)
{
// For shifts, ensure that the shift amount is unsigned, as required by
// https://www.w3.org/TR/WGSL/#bit-expr.
if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
{
IRInst* shiftAmount = inst->getOperand(1);
IRType* shiftAmountType = shiftAmount->getDataType();
if (auto shiftAmountVectorType = as<IRVectorType>(shiftAmountType))
{
IRType* shiftAmountElementType = shiftAmountVectorType->getElementType();
IntInfo opIntInfo = getIntTypeInfo(shiftAmountElementType);
if (opIntInfo.isSigned)
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
opIntInfo.isSigned = false;
shiftAmountElementType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
shiftAmountVectorType = builder.getVectorType(
shiftAmountElementType,
shiftAmountVectorType->getElementCount());
IRInst* newShiftAmount = builder.emitCast(shiftAmountVectorType, shiftAmount);
builder.replaceOperand(inst->getOperands() + 1, newShiftAmount);
}
}
else if (isIntegralType(shiftAmountType))
{
IntInfo opIntInfo = getIntTypeInfo(shiftAmountType);
if (opIntInfo.isSigned)
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
opIntInfo.isSigned = false;
shiftAmountType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
IRInst* newShiftAmount = builder.emitCast(shiftAmountType, shiftAmount);
builder.replaceOperand(inst->getOperands() + 1, newShiftAmount);
}
}
}

auto isVectorOrMatrix = [](IRType* type)
{
switch (type->getOp())
{
case kIROp_VectorType:
case kIROp_MatrixType:
return true;
default:
return false;
}
};
if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
as<IRBasicType>(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
IRType* compositeType = inst->getOperand(0)->getDataType();
IRInst* scalarValue = inst->getOperand(1);
// Retain the scalar type for shifts
if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
{
auto vectorType = as<IRVectorType>(compositeType);
compositeType =
builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
}
auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
builder.replaceOperand(inst->getOperands() + 1, newRhs);
}
else if (
as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
isVectorOrMatrix(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
IRType* compositeType = inst->getOperand(1)->getDataType();
IRInst* scalarValue = inst->getOperand(0);
// Retain the scalar type for shifts
if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
{
auto vectorType = as<IRVectorType>(compositeType);
compositeType =
builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
}
auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
builder.replaceOperand(inst->getOperands(), newLhs);
}
else if (
isIntegralType(inst->getOperand(0)->getDataType()) &&
isIntegralType(inst->getOperand(1)->getDataType()))
{
// Unless the operator is a shift, and if the integer operands differ in signedness,
// then convert the signed one to unsigned.
// We're assuming that the cases where this is bad have already been caught by
// common validation checks.
IntInfo opIntInfo[2] = {
getIntTypeInfo(inst->getOperand(0)->getDataType()),
getIntTypeInfo(inst->getOperand(1)->getDataType())};
bool isShift = inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh;
bool signednessDiffers = opIntInfo[0].isSigned != opIntInfo[1].isSigned;
if (!isShift && signednessDiffers)
{
int signedOpIndex = (int)opIntInfo[1].isSigned;
opIntInfo[signedOpIndex].isSigned = false;
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newOp = builder.emitCast(
builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])),
inst->getOperand(signedOpIndex));
builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp);
}
}
}

} // namespace Slang
16 changes: 16 additions & 0 deletions source/slang/slang-ir-legalize-binary-operator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

namespace Slang
{

struct IRInst;

// Ensures:
// - Shift amounts are over unsigned scalar types.
// - If one operand is a composite type (vector or matrix), and the other one is a scalar
// type, then the scalar is converted to a composite type.
// - If 'inst' is not a shift, and if operands are integers of mixed signedness, then the
// signed operand is converted to unsigned.
void legalizeBinaryOp(IRInst* inst);

} // namespace Slang
37 changes: 37 additions & 0 deletions source/slang/slang-ir-metal-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-legalize-varying-params.h"
#include "slang-ir-specialize-address-space.h"
#include "slang-ir-util.h"
Expand Down Expand Up @@ -2120,6 +2121,40 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner
}
};

static void processInst(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
case kIROp_Div:
case kIROp_FRem:
case kIROp_IRem:
case kIROp_And:
case kIROp_Or:
case kIROp_BitAnd:
case kIROp_BitOr:
case kIROp_BitXor:
case kIROp_Lsh:
case kIROp_Rsh:
case kIROp_Eql:
case kIROp_Neq:
case kIROp_Greater:
case kIROp_Less:
case kIROp_Geq:
case kIROp_Leq:
legalizeBinaryOp(inst);
break;

default:
for (auto child : inst->getModifiableChildren())
{
processInst(child);
}
}
}

void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
{
List<EntryPointInfo> entryPoints;
Expand All @@ -2145,6 +2180,8 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)

MetalAddressSpaceAssigner metalAddressSpaceAssigner;
specializeAddressSpace(module, &metalAddressSpaceAssigner);

processInst(module->getModuleInst());
}

} // namespace Slang
59 changes: 1 addition & 58 deletions source/slang/slang-ir-wgsl-legalize.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "slang-ir-wgsl-legalize.h"

#include "slang-ir-insts.h"
#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-varying-params.h"
#include "slang-ir-util.h"
Expand Down Expand Up @@ -1487,64 +1488,6 @@ struct LegalizeWGSLEntryPointContext
switchInst->removeAndDeallocate();
}

void legalizeBinaryOp(IRInst* inst)
{
auto isVectorOrMatrix = [](IRType* type)
{
switch (type->getOp())
{
case kIROp_VectorType:
case kIROp_MatrixType:
return true;
default:
return false;
}
};
if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
as<IRBasicType>(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newRhs = builder.emitMakeCompositeFromScalar(
inst->getOperand(0)->getDataType(),
inst->getOperand(1));
builder.replaceOperand(inst->getOperands() + 1, newRhs);
}
else if (
as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
isVectorOrMatrix(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newLhs = builder.emitMakeCompositeFromScalar(
inst->getOperand(1)->getDataType(),
inst->getOperand(0));
builder.replaceOperand(inst->getOperands(), newLhs);
}
else if (
isIntegralType(inst->getOperand(0)->getDataType()) &&
isIntegralType(inst->getOperand(1)->getDataType()))
{
// If integer operands differ in signedness, convert the signed one to unsigned.
// We're assuming that the cases where this is bad have already been caught by
// common validation checks.
IntInfo opIntInfo[2] = {
getIntTypeInfo(inst->getOperand(0)->getDataType()),
getIntTypeInfo(inst->getOperand(1)->getDataType())};
if (opIntInfo[0].isSigned != opIntInfo[1].isSigned)
{
int signedOpIndex = (int)opIntInfo[1].isSigned;
opIntInfo[signedOpIndex].isSigned = false;
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newOp = builder.emitCast(
builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])),
inst->getOperand(signedOpIndex));
builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp);
}
}
}

void processInst(IRInst* inst)
{
switch (inst->getOp())
Expand Down
8 changes: 4 additions & 4 deletions tests/metal/byte-address-buffer.slang
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ struct TestStruct
void main_kernel(uint3 tid: SV_DispatchThreadID)
{
// CHECK: uint [[WORD0:[a-zA-Z0-9_]+]] = as_type<uint>({{.*}}[(int(0))>>2]);
// CHECK: uint8_t [[A:[a-zA-Z0-9_]+]] = uint8_t([[WORD0]] >> int(0) & 255U);
// CHECK: uint8_t [[A:[a-zA-Z0-9_]+]] = uint8_t([[WORD0]] >> 0U & 255U);
// CHECK: uint [[WORD1:[a-zA-Z0-9_]+]] = as_type<uint>({{.*}}[(int(0))>>2]);
// CHECK: half [[H:[a-zA-Z0-9_]+]] = as_type<half>(ushort([[WORD1]] >> int(16) & 65535U));
// CHECK: half [[H:[a-zA-Z0-9_]+]] = as_type<half>(ushort([[WORD1]] >> 16U & 65535U));

// CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 4294967040U) | uint([[A]]) << int(0));
// CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 65535U) | uint(as_type<ushort>([[H]])) << int(16));
// CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 4294967040U) | uint([[A]]) << 0U);
// CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 65535U) | uint(as_type<ushort>([[H]])) << 16U);
buffer.Store(128, buffer.Load<TestStruct>(0));
}
Loading
Loading