Skip to content

Commit a2a7c4d

Browse files
authored
Allow unspecialized existential shader parameters (dynamic dispatch). (shader-slang#1529)
* Allow unspecialized existential shader parameters (dynamic dispatch). * Fixes. * Fixes * disable cuda test
1 parent 7f567df commit a2a7c4d

16 files changed

+194
-19
lines changed

slang.h

+6
Original file line numberDiff line numberDiff line change
@@ -3110,6 +3110,12 @@ namespace slang
31103110
LayoutRules rules = LayoutRules::Default,
31113111
ISlangBlob** outDiagnostics = nullptr) = 0;
31123112

3113+
/** Get the mangled name for a type RTTI object.
3114+
*/
3115+
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTypeRTTIMangledName(
3116+
TypeReflection* type,
3117+
ISlangBlob** outNameBlob) = 0;
3118+
31133119
/** Get the mangled name for a type witness.
31143120
*/
31153121
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessMangledName(

source/slang/slang-compiler.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -2483,13 +2483,14 @@ SlangResult dissassembleDXILUsingDXC(
24832483
BackEndCompileRequest* compileRequest,
24842484
EndToEndCompileRequest* endToEndReq)
24852485
{
2486-
// If we are about to generate output code, but we still
2486+
// When dynamic dispatch is disabled, the program must
2487+
// be fully specialized by now. So we check if we still
24872488
// have unspecialized generic/existential parameters,
2488-
// then there is a problem.
2489+
// and report them as an error.
24892490
//
24902491
auto program = compileRequest->getProgram();
24912492
auto specializationParamCount = program->getSpecializationParamCount();
2492-
if( specializationParamCount != 0 )
2493+
if (compileRequest->disableDynamicDispatch && specializationParamCount != 0)
24932494
{
24942495
auto sink = compileRequest->getSink();
24952496

source/slang/slang-compiler.h

+6
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,9 @@ namespace Slang
12051205
SlangInt targetIndex = 0,
12061206
slang::LayoutRules rules = slang::LayoutRules::Default,
12071207
ISlangBlob** outDiagnostics = nullptr) override;
1208+
SLANG_NO_THROW SlangResult SLANG_MCALL getTypeRTTIMangledName(
1209+
slang::TypeReflection* type,
1210+
ISlangBlob** outNameBlob) override;
12081211
SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessMangledName(
12091212
slang::TypeReflection* type,
12101213
slang::TypeReflection* interfaceType,
@@ -1763,6 +1766,9 @@ namespace Slang
17631766
// If true will disable generics/existential value specialization pass.
17641767
bool disableSpecialization = false;
17651768

1769+
// If true will disable generating dynamic dispatch code.
1770+
bool disableDynamicDispatch = false;
1771+
17661772
String m_dumpIntermediatePrefix;
17671773

17681774
private:

source/slang/slang-emit-c-like.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -3888,8 +3888,9 @@ void CLikeSourceEmitter::computeEmitActions(IRModule* module, List<EmitAction>&
38883888
{
38893889
if( as<IRType>(inst) )
38903890
{
3891-
// Don't emit a type unless it is actually used.
3892-
continue;
3891+
// Don't emit a type unless it is actually used or is marked public.
3892+
if (!inst->findDecoration<IRPublicDecoration>())
3893+
continue;
38933894
}
38943895

38953896
ensureGlobalInst(&ctx, inst, EmitAction::Level::Definition);

source/slang/slang-emit-cpp.cpp

+12-7
Original file line numberDiff line numberDiff line change
@@ -1615,12 +1615,12 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
16151615
auto witnessTableItems = witnessTable->getChildren();
16161616
_maybeEmitWitnessTableTypeDefinition(interfaceType);
16171617

1618-
// Define a global variable for the witness table.
1619-
m_writer->emit("extern \"C\" ");
1618+
// Declare a global variable for the witness table.
1619+
m_writer->emit("extern \"C\" { SLANG_PRELUDE_SHARED_LIB_EXPORT extern ");
16201620
emitSimpleType(interfaceType);
16211621
m_writer->emit(" ");
16221622
m_writer->emit(getName(witnessTable));
1623-
m_writer->emit(";\n");
1623+
m_writer->emit("; }\n");
16241624

16251625
// The actual definition of this witness table global variable
16261626
// is deferred until the entire `Context` class is emitted, so
@@ -1636,6 +1636,9 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
16361636
{
16371637
auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
16381638
List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable);
1639+
m_writer->emit("extern \"C\"\n{\n");
1640+
m_writer->indent();
1641+
m_writer->emit("SLANG_PRELUDE_SHARED_LIB_EXPORT\n");
16391642
emitSimpleType(interfaceType);
16401643
m_writer->emit(" ");
16411644
m_writer->emit(getName(witnessTable));
@@ -1679,6 +1682,8 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
16791682
}
16801683
m_writer->dedent();
16811684
m_writer->emit("\n};\n");
1685+
m_writer->dedent();
1686+
m_writer->emit("\n}\n");
16821687
}
16831688
}
16841689

@@ -1698,18 +1703,18 @@ void CPPSourceEmitter::emitInterface(IRInterfaceType* interfaceType)
16981703
void CPPSourceEmitter::emitRTTIObject(IRRTTIObject* rttiObject)
16991704
{
17001705
// Declare the type info object as `extern "C"` first.
1701-
m_writer->emit("extern \"C\" TypeInfo ");
1706+
m_writer->emit("extern \"C\"{ SLANG_PRELUDE_SHARED_LIB_EXPORT extern TypeInfo ");
17021707
m_writer->emit(getName(rttiObject));
1703-
m_writer->emit(";\n");
1708+
m_writer->emit("; }\n");
17041709

17051710
// Now actually define the object.
1706-
m_writer->emit("TypeInfo ");
1711+
m_writer->emit("extern \"C\" { SLANG_PRELUDE_SHARED_LIB_EXPORT TypeInfo ");
17071712
m_writer->emit(getName(rttiObject));
17081713
m_writer->emit(" = {");
17091714
auto typeSizeDecoration = rttiObject->findDecoration<IRRTTITypeSizeDecoration>();
17101715
SLANG_ASSERT(typeSizeDecoration);
17111716
m_writer->emit(typeSizeDecoration->getTypeSize());
1712-
m_writer->emit("};\n");
1717+
m_writer->emit("}; }\n");
17131718
}
17141719

17151720

source/slang/slang-ir-generics-lowering-context.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ namespace Slang
7575
String rttiObjName = exportDecoration->getMangledName();
7676
builder->addExportDecoration(result, rttiObjName.getUnownedSlice());
7777
}
78+
// Make sure the RTTI object for a public struct type has public visiblity.
79+
if (typeInst->findDecoration<IRPublicDecoration>())
80+
{
81+
builder->addPublicDecoration(result);
82+
builder->addKeepAliveDecoration(result);
83+
}
7884
mapTypeToRTTIObject[typeInst] = result;
7985
return result;
8086
}

