Skip to content

Commit aedf617

Browse files
author
Tim Foley
authored
Initial support for dynamic dispatch using "tagged union" types (shader-slang#772)
* Initial support for dynamic dispatch using "tagged union" types Suppose a user declares some generic shader code, like the following: ```hlsl interface IFrobnicator { ... } type_param T : IFrobincator; ParameterBlock<T : IFrobnicator> gFrobnicator; ... gFrobincator.frobnicate(value); ``` and then they have some concrete implementations of the required interface: ```hlsl struct A : IFrobnicator { ... } struct B : IFrobnicator { ... } ``` The current Slang compiler allows them to generate distinct compiled kernels for the case of `T=A` and the case of `T=B`. This means that the decision of which implementation to use must be made at or before the time when a shader gets bound in the application. This change adds a new ability where the Slang compiler can generate code to handle the case where `T` might be *either* `A` or `B`, and which case it is will be determined dynamically at runtime. This means a single compiled kernel can handle both cases, and the decision about which code path to run can be made any time before the shader executes. This new option is supported by defining a *tagged union* type. Via the API, the user specifies that `T` should be specialized to `__TaggedUnion(A,B)` (the double underscore indicates that this is an experimental and unsupported feature at present). We refer to the types `A` and `B` here as the "case" types of the tagged union. Conceptually, the compiler synthesizes a type something like: ```hlsl struct TU { union { A a; B b; } payload; uint tag; } ``` The user can then allocate a constant buffer to hold their tagged union type, and when they pick a concrete type to use (say `B`), they fill in the first `sizeof(B)` bytes of their buffer with data describing a `B` instance, and then set the `tag` field to the appopriate 0-based index of the case type they chose (in this case the `B` case gets the tag value `1`). Actually implementing tagged unions takes a few main steps: * Type parsing was extended to special-case `__TaggedUnion` as a contextual keyword. This is really only intended to be used when parsing types from the API or command-line, and Bad Things are likely to happen if a user ever puts it directly in their code. Eventually construction of tagged unions should be an API feature and not part of the language syntax. * Semantic checking was extended to recognize that a tagged union like `__TaggedUnion(A,B)` shoud support an interface like `IFrobnicator` whenever all of the case types suport it, as long as the interface is "safe" for use with tagged unions (which means it doesn't use a few of the advancd langauge features like associated types). * The IR was extended with instructions to represent tagged union types and to extract their tag and the payload for the different cases as needed. * IR generation was extended to synthesize implementations of interface methods for any interface that a tagged union needs to support. Right now the implementation is simplistic and only handles simple method requirements, which it does by emitting a `switch` instruction to pick between the different cases. * A new IR pass was introduced to "desugar" any tagged union types used in the code. The downstream HLSL and GLSL compilers don't support `union`s, so we have to instead emit a tagged union as a "bag of bits" and implement loading the data for particular cases from it manually. * Final code emit mostly Just Works after the above steps, but we had to introduce an explicit IR instruction for bit-casting to handle the output of the desugaring pass. There are a bunch of gaps and caveats in this implementation, but that seems reasonable for something that is an experimental feature. The various `TODO` comments and assertion failures in unimplemented cases are intended, so that this work can be checked in even if it isn't feature-complete. * fixup: missing files * fixup: typos
1 parent 8e47a38 commit aedf617

28 files changed

+1999
-27
lines changed

source/slang/check.cpp

+158-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
namespace Slang
1111
{
12+
RefPtr<TypeType> getTypeType(
13+
Type* type);
1214

1315
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
1416
bool isEffectivelyStatic(
@@ -677,7 +679,7 @@ namespace Slang
677679
auto baseExprType = baseExpr->type.type;
678680
RefPtr<SharedTypeExpr> baseTypeExpr = new SharedTypeExpr();
679681
baseTypeExpr->base.type = baseExprType;
680-
baseTypeExpr->type = new TypeType(baseExprType);
682+
baseTypeExpr->type.type = getTypeType(baseExprType);
681683

682684
auto expr = new StaticMemberExpr();
683685
expr->loc = loc;
@@ -2071,9 +2073,7 @@ namespace Slang
20712073
{
20722074
RefPtr<TypeCastExpr> castExpr = createImplicitCastExpr();
20732075

2074-
auto typeType = new TypeType();
2075-
typeType->setSession(getSession());
2076-
typeType->type = toType;
2076+
auto typeType = getTypeType(toType);
20772077

20782078
auto typeExpr = new SharedTypeExpr();
20792079
typeExpr->type.type = typeType;
@@ -5547,6 +5547,60 @@ namespace Slang
55475547
return witness;
55485548
}
55495549

5550+
/// Is the given interface one that a tagged-union type can conform to?
5551+
///
5552+
/// If a tagged union type `__TaggedUnion(A,B)` is going to be
5553+
/// plugged in for a type parameter `T : IFoo` then we need to
5554+
/// be sure that the interface `IFoo` doesn't have anything
5555+
/// that could lead to unsafe/unsound behavior. This function
5556+
/// checks that all the requirements on the interfaceare safe ones.
5557+
///
5558+
bool isInterfaceSafeForTaggedUnion(
5559+
DeclRef<InterfaceDecl> interfaceDeclRef)
5560+
{
5561+
for( auto memberDeclRef : getMembers(interfaceDeclRef) )
5562+
{
5563+
if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef))
5564+
return false;
5565+
}
5566+
5567+
return true;
5568+
}
5569+
5570+
/// Is the given interface requirement one that a tagged-union type can satisfy?
5571+
///
5572+
/// Unsafe requirements include any `static` requirements,
5573+
/// any associated types, and also any requirements that make
5574+
/// use of the `This` type (once we support it).
5575+
///
5576+
bool isInterfaceRequirementSafeForTaggedUnion(
5577+
DeclRef<InterfaceDecl> interfaceDeclRef,
5578+
DeclRef<Decl> requirementDeclRef)
5579+
{
5580+
if(auto callableDeclRef = requirementDeclRef.As<CallableDecl>())
5581+
{
5582+
// A `static` method requirement can't be satisfied by a
5583+
// tagged union, because there is no tag to dispatch on.
5584+
//
5585+
if(requirementDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
5586+
return false;
5587+
5588+
// TODO: We will eventually want to check that any callable
5589+
// requirements do not use the `This` type or any associated
5590+
// types in ways that could lead to errors.
5591+
//
5592+
// For now we are disallowing interfaces that have associated
5593+
// types completely, and we haven't implemented the `This`
5594+
// type, so we should be safe.
5595+
5596+
return true;
5597+
}
5598+
else
5599+
{
5600+
return false;
5601+
}
5602+
}
5603+
55505604
bool doesTypeConformToInterfaceImpl(
55515605
RefPtr<Type> originalType,
55525606
RefPtr<Type> type,
@@ -5661,6 +5715,69 @@ namespace Slang
56615715
}
56625716
}
56635717
}
5718+
else if(auto taggedUnionType = type->As<TaggedUnionType>())
5719+
{
5720+
// A tagged union type conforms to an interface if all of
5721+
// the constituent types in the tagged union conform.
5722+
//
5723+
// We will iterate over the "case" types in the tagged
5724+
// union, and check if they conform to the interface.
5725+
// Along the way we will collect the conformance witness
5726+
// values *if* we are being asked to produce a witness
5727+
// value for the tagged union itself (that is, if
5728+
// `outWitness` is non-null).
5729+
//
5730+
List<RefPtr<Val>> caseWitnesses;
5731+
for(auto caseType : taggedUnionType->caseTypes)
5732+
{
5733+
RefPtr<Val> caseWitness;
5734+
5735+
if(!doesTypeConformToInterfaceImpl(
5736+
caseType,
5737+
caseType,
5738+
interfaceDeclRef,
5739+
outWitness ? &caseWitness : nullptr,
5740+
nullptr))
5741+
{
5742+
return false;
5743+
}
5744+
5745+
if(outWitness)
5746+
{
5747+
caseWitnesses.Add(caseWitness);
5748+
}
5749+
}
5750+
5751+
// We also need to validate the requirements on
5752+
// the interface to make sure that they are suitable for
5753+
// use with a tagged-union type.
5754+
//
5755+
// For example, if the interface includes a `static` method
5756+
// (which can therefore be called without a particular instance),
5757+
// then we wouldn't know what implementation of that method
5758+
// to use because there is no tag value to dispatch on.
5759+
//
5760+
// We will start out being conservative about what we accept
5761+
// here, just to keep things simple.
5762+
//
5763+
if(!isInterfaceSafeForTaggedUnion(interfaceDeclRef))
5764+
return false;
5765+
5766+
// If we reach this point then we have a concrete
5767+
// witness for each of the case types, and that is
5768+
// enough to build a witness for the tagged union.
5769+
//
5770+
if(outWitness)
5771+
{
5772+
RefPtr<TaggedUnionSubtypeWitness> taggedUnionWitness = new TaggedUnionSubtypeWitness();
5773+
taggedUnionWitness->sub = taggedUnionType;
5774+
taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef);
5775+
taggedUnionWitness->caseWitnesses.SwapWith(caseWitnesses);
5776+
5777+
*outWitness = taggedUnionWitness;
5778+
}
5779+
return true;
5780+
}
56645781

