Skip to content

Commit 71e35b6

Browse files
author
Tim Foley
authored
Changes required for application adoption of interface-type parameters (shader-slang#963)
* A few changes required for application adoption of interface-type parameters There are a few small changes here that are all related in that they arose from trying to integrate support for specialization via global interface-type shader parameters into a real application. Allow querying the "pending" layout via reflection API ------------------------------------------------------ The naming here isn't ideal, and could probably use a round of "bikeshedding" to arrive at something better, but the basic idea is that when you have a type like: ``` struct MyStuff { int a; IFoo foo; int b; } ``` the fields `a` and `b` get allocated space directly in the "primary" layout for `MyStuff` (at offsets 0 and 4, with `sizeof(MyStuff) == 8`), but the `foo` field can't be allocated space until we know what concrete type will get plugged in there. If we have a concrete type in mind: ``` struct Bar : IFoo { int bar; } ``` then we can know how much space the `foo` field will take up, but we still can't allocate it space directly in `MyStuff`, because we already decided that `sizeof(MyStuff) == 8`. Now imagine we place some `MyStuff` values into constant buffers: ``` cbuffer X { MyStuff x; } cbuffer Y { MyStuff y; float4 z; } ``` In each case we know that we want to place the `MyStuff::foo` field at the end of the containing constant buffer so that it doesn't disrupt the layout of the existing fields. But that means that the offset of `MyStuff::foo` relative to the start of the `MyStuff` isn't fixed, because of unrelated fields like `z` that need to get in between. In our layout code, we handle this by having a notion of a "pending" layout. Once we know how `MyStuff::foo` will be specialized, we can compute both a "primary" and a "pending" layout for `MyStuff`, which basically treats it as if it were two distinct types: ``` struct MyStuff_Primary { int a; int b; } struct MyStuff_Pending { Bar foo; } ``` Layout for an aggregate type like the `X` or `Y` constant buffer then proceeds by computing an aggregate primary layout and an aggregate pending layout, and then finally a constant buffer or parameter block "flushes" all or part of the pending data by appending it to the primary data to get the final layout. What all this means is that a type like `MyStuff` will have two different layouts (a default one for the primary data and a "pending" one for any specialized interface-type fields), and a variable like `Y::y` will also have two variable layouts that specify offsets (one set of offsets for its primary part, and one set of offsets for its pending part). In order to handle interface-type fields with these layout rules, an application needs a way to query the "pending" part of a type or variable layout, which luckily gives it back just another type/variable layout. The API change here is minimal, although actually exploiting the new API correctly in application code could prove challenging. Allow creating of explicitly specialized types ---------------------------------------------- This feature isn't actually implemented all the way through the compiler (I just needed enough to make the API calls go through), but I've added support for specializing a type that has interface-type fields through the reflection API. This maps to an `ExistentialSpecializedType` in the AST, and I'm lowering it to the IR as a `BindExistentialsType`, although that isn't 100% correct for the future. This feature will require a future PR to actually flesh out the implementation work, but I'll wait until that is the sticking point on the application side before I do that. Introduce a tiny `Hasher` abstraction ------------------------------------- While implementing all the boilerplate for a new `Type` subclass (we really need to reduce that work...), I got fed up with how we do hash-code computation and introduced a small utility `Hasher` type that is intended to wrap up the idiom of combining hashes. For now this isn't a major change, but in the future I'd like to expand on the design a bit to clean up some of the warts around how we handle hashing: * The `Hasher` implementation can and should switch from maintaining a single `HashCode` as its state to something that contains a more complete state (larger than the hash code) and just hashes new bytes into that state as it goes. This should make it possible to implement a `Hasher` for more serious hash functions, whether MD5, CityHash, or whatever we decide is good default. * Things that are hashable shouldn't have a `getHashCode()` method, but instead should have something like a `hashInto(Hasher&)` method. This change would have the dual benefits that (1) a composite type can easily hash all the fields that contribute to its identity into the hasher with minimal fuss/boilerplate, and (2) the hashes for composite types will be of higher quality because they can exploit all the bits of the hasher's state to combine the fields, instead of restricting each sub-field to just the bits in a hash code. We should be able to incrementally improve the quality of our design there over future changes, but for now it probably isn't a critical priority. Fixes for legalization of existential types ------------------------------------------- There were some missing cases in the handling of type legalization, such that a global interface-type shader parameter that got specialized to a type that contains *only* resource-type fields would cause a crash in the legalization step. I added a test for this case, and then made `ir-legalize-types.cpp` account for this case (the code to handle it ias a bit of a kludge, and shows that the `declareVars()` routine there is getting to a level of complexity that is worrying. * fixup: review feedback
1 parent 3b9994b commit 71e35b6

14 files changed

+431
-40
lines changed

slang.h

+39
Original file line numberDiff line numberDiff line change
@@ -1812,6 +1812,8 @@ extern "C"
18121812

18131813
SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLayout* type);
18141814

