Skip to content

Commit fab1c9f

Browse files
authored
Support CUDA bindless texture in dynamic dispatch code. (shader-slang#1575)
1 parent 11f3317 commit fab1c9f

14 files changed

+424
-45
lines changed

source/slang/slang-emit-c-like.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -2090,7 +2090,15 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
20902090
emitType(inst->getDataType());
20912091
emitArgs(inst);
20922092
break;
2093-
2093+
case kIROp_makeUInt64:
2094+
m_writer->emit("((");
2095+
emitType(inst->getDataType());
2096+
m_writer->emit("(");
2097+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
2098+
m_writer->emit(") << 32) + ");
2099+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
2100+
m_writer->emit(")");
2101+
break;
20942102
case kIROp_constructVectorFromScalar:
20952103
{
20962104
// Simple constructor call

source/slang/slang-ir-any-value-marshalling.cpp

+50-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ namespace Slang
8585
IRInst* anyValueVar;
8686
// Defines what to do with basic typed data elements.
8787
virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;
88-
88+
// Defines what to do with resource handle elements.
89+
virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;
8990
// Validates that the type fits in the given AnyValueSize.
9091
// After calling emitMarshallingCode, `fieldOffset` will be increased to the required `AnyValue` size.
9192
// If this is larger than the provided AnyValue size, report a dianogstic. We might want to front load
@@ -188,6 +189,11 @@ namespace Slang
188189
break;
189190
}
190191
default:
192+
if (as<IRTextureTypeBase>(dataType) || as<IRSamplerStateTypeBase>(dataType))
193+
{
194+
context->marshalResourceHandle(builder, dataType, concreteTypedVar);
195+
return;
196+
}
191197
SLANG_UNIMPLEMENTED_X("Unimplemented type packing");
192198
break;
193199
}
@@ -243,6 +249,29 @@ namespace Slang
243249
SLANG_UNREACHABLE("unknown basic type");
244250
}
245251
}
252+
253+
virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
254+
{
255+
SLANG_UNUSED(dataType);
256+
if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
257+
{
258+
auto srcVal = builder->emitLoad(concreteVar);
259+
auto uint64Val = builder->emitBitCast(builder->getUInt64Type(), srcVal);
260+
auto lowBits = builder->emitConstructorInst(builder->getUIntType(), 1, &uint64Val);
261+
auto shiftedBits = builder->emitShr(
262+
builder->getUInt64Type(),
263+
uint64Val,
264+
builder->getIntValue(builder->getIntType(), 32));
265+
auto highBits = builder->emitBitCast(builder->getUIntType(), shiftedBits);
266+
auto dstAddr1 = builder->emitFieldAddress(
267+
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
268+
builder->emitStore(dstAddr1, lowBits);
269+
auto dstAddr2 = builder->emitFieldAddress(
270+
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
271+
builder->emitStore(dstAddr2, highBits);
272+
fieldOffset += 2;
273+
}
274+
}
246275
};
247276

248277
IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType)
@@ -335,6 +364,26 @@ namespace Slang
335364
SLANG_UNREACHABLE("unknown basic type");
336365
}
337366
}
367+
368+
virtual void marshalResourceHandle(
369+
IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
370+
{
371+
if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
372+
{
373+
auto srcAddr = builder->emitFieldAddress(
374+
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
375+
auto lowBits = builder->emitLoad(srcAddr);
376+
377+
auto srcAddr1 = builder->emitFieldAddress(
378+
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
379+
auto highBits = builder->emitLoad(srcAddr1);
380+
381+
auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
382+
combinedBits = builder->emitBitCast(dataType, combinedBits);
383+
builder->emitStore(concreteVar, combinedBits);
384+
fieldOffset += 2;
385+
}
386+
}
338387
};
339388

340389
IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType)

source/slang/slang-ir-inst-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0)
234234

235235
INST(Construct, construct, 0, 0)
236236

237+
INST(makeUInt64, makeUInt64, 2, 0)
237238
INST(makeVector, makeVector, 0, 0)
238239
INST(MakeMatrix, makeMatrix, 0, 0)
239240
INST(makeArray, makeArray, 0, 0)

source/slang/slang-ir-insts.h

+5
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,7 @@ struct IRBuilder
17791779
IRBasicType* getBoolType();
17801780
IRBasicType* getIntType();
17811781
IRBasicType* getUIntType();
1782+
IRBasicType* getUInt64Type();
17821783
IRStringType* getStringType();
17831784