source/slang/slang-ir-insts.h

+6
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ IR_SIMPLE_DECORATION(EarlyDepthStencilDecoration)
259259
IR_SIMPLE_DECORATION(GloballyCoherentDecoration)
260260
IR_SIMPLE_DECORATION(PreciseDecoration)
261261
IR_SIMPLE_DECORATION(PublicDecoration)
262+
IR_SIMPLE_DECORATION(KeepAliveDecoration)
262263

263264

264265
struct IROutputControlPointsDecoration : IRDecoration
@@ -2431,6 +2432,11 @@ struct IRBuilder
24312432
addDecoration(value, kIROp_KeepAliveDecoration);
24322433
}
24332434

2435+
void addPublicDecoration(IRInst* value)
2436+
{
2437+
addDecoration(value, kIROp_PublicDecoration);
2438+
}
2439+
24342440
/// Add a decoration that indicates that the given `inst` depends on the given `dependency`.
24352441
///
24362442
/// This decoration can be used to ensure that a value that an instruction

source/slang/slang-ir-lower-generic-type.cpp

+27-3
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,22 @@ namespace Slang
1010
{
1111
// This is a subpass of generics lowering IR transformation.
1212
// This pass lowers all generic/polymorphic types into IRAnyValueType.
13-
struct GenericVarLoweringContext
13+
struct GenericTypeLoweringContext
1414
{
1515
SharedGenericsLoweringContext* sharedContext;
1616

1717
void processInst(IRInst* inst)
1818
{
19-
// If inst is a type itself, keep its type.
19+
// Ensure public struct types has RTTI object defined.
20+
if (as<IRStructType>(inst))
21+
{
22+
if (inst->findDecoration<IRPublicDecoration>())
23+
{
24+
sharedContext->maybeEmitRTTIObject(inst);
25+
}
26+
}
27+
28+
// Don't modify type insts themselves.
2029
if (as<IRType>(inst))
2130
return;
2231

@@ -28,6 +37,21 @@ namespace Slang
2837
auto newType = sharedContext->lowerType(builder, inst->getFullType());
2938
if (newType != inst->getFullType())
3039
inst->setFullType((IRType*)newType);
40+
41+
switch (inst->op)
42+
{
43+
default:
44+
break;
45+
case kIROp_StructField:
46+
{
47+
// Translate the struct field type.
48+
auto structField = static_cast<IRStructField*>(inst);
49+
auto loweredFieldType =
50+
sharedContext->lowerType(builder, structField->getFieldType());
51+
structField->setOperand(1, loweredFieldType);
52+
}
53+
break;
54+
}
3155
}
3256

3357
void processModule()
@@ -62,7 +86,7 @@ namespace Slang
6286

6387
void lowerGenericType(SharedGenericsLoweringContext* sharedContext)
6488
{
65-
GenericVarLoweringContext context;
89+
GenericTypeLoweringContext context;
6690
context.sharedContext = sharedContext;
6791
context.processModule();
6892
}

source/slang/slang-lower-to-ir.cpp

+30-1
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,11 @@ static void addLinkageDecoration(
10381038
{
10391039
builder->addExportDecoration(inst, mangledName);
10401040
}
1041+
if (decl->findModifier<PublicModifier>())
1042+
{
1043+
builder->addPublicDecoration(inst);
1044+
builder->addKeepAliveDecoration(inst);
1045+
}
10411046
}
10421047

10431048
static void addLinkageDecoration(
@@ -5161,6 +5166,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
51615166
return LoweredValInfo::simple(inst);
51625167
}
51635168

5169+
bool isPublicType(Type* type)
5170+
{
5171+
if (auto declRefType = as<DeclRefType>(type))
5172+
{
5173+
if (declRefType->declRef.getDecl()->findModifier<PublicModifier>())
5174+
return true;
5175+
}
5176+
return false;
5177+
}
51645178

51655179
void lowerWitnessTable(
51665180
IRGenContext* subContext,
@@ -5211,6 +5225,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
52115225
astReqWitnessTable->witnessedType,
52125226
astReqWitnessTable->baseType);
52135227
subBuilder->addExportDecoration(irSatisfyingWitnessTable, mangledName.getUnownedSlice());
5228+
if (isPublicType(astReqWitnessTable->witnessedType))
5229+
{
5230+
subBuilder->addPublicDecoration(irSatisfyingWitnessTable);
5231+
subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable);
5232+
}
5233+
52145234
// Recursively lower the sub-table.
52155235
lowerWitnessTable(
52165236
subContext,
@@ -5327,6 +5347,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
53275347
// Create the IR-level witness table
53285348
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType);
53295349
addLinkageDecoration(context, irWitnessTable, inheritanceDecl, mangledName.getUnownedSlice());
5350+
if (parentDecl->findModifier<PublicModifier>())
5351+
{
5352+
subBuilder->addPublicDecoration(irWitnessTable);
5353+
subBuilder->addKeepAliveDecoration(irWitnessTable);
5354+
}
53305355

53315356
// Register the value now, rather than later, to avoid any possible infinite recursion.
53325357
setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable));
@@ -6154,6 +6179,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
61546179
SLANG_UNREACHABLE("associatedtype should have been handled by visitAssocTypeDecl.");
61556180
}
61566181