1815+
SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_getPendingDataTypeLayout(SlangReflectionTypeLayout* type);
1816+
18151817
// Variable Reflection
18161818

18171819
SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* var);
@@ -1843,6 +1845,9 @@ extern "C"
18431845
SLANG_API SlangStage spReflectionVariableLayout_getStage(
18441846
SlangReflectionVariableLayout* var);
18451847

1848+
1849+
SLANG_API SlangReflectionVariableLayout* spReflectionVariableLayout_getPendingDataLayout(SlangReflectionVariableLayout* var);
1850+
18461851
// Shader Parameter Reflection
18471852

18481853
typedef SlangReflectionVariableLayout SlangReflectionParameter;
@@ -1897,6 +1902,14 @@ extern "C"
18971902
SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection);
18981903
SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection);
18991904

1905+
SLANG_API SlangReflectionType* spReflection_specializeType(
1906+
SlangReflection* reflection,
1907+
SlangReflectionType* type,
1908+
SlangInt specializationArgCount,
1909+
SlangReflectionType* const* specializationArgs,
1910+
ISlangBlob** outDiagnostics);
1911+
1912+
19001913
#ifdef __cplusplus
19011914
}
19021915

@@ -2232,6 +2245,13 @@ namespace slang
22322245
return spReflectionTypeLayout_getGenericParamIndex(
22332246
(SlangReflectionTypeLayout*) this);
22342247
}
2248+
2249+
TypeLayoutReflection* getPendingDataTypeLayout()
2250+
{
2251+
return (TypeLayoutReflection*) spReflectionTypeLayout_getPendingDataTypeLayout(
2252+
(SlangReflectionTypeLayout*) this);
2253+
}
2254+
22352255
};
22362256

22372257
struct Modifier
@@ -2350,6 +2370,11 @@ namespace slang
23502370
{
23512371
return spReflectionVariableLayout_getStage((SlangReflectionVariableLayout*) this);
23522372
}
2373+
2374+
VariableLayoutReflection* getPendingDataLayout()
2375+
{
2376+
return (VariableLayoutReflection*) spReflectionVariableLayout_getPendingDataLayout((SlangReflectionVariableLayout*) this);
2377+
}
23532378
};
23542379

23552380
struct EntryPointReflection
@@ -2487,6 +2512,20 @@ namespace slang
24872512
(SlangReflection*) this,
24882513
name);
24892514
}
2515+
2516+
TypeReflection* specializeType(
2517+
TypeReflection* type,
2518+
SlangInt specializationArgCount,
2519+
TypeReflection* const* specializationArgs,
2520+
ISlangBlob** outDiagnostics)
2521+
{
2522+
return (TypeReflection*) spReflection_specializeType(
2523+
(SlangReflection*) this,
2524+
(SlangReflectionType*) type,
2525+
specializationArgCount,
2526+
(SlangReflectionType* const*) specializationArgs,
2527+
outDiagnostics);
2528+
}
24902529
};
24912530
}
24922531

source/core/hash.h

