Skip to content

Commit

Permalink
Implement bitcast for 64-bit date type (#5895)
Browse files Browse the repository at this point in the history
Close #5470

* implement bitcast for 64-bit date type
* Move 'ensurePrelude' to base class to remove duplication
* Assert on 'double' type for Metal target, as Metal doesn't have 'double' support
  • Loading branch information
kaizhangNV authored Dec 18, 2024
1 parent 49e912a commit 6f57e47
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 42 deletions.
11 changes: 11 additions & 0 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5137,4 +5137,15 @@ void CLikeSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink)
executeEmitActions(actions);
}

void CLikeSourceEmitter::ensurePrelude(const char* preludeText)
{
IRStringLit* stringLit;
if (!m_builtinPreludes.tryGetValue(preludeText, stringLit))
{
IRBuilder builder(m_irModule);
stringLit = builder.getStringValue(UnownedStringSlice(preludeText));
m_builtinPreludes[preludeText] = stringLit;
}
m_requiredPreludes.add(stringLit);
}
} // namespace Slang
5 changes: 5 additions & 0 deletions source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ class CLikeSourceEmitter : public SourceEmitterBase

String _emitLiteralOneWithType(int bitWidth);


virtual void ensurePrelude(const char* preludeText);

CodeGenContext* m_codeGenContext = nullptr;
IRModule* m_irModule = nullptr;

Expand Down Expand Up @@ -723,6 +726,8 @@ class CLikeSourceEmitter : public SourceEmitterBase
{
String requireComputeDerivatives;
} m_requiredAfter;

Dictionary<const char*, IRStringLit*> m_builtinPreludes;
};

} // namespace Slang
Expand Down
26 changes: 26 additions & 0 deletions source/slang/slang-emit-glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,32 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
emitType(inst->getDataType());
}
break;
case BaseType::Int64:
if (fromType == BaseType::Double)
{
m_writer->emit("int64_t(doubleBitsToInt64(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit("))");
return true;
}
else
{
emitType(inst->getDataType());
}
break;
case BaseType::UInt64:
if (fromType == BaseType::Double)
{
m_writer->emit("uint64_t(doubleBitsToUint64(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit("))");
return true;
}
else
{
emitType(inst->getDataType());
}
break;
case BaseType::Half:
switch (fromType)
{
Expand Down
37 changes: 28 additions & 9 deletions source/slang/slang-emit-hlsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
namespace Slang
{

static const char* kHLSLBuiltInPrelude64BitCast = R"(
uint64_t _slang_asuint64(double x)
{
uint32_t low;
uint32_t high;
asuint(x, low, high);
return ((uint64_t)high << 32) | low;
}
double _slang_asdouble(uint64_t x)
{
uint32_t low = x & 0xFFFFFFFF;
uint32_t high = x >> 32;
return asdouble(low, high);
}
)";

void HLSLSourceEmitter::_emitHLSLDecorationSingleString(
const char* name,
IRFunc* entryPoint,
Expand Down Expand Up @@ -822,17 +839,12 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
//
// There is no current function (it seems)
// for bit-casting an `int16_t` to a `half`.
//
// TODO: There is an `asdouble` function
// for converting two 32-bit integer values into
// one `double`. We could use that for
// bit casts of 64-bit values with a bit of
// extra work, but doing so might be best
// handled in an IR pass that legalizes
// bit-casts.
//
m_writer->emit("asfloat");
break;
case BaseType::Double:
ensurePrelude(kHLSLBuiltInPrelude64BitCast);
m_writer->emit("_slang_asdouble");
break;
}
m_writer->emit("(");
int closeCount = 1;
Expand All @@ -844,6 +856,8 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
diagnoseUnhandledInst(inst);
break;

case BaseType::Int64:
case BaseType::UInt64:
case BaseType::UInt:
case BaseType::Int:
case BaseType::Bool:
Expand All @@ -860,6 +874,11 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
m_writer->emit("asuint16(");
closeCount++;
break;
case BaseType::Double:
ensurePrelude(kHLSLBuiltInPrelude64BitCast);
m_writer->emit("_slang_asuint64(");
closeCount++;
break;
}

emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
Expand Down
16 changes: 3 additions & 13 deletions source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,6 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(
}
}

void MetalSourceEmitter::ensurePrelude(const char* preludeText)
{
IRStringLit* stringLit;
if (!m_builtinPreludes.tryGetValue(preludeText, stringLit))
{
IRBuilder builder(m_irModule);
stringLit = builder.getStringValue(UnownedStringSlice(preludeText));
m_builtinPreludes[preludeText] = stringLit;
}
m_requiredPreludes.add(stringLit);
}

void MetalSourceEmitter::emitMemoryOrderOperand(IRInst* inst)
{
auto memoryOrder = (IRMemoryOrder)getIntVal(inst);
Expand Down Expand Up @@ -1065,7 +1053,6 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_UInt8Type:
case kIROp_UIntType:
case kIROp_FloatType:
case kIROp_DoubleType:
case kIROp_HalfType:
{
m_writer->emit(getDefaultBuiltinTypeName(type->getOp()));
Expand Down Expand Up @@ -1093,6 +1080,9 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit(getName(type));
return;

case kIROp_DoubleType:
SLANG_UNEXPECTED("'double' type emitted");
return;
case kIROp_VectorType:
{
auto vecType = (IRVectorType*)type;
Expand Down
4 changes: 0 additions & 4 deletions source/slang/slang-emit-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@ class MetalSourceEmitter : public CLikeSourceEmitter

virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }

Dictionary<const char*, IRStringLit*> m_builtinPreludes;

protected:
RefPtr<MetalExtensionTracker> m_extensionTracker;

void ensurePrelude(const char* preludeText);

void emitMemoryOrderOperand(IRInst* inst);
virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type)
SLANG_OVERRIDE;
Expand Down
12 changes: 0 additions & 12 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,6 @@ fn _slang_getNan() -> f32
}
)";

void WGSLSourceEmitter::ensurePrelude(const char* preludeText)
{
IRStringLit* stringLit;
if (!m_builtinPreludes.tryGetValue(preludeText, stringLit))
{
IRBuilder builder(m_irModule);
stringLit = builder.getStringValue(UnownedStringSlice(preludeText));
m_builtinPreludes[preludeText] = stringLit;
}
m_requiredPreludes.add(stringLit);
}

void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
const SwitchRegion::Case* const currentCase,
const bool isDefault)
Expand Down
4 changes: 0 additions & 4 deletions source/slang/slang-emit-wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ class WGSLSourceEmitter : public CLikeSourceEmitter
void emit(const AddressSpace addressSpace);

virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;
Dictionary<const char*, IRStringLit*> m_builtinPreludes;

protected:
void ensurePrelude(const char* preludeText);

private:
bool maybeEmitSystemSemantic(IRInst* inst);
Expand Down
43 changes: 43 additions & 0 deletions tests/compute/bitcast-64bit.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -emit-spirv-via-glsl -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -d3d12 -profile cs_6_6 -use-dxil -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=gOutputBuffer
RWStructuredBuffer<uint64_t> gOutputBuffer;

int64_t icast(double x)
{
return bit_cast<int64_t>(x);
}

int64_t icast(uint64_t x)
{
return bit_cast<int64_t>(x);
}

uint64_t ucast(double x)
{
return bit_cast<uint64_t>(x);
}

uint64_t ucast(int64_t x)
{
return bit_cast<uint64_t>(x);
}

[numthreads(1, 1, 1)]
[shader("compute")]
void computeMain()
{
double t1 = -1.0;
uint64_t t2 = 2;
gOutputBuffer[0] = icast(t1); // 0xBFF0000000000000 => 13830554455654793216
gOutputBuffer[1] = icast(t2); // 0x0000000000000002 => 2

double t3 = 3.0;
int64_t t4 = -4;
gOutputBuffer[2] = ucast(t3); // 0x4008000000000000 => 4613937818241073152
gOutputBuffer[3] = ucast(t4); // 0xFFFFFFFFFFFFFFFC => 18446744073709551612
}
5 changes: 5 additions & 0 deletions tests/compute/bitcast-64bit.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: uint64_t
13830554455654793216
2
4613937818241073152
18446744073709551612

0 comments on commit 6f57e47

Please sign in to comment.