56655782
// default is failure
56665783
return false;
@@ -8090,6 +8207,23 @@ namespace Slang
80908207
return expr;
80918208
}
80928209

8210+
RefPtr<Expr> visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr)
8211+
{
8212+
// We have an expression of the form `__TaggedUnion(A, B, ...)`
8213+
// which will evaluate to a tagged-union type over `A`, `B`, etc.
8214+
//
8215+
RefPtr<TaggedUnionType> type = new TaggedUnionType();
8216+
expr->type = QualType(getTypeType(type));
8217+
8218+
for( auto& caseTypeExpr : expr->caseTypes )
8219+
{
8220+
caseTypeExpr = CheckProperType(caseTypeExpr);
8221+
type->caseTypes.Add(caseTypeExpr.type);
8222+
}
8223+
8224+
return expr;
8225+
}
8226+
80938227

80948228

80958229

@@ -9039,7 +9173,7 @@ namespace Slang
90399173
scopesToTry.Add(module->moduleDecl->scope);
90409174

90419175
List<RefPtr<Type>> globalGenericArgs;
9042-
for (auto name : entryPoint->genericParameterTypeNames)
9176+
for (auto name : entryPoint->genericArgStrings)
90439177
{
90449178
// parse type name
90459179
RefPtr<Type> type;
@@ -9059,6 +9193,25 @@ namespace Slang
90599193
return;
90609194
}
90619195