+28
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
namespace Slang
99
{
10+
typedef int HashCode;
11+
1012
inline int GetHashCode(double key)
1113
{
1214
return FloatAsInt((float)key);
@@ -120,6 +122,32 @@ namespace Slang
120122
{
121123
return (left * 16777619) ^ right;
122124
}
125+
126+
struct Hasher
127+
{
128+
public:
129+
Hasher() {}
130+
131+
template<typename T>
132+
void hashValue(T const& value)
133+
{
134+
m_hashCode = combineHash(m_hashCode, GetHashCode(value));
135+
}
136+
137+
template<typename T>
138+
void hashObject(T const& object)
139+
{
140+
m_hashCode = combineHash(m_hashCode, object->GetHashCode());
141+
}
142+
143+
HashCode getResult() const
144+
{
145+
return m_hashCode;
146+
}
147+
148+
private:
149+
HashCode m_hashCode = 0;
150+
};
123151
}
124152

125153
#endif

source/slang/check.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -10763,6 +10763,41 @@ static bool doesParameterMatch(
1076310763
Slang::_specializeExistentialTypeParams(getLinkage(), m_globalExistentialSlots, args, sink);
1076410764
}
1076510765

10766+
Type* Linkage::specializeType(
10767+
Type* unspecializedType,
10768+
Int argCount,
10769+
Type* const* args,
10770+
DiagnosticSink* sink)
10771+
{
10772+
// TODO: We should cache and re-use specialized types
10773+
// when the exact same arguments are provided again later.
10774+
10775+
SemanticsVisitor visitor(this, sink);
10776+
10777+
10778+
ExistentialTypeSlots slots;
10779+
_collectExistentialTypeParamsRec(slots, unspecializedType);
10780+
10781+
assert(slots.paramTypes.getCount() == argCount);
10782+
10783+
for( Int aa = 0; aa < argCount; ++aa )
10784+
{
10785+
auto argType = args[aa];
10786+
10787+
ExistentialTypeSlots::Arg arg;
10788+
arg.type = argType;
10789+
arg.witness = visitor.tryGetSubtypeWitness(argType, slots.paramTypes[aa]);
10790+
slots.args.add(arg);
10791+
}
10792+
10793+
RefPtr<ExistentialSpecializedType> specializedType = new ExistentialSpecializedType();
10794+
specializedType->baseType = unspecializedType;
10795+
specializedType->slots = slots;
10796+
10797+
m_specializedTypes.add(specializedType);
10798+
10799+
return specializedType;
10800+
}
1076610801

1076710802
/// Specialize a program to global generic arguments
1076810803
RefPtr<Program> createSpecializedProgram(

source/slang/compiler.h

+8-25
Original file line numberDiff line numberDiff line change
@@ -133,31 +133,6 @@ namespace Slang
133133
ComPtr<ISlangBlob> blob;
134134
};
135135

136-
/// Collects information about existential type parameters and their arguments.
137-
struct ExistentialTypeSlots
138-
{
139-
/// For each type parameter, holds the interface/existential type that constrains it.
140-
List<RefPtr<Type>> paramTypes;
141-
142-
/// An argument for an existential type parameter.
143-
///
144-
/// Comprises a concrete type and a witness for its conformance to the desired
145-
/// interface/existential type for the corresponding parameter.
146-
///
147-
struct Arg
148-
{
149-
RefPtr<Type> type;
150-
RefPtr<Val> witness;
151-
};
152-
153-
/// Any arguments provided for the existential type parameters.
154-
///
155-
/// It is possible for `args` to be empty even if `paramTypes` is non-empty;
156-
/// that situation represents an unspecialized program or entry point.
157-
///
158-
List<Arg> args;
159-
};
160-
161136
/// Information collected about global or entry-point shader parameters
162137
struct ShaderParamInfo
163138
{
@@ -665,6 +640,12 @@ namespace Slang
665640

666641
RefPtr<Expr> parseTypeString(String typeStr, RefPtr<Scope> scope);
667642

643+
Type* specializeType(
644+
Type* unspecializedType,
645+
Int argCount,
646+
Type* const* args,
647+
DiagnosticSink* sink);
648+
668649
/// Add a mew target amd return its index.
669650
UInt addTarget(
670651
CodeGenTarget target);
@@ -754,6 +735,8 @@ namespace Slang
754735

755736
/// Is the given module in the middle of being imported?
756737
bool isBeingImported(Module* module);
738+
739+
List<RefPtr<Type>> m_specializedTypes;
757740
};
758741

759742
/// Shared functionality between front- and back-end compile requests.

source/slang/diagnostics.h

+2
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ namespace Slang
227227
/// During propagation of an exception for an internal
228228
/// error, note that this source location was involved
229229
void noteInternalErrorLoc(SourceLoc const& loc);
230+
231+
SlangResult getBlobIfNeeded(ISlangBlob** outBlob);
230232
};
231233