17841785
IRAssociatedType* getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes);
@@ -1942,6 +1943,8 @@ struct IRBuilder
19421943
UInt argCount,
19431944
IRInst* const* args);
19441945

1946+
IRInst* emitMakeUInt64(IRInst* low, IRInst* high);
1947+
19451948
// Creates an RTTI object. Result is of `IRRTTIType`.
19461949
IRInst* emitMakeRTTIObject(IRInst* typeInst);
19471950

@@ -2303,6 +2306,8 @@ struct IRBuilder
23032306

23042307
IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right);
23052308
IRInst* emitMul(IRType* type, IRInst* left, IRInst* right);
2309+
IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1);
2310+
IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1);
23062311

23072312
//
23082313
// Decorations

source/slang/slang-ir.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -2230,6 +2230,11 @@ namespace Slang
22302230
return (IRBasicType*)getType(kIROp_UIntType);
22312231
}
22322232

2233+
IRBasicType* IRBuilder::getUInt64Type()
2234+
{
2235+
return (IRBasicType*)getType(kIROp_UInt64Type);
2236+
}
2237+
22332238
IRStringType* IRBuilder::getStringType()
22342239
{
22352240
return (IRStringType*)getType(kIROp_StringType);
@@ -2750,6 +2755,12 @@ namespace Slang
27502755
return inst;
27512756
}
27522757

2758+
IRInst* IRBuilder::emitMakeUInt64(IRInst* low, IRInst* high)
2759+
{
2760+
IRInst* args[2] = {low, high};
2761+
return emitIntrinsicInst(getUInt64Type(), kIROp_makeUInt64, 2, args);
2762+
}
2763+
27532764
IRInst* IRBuilder::emitMakeRTTIObject(IRInst* typeInst)
27542765
{
27552766
auto inst = createInst<IRRTTIObject>(
@@ -3786,6 +3797,20 @@ namespace Slang
37863797
return inst;
37873798
}
37883799

3800+
IRInst* IRBuilder::emitShr(IRType* type, IRInst* left, IRInst* right)
3801+
{
3802+
auto inst = createInst<IRInst>(this, kIROp_Rsh, type, left, right);
3803+
addInst(inst);
3804+
return inst;
3805+
}
3806+
3807+
IRInst* IRBuilder::emitShl(IRType* type, IRInst* left, IRInst* right)
3808+
{
3809+
auto inst = createInst<IRInst>(this, kIROp_Lsh, type, left, right);
3810+
addInst(inst);
3811+
return inst;
3812+
}
3813+
37893814
IRInst* IRBuilder::emitGpuForeach(List<IRInst*> args)
37903815
{
37913816
auto inst = createInst<IRInst>(
@@ -5300,6 +5325,7 @@ namespace Slang
53005325
case kIROp_getAddr:
53015326
case kIROp_GetValueFromExistentialBox:
53025327
case kIROp_Construct:
5328+
case kIROp_makeUInt64:
53035329
case kIROp_makeVector:
53045330
case kIROp_MakeMatrix:
53055331
case kIROp_makeArray:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Test using interface typed shader parameters with texture typed fields.
2+
//TEST(compute):COMPARE_COMPUTE:-cpu
3+
//TEST(compute):COMPARE_COMPUTE:-cuda
4+
5+
[anyValueSize(8)]
6+
interface IInterface
7+
{
8+
float run();
9+
}
10+
11+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
12+
RWStructuredBuffer<uint> gOutputBuffer;
13+
//TEST_INPUT: Texture2D(size=8, content = one):name t2D,bindless
14+
//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) handle(t2D) 0 0], stride=4):name=gCb
15+
StructuredBuffer<IInterface> gCb;
16+
17+
[numthreads(4, 1, 1)]
18+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
19+
{
20+
let tid = dispatchThreadID.x;
21+
22+
let inputVal : int = tid;
23+
IInterface v0 = gCb.Load(0);
24+
SamplerState sampler;
25+
let outputVal = v0.run();
26+
gOutputBuffer[tid] = trunc(outputVal);
27+
}
28+
29+
//TEST_INPUT: globalExistentialType __Dynamic
30+
31+
// Type must be marked `public` to ensure it is visible in the generated DLL.
32+
public struct MyImpl : IInterface
33+
{
34+
Texture2D tex;
35+
SamplerState sampler;
36+
float run()
37+
{
38+
return tex.Sample(sampler, float2(0.0, 0.0)).x;
39+
}
40+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1
2+
1
3+
1
4+
1

0 commit comments

Comments
 (0)