9196+
// The following is a bit of a hack.
9197+
//
9198+
// Back-end code generation relies on us having computed layouts for all tagged
9199+
// unions that end up being used in the code, which means we need a way to find
9200+
// all such types that get used in a module (and the stuff it imports).
9201+
//
9202+
// The Right Way to handle this would probably be to have each `ModuleDecl` track
9203+
// any tagged union types that get created in the context of that module, and
9204+
// then combine those lists later.
9205+
//
9206+
// For now we are assuming a tagged union type only comes into existence
9207+
// as a (top-level) argument for a generic type parameter, so that we
9208+
// can check for them here and cache them on the entry point request.
9209+
//
9210+
if( auto taggedUnionType = type->As<TaggedUnionType>() )
9211+
{
9212+
entryPoint->taggedUnionTypes.Add(taggedUnionType);
9213+
}
9214+
90629215
globalGenericArgs.Add(type);
90639216
}
90649217

source/slang/compiler.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@ namespace Slang
122122
// The name of the entry point function (e.g., `main`)
123123
Name* name;
124124

125-
// The type names we want to substitute into the
126-
// global generic type parameters
127-
List<String> genericParameterTypeNames;
125+
/// Source code for the generic arguments to use for the generic parameters of the entry point.
126+
List<String> genericArgStrings;
128127

129128
// The profile that the entry point will be compiled for
130129
// (this is a combination of the target stage, and also
@@ -156,6 +155,9 @@ namespace Slang
156155
RefPtr<FuncDecl> decl;
157156

158157
RefPtr<Substitutions> globalGenericSubst;
158+
159+
/// Any tagged union types that were referenced by the generic arguments of the entry point.
160+
List<RefPtr<TaggedUnionType>> taggedUnionTypes;
159161
};
160162

161163
enum class PassThroughMode : SlangPassThrough

source/slang/diagnostic-defs.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is onl
347347
DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration")
348348

349349
DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.")
350-
DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.")
350+
DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.")
351351

352352
DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself")
353353
DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')")

source/slang/emit.cpp

+96-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ir-specialize.h"
1313
#include "ir-specialize-resources.h"
1414
#include "ir-ssa.h"
15+
#include "ir-union.h"
1516
#include "ir-validate.h"
1617
#include "legalize-types.h"
1718
#include "lower-to-ir.h"
@@ -3922,13 +3923,100 @@ struct EmitVisitor
39223923
}
39233924
break;
39243925