232234
/// An `ISlangWriter` that writes directly to a diagnostic sink.

source/slang/ir-legalize-types.cpp

+39-15
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ static LegalVal declareVars(
127127
LegalVarChain const& varChain,
128128
UnownedStringSlice nameHint,
129129
IRInst* leafVar,
130-
IRGlobalNameInfo* globalNameInfo);
130+
IRGlobalNameInfo* globalNameInfo,
131+
bool isSpecial);
131132

132133
/// Unwrap a value with flavor `wrappedBuffer`
133134
///
@@ -1266,9 +1267,10 @@ static LegalVal legalizeLocalVar(
12661267
IRVar* irLocalVar)
12671268
{
12681269
// Legalize the type for the variable's value
1270+
auto originalValueType = irLocalVar->getDataType()->getValueType();
12691271
auto legalValueType = legalizeType(
12701272
context,
1271-
irLocalVar->getDataType()->getValueType());
1273+
originalValueType);
12721274

12731275
auto originalRate = irLocalVar->getRate();
12741276

@@ -1311,7 +1313,7 @@ static LegalVal legalizeLocalVar(
13111313

13121314
UnownedStringSlice nameHint = findNameHint(irLocalVar);
13131315
context->builder->setInsertBefore(irLocalVar);
1314-
LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nameHint, irLocalVar, nullptr);
1316+
LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nameHint, irLocalVar, nullptr, context->isSpecialType(originalValueType));
13151317

13161318
// Remove the old local var.
13171319
irLocalVar->removeFromParent();
@@ -1345,7 +1347,7 @@ static LegalVal legalizeParam(
13451347
UnownedStringSlice nameHint = findNameHint(originalParam);
13461348

13471349
context->builder->setInsertBefore(originalParam);
1348-
auto newVal = declareVars(context, kIROp_Param, legalParamType, nullptr, LegalVarChain(), nameHint, originalParam, nullptr);
1350+
auto newVal = declareVars(context, kIROp_Param, legalParamType, nullptr, LegalVarChain(), nameHint, originalParam, nullptr, context->isSpecialType(originalParam->getDataType()));
13491351

13501352
originalParam->removeFromParent();
13511353
context->replacedInstructions.add(originalParam);
@@ -2219,12 +2221,31 @@ static LegalVal declareVars(
22192221
IRTypeLegalizationContext* context,
22202222
IROp op,
22212223
LegalType type,
2222-
TypeLayout* typeLayout,
2223-
LegalVarChain const& varChain,
2224+
TypeLayout* inTypeLayout,
2225+
LegalVarChain const& inVarChain,
22242226
UnownedStringSlice nameHint,
22252227
IRInst* leafVar,
2226-
IRGlobalNameInfo* globalNameInfo)
2228+
IRGlobalNameInfo* globalNameInfo,
2229+
bool isSpecial)
22272230
{
2231+
LegalVarChain varChain = inVarChain;
2232+
TypeLayout* typeLayout = inTypeLayout;
2233+
if( isSpecial )
2234+
{
2235+
if( varChain.pendingChain )
2236+
{
2237+
varChain.primaryChain = varChain.pendingChain;
2238+
varChain.pendingChain = nullptr;
2239+
}
2240+
if( typeLayout )
2241+
{
2242+
if( auto pendingTypeLayout = typeLayout->pendingDataTypeLayout )
2243+
{
2244+
typeLayout = pendingTypeLayout;
2245+
}
2246+
}
2247+
}
2248+
22282249
switch (type.flavor)
22292250
{
22302251
case LegalType::Flavor::none:
@@ -2247,16 +2268,17 @@ static LegalVal declareVars(
22472268
varChain,
22482269
nameHint,
22492270
leafVar,
2250-
globalNameInfo);
2271+
globalNameInfo,
2272+
isSpecial);
22512273
return LegalVal::implicitDeref(val);
22522274
}
22532275
break;
22542276