6182+
bool isPublicType = decl->findModifier<PublicModifier>() != nullptr;
6183+
61576184
// Given a declaration of a type, we need to make sure
61586185
// to output "witness tables" for any interfaces this
61596186
// type has declared conformance to.
@@ -6171,18 +6198,20 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
61716198

61726199
// Emit any generics that should wrap the actual type.
61736200
auto outerGeneric = emitOuterGenerics(subContext, decl, decl);
6174-
61756201

61766202
IRStructType* irStruct = subBuilder->createStructType();
61776203
addNameHint(context, irStruct, decl);
61786204
addLinkageDecoration(context, irStruct, decl);
6205+
61796206
subBuilder->setInsertInto(irStruct);
61806207

61816208
// A `struct` that inherits from another `struct` must start
61826209
// with a member for the direct base type.
61836210
//
61846211
for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() )
61856212
{
6213+
if (isPublicType)
6214+
ensureDecl(context, inheritanceDecl);
61866215
auto superType = inheritanceDecl->base;
61876216
if(auto superDeclRefType = as<DeclRefType>(superType))
61886217
{

source/slang/slang-options.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,10 @@ struct OptionsParser
563563
{
564564
requestImpl->getBackEndReq()->disableSpecialization = true;
565565
}
566+
else if (argStr == "-disable-dynamic-dispatch")
567+
{
568+
requestImpl->getBackEndReq()->disableDynamicDispatch = true;
569+
}
566570
else if (argStr == "-verbose-paths")
567571
{
568572
requestImpl->getSink()->setFlag(DiagnosticSink::Flag::VerbosePath);

source/slang/slang-type-layout.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -3613,6 +3613,22 @@ static TypeLayoutResult _createTypeLayout(
36133613
typeLayout->type = type;
36143614
typeLayout->rules = rules;
36153615

3616+
if (isCPUTarget(context.targetReq) || isCUDATarget(context.targetReq))
3617+
{
3618+
LayoutSize fixedSize = 16;
3619+
if (auto anyValueAttr =
3620+
interfaceDeclRef.getDecl()->findModifier<AnyValueSizeAttribute>())
3621+
{
3622+
fixedSize += anyValueAttr->size;
3623+
}
3624+
else
3625+
{
3626+
// The interface type does not have an `[anyValueSize]` attribute,
3627+
// assume a default of 8 bytes.
3628+
fixedSize += 8;
3629+
}
3630+
typeLayout->addResourceUsage(LayoutResourceKind::Uniform, fixedSize);
3631+
}
36163632
typeLayout->addResourceUsage(LayoutResourceKind::ExistentialTypeParam, 1);
36173633
typeLayout->addResourceUsage(LayoutResourceKind::ExistentialObjectParam, 1);
36183634

source/slang/slang.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,20 @@ SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout(
623623
return asExternal(typeLayout);
624624
}
625625

626+
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeRTTIMangledName(
627+
slang::TypeReflection* type, ISlangBlob** outNameBlob)
628+
{
629+
auto internalType = asInternal(type);
630+
if (auto declRefType = as<DeclRefType>(internalType))
631+
{
632+
auto name = getMangledName(internalType->getASTBuilder(), declRefType->declRef);
633+
Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name);
634+
*outNameBlob = blob.detach();
635+
return SLANG_OK;
636+
}
637+
return SLANG_FAIL;
638+
}
639+
626640
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessMangledName(
627641
slang::TypeReflection* type, slang::TypeReflection* interfaceType, ISlangBlob** outNameBlob)
628642
{
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Test using interface typed shader parameters with dynamic dispatch.
2+
3+
//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization
4+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization
5+
6+
[anyValueSize(8)]
7+
interface IInterface
8+
{
9+
int run(int input);
10+
}
11+
12+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
13+
RWStructuredBuffer<int> gOutputBuffer;
14+
15+
//TEST_INPUT:cbuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb
16+
ConstantBuffer<IInterface> gCb;
17+
18+
[numthreads(4, 1, 1)]
19+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
20+
{
21+
let tid = dispatchThreadID.x;
22+
23+
let inputVal : int = tid;
24+
let outputVal = gCb.run(inputVal);
25+
26+
gOutputBuffer[tid] = outputVal;
27+
}
28+
29+
// No type input for dynamic dispatch //TEST_INPUT: globalExistentialType MyImpl
30+
// Type must be marked `public` to ensure it is visible in the generated DLL.
31+
public struct MyImpl : IInterface
32+
{
33+
int val;
34+
int run(int input)
35+
{
36+
return input + val;
37+
}
38+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1
2+
2
3+
3
4+
4

0 commit comments

Comments
 (0)