3926+
case kIROp_BitCast:
3927+
{
3928+
// TODO: we can simplify the logic for arbitrary bitcasts
3929+
// by always bitcasting the source to a `uint*` type (if it
3930+
// isn't already) and then bitcasting that to the destination
3931+
// type (if it isn't already `uint*`.
3932+
//
3933+
// For now we are assuming the source type is *already*
3934+
// a `uint*` type of the appropriate size.
3935+
//
3936+
// auto fromType = extractBaseType(inst->getOperand(0)->getDataType());
3937+
auto toType = extractBaseType(inst->getDataType());
3938+
switch(getTarget(ctx))
3939+
{
3940+
case CodeGenTarget::GLSL:
3941+
switch(toType)
3942+
{
3943+
default:
3944+
emit("/* unhandled */");
3945+
break;
3946+
3947+
case BaseType::UInt:
3948+
break;
3949+
3950+
case BaseType::Int:
3951+
emitIRType(ctx, inst->getDataType());
3952+
break;
3953+
3954+
case BaseType::Float:
3955+
emit("uintBitsToFloat(");
3956+
break;
3957+
}
3958+
break;
3959+
3960+
case CodeGenTarget::HLSL:
3961+
switch(toType)
3962+
{
3963+
default:
3964+
emit("/* unhandled */");
3965+
break;
3966+
3967+
case BaseType::UInt:
3968+
break;
3969+
case BaseType::Int:
3970+
emit("(");
3971+
emitIRType(ctx, inst->getDataType());
3972+
emit(")");
3973+
break;
3974+
case BaseType::Float:
3975+
emit("asfloat");
3976+
break;
3977+
}
3978+
break;
3979+
3980+
3981+
default:
3982+
SLANG_UNEXPECTED("unhandled codegen target");
3983+
break;
3984+
}
3985+
3986+
emit("(");
3987+
emitIROperand(ctx, inst->getOperand(0), mode, kEOp_General);
3988+
emit(")");
3989+
}
3990+
break;
3991+
39253992
default:
39263993
emit("/* unhandled */");
39273994
break;
39283995
}
39293996
maybeCloseParens(needClose);
39303997
}
39313998

3999+
BaseType extractBaseType(IRType* inType)
4000+
{
4001+
auto type = inType;
4002+
for(;;)
4003+
{
4004+
if(auto irBaseType = as<IRBasicType>(type))
4005+
{
4006+
return irBaseType->getBaseType();
4007+
}
4008+
else if(auto vecType = as<IRVectorType>(type))
4009+
{
4010+
type = vecType->getElementType();
4011+
continue;
4012+
}
4013+
else
4014+
{
4015+
return BaseType::Void;
4016+
}
4017+
}
4018+
}
4019+
39324020
void emitIRInst(
39334021
EmitContext* ctx,
39344022
IRInst* inst,
@@ -6565,6 +6653,14 @@ String emitEntryPoint(
65656653
#endif
65666654
validateIRModuleIfEnabled(compileRequest, irModule);
65676655

6656+
// Desguar any union types, since these will be illegal on
6657+
// various targets.
6658+
//
6659+
desugarUnionTypes(irModule);
6660+
#if 0
6661+
dumpIRIfEnabled(compileRequest, irModule, "UNIONS DESUGARED");
6662+
#endif
6663+
validateIRModuleIfEnabled(compileRequest, irModule);
65686664

65696665

65706666
// Any code that makes use of existential (interface) types
@@ -6595,8 +6691,6 @@ String emitEntryPoint(
65956691
//
65966692
specializeGenerics(irModule);
65976693

6598-
6599-
66006694
// Debugging code for IR transformations...
66016695
#if 0
66026696
dumpIRIfEnabled(compileRequest, irModule, "SPECIALIZED");

source/slang/expr-defs.h

+11
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,14 @@ RAW(
193193
DeclRef<VarDeclBase> declRef;
194194
)
195195
END_SYNTAX_CLASS()
196+
197+
/// A type expression of the form `__TaggedUnion(A, ...)`.
198+
///
199+
/// An expression of this form will resolve to a `TaggedUnionType`
200+
/// when checked.
201+
///
202+
SYNTAX_CLASS(TaggedUnionTypeExpr, Expr)
203+
RAW(
204+
List<TypeExp> caseTypes;
205+
)
206+
END_SYNTAX_CLASS()

0 commit comments

Comments
 (0)