22552277
case LegalType::Flavor::pair:
22562278
{
22572279
auto pairType = type.getPair();
2258-
auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, nameHint, leafVar, globalNameInfo);
2259-
auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, nameHint, leafVar, globalNameInfo);
2280+
auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, false);
2281+
auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, true);
22602282
return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo);
22612283
}
22622284

@@ -2305,7 +2327,8 @@ static LegalVal declareVars(
23052327
newVarChain,
23062328
fieldNameHint,
23072329
ee.key,
2308-
globalNameInfo);
2330+
globalNameInfo,
2331+
true);
23092332

23102333
TuplePseudoVal::Element element;
23112334
element.key = ee.key;
@@ -2348,9 +2371,10 @@ static LegalVal legalizeGlobalVar(
23482371
IRGlobalVar* irGlobalVar)
23492372
{
23502373
// Legalize the type for the variable's value
2374+
auto originalValueType = irGlobalVar->getDataType()->getValueType();
23512375
auto legalValueType = legalizeType(
23522376
context,
2353-
irGlobalVar->getDataType()->getValueType());
2377+
originalValueType);
23542378

23552379
switch (legalValueType.flavor)
23562380
{
@@ -2373,7 +2397,7 @@ static LegalVal legalizeGlobalVar(
23732397

23742398
UnownedStringSlice nameHint = findNameHint(irGlobalVar);
23752399
context->builder->setInsertBefore(irGlobalVar);
2376-
LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalVar, &globalNameInfo);
2400+
LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalVar, &globalNameInfo, context->isSpecialType(originalValueType));
23772401

23782402
// Register the new value as the replacement for the old
23792403
registerLegalizedValue(context, irGlobalVar, newVal);
@@ -2417,7 +2441,7 @@ static LegalVal legalizeGlobalConstant(
24172441

24182442
UnownedStringSlice nameHint = findNameHint(irGlobalConstant);
24192443
context->builder->setInsertBefore(irGlobalConstant);
2420-
LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalConstant, &globalNameInfo);
2444+
LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalConstant, &globalNameInfo, context->isSpecialType(irGlobalConstant->getDataType()));
24212445

24222446
// Register the new value as the replacement for the old
24232447
registerLegalizedValue(context, irGlobalConstant, newVal);
@@ -2466,7 +2490,7 @@ static LegalVal legalizeGlobalParam(
24662490

24672491
UnownedStringSlice nameHint = findNameHint(irGlobalParam);
24682492
context->builder->setInsertBefore(irGlobalParam);
2469-
LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, irGlobalParam, &globalNameInfo);
2493+
LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, irGlobalParam, &globalNameInfo, context->isSpecialType(irGlobalParam->getDataType()));
24702494

24712495
// Register the new value as the replacement for the old
24722496
registerLegalizedValue(context, irGlobalParam, newVal);

source/slang/lower-to-ir.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,24 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
16081608
return LoweredValInfo::simple(irType);
16091609
}
16101610

1611+
LoweredValInfo visitExistentialSpecializedType(ExistentialSpecializedType* type)
1612+
{
1613+
auto irBaseType = lowerType(context, type->baseType);
1614+
1615+
List<IRInst*> slotArgs;
1616+
for(auto arg : type->slots.args)
1617+
{
1618+
auto irArgType = lowerType(context, arg.type);
1619+
auto irArgWitness = lowerSimpleVal(context, arg.witness);
1620+
1621+
slotArgs.add(irArgType);
1622+
slotArgs.add(irArgWitness);
1623+
}
1624+
1625+
auto irType = getBuilder()->getBindExistentialsType(irBaseType, slotArgs.getCount(), slotArgs.getBuffer());
1626+
return LoweredValInfo::simple(irType);
1627+
}
1628+
16111629
// We do not expect to encounter the following types in ASTs that have
16121630
// passed front-end semantic checking.
16131631
#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); }

0 commit comments

Comments
 (0)