Skip to content

Commit 54bf54b

Browse files
csyongheTim Foley
authored and
Tim Foley
committed
Add support for global generic parameters (shader-slang#285)
* Add support for global generic parameters (In-progress work) This commit include: 1. Update Slang API to allow specification of generic type arguments in an `EntryPointRequest` 2. Add parsing of `__generic_param` construct, which becomes a GlobalGenericParamDecl, contains members of `GenericTypeConstraintDecl`. 3. Semantics checking will check whether the provided type arguments conform to the interfaces as defined by the generic parameter, and store SubtypeWitness values in the EntryPointRequest, which will be used by `specializeIRForEntryPoint` when generating final IR. 4. Add a new type of substitution - `GlobalGenericParamSubstitution` for subsittuting references to `__generic_param` decls or to its member `GenericTypeConsraintDecl` with the actual type argument or witness tables. 5. Update `IRSpecContext` to apply `GlobalGenericParamSubstitution` when specializing the IR for an EntryPointRequest. 6. Update `render-test` to take additional `type` inputs, which specifies the type arguments to substitute into the global `__generic_param` types. This commit does not include ProgramLayout specialization. * IR: pass through `[unroll]` attribute (shader-slang#284) The initial lowering was adding an `IRLoopControlDecoration` to the instruction at the head of a loop, but this was getting dropped when the IR gets cloned for a particular entry point. The fix was simply to add a case for loop-control decorations to `cloneDecoration`. * fix warnings * IR: support `CompileTimeForStmt` (shader-slang#286) This statement type is a bit of a hack, to support loops that *must* be unrolled. The AST-to-AST pass handles them by cloning the AST for the loop body N times, and it was easy enough to do the same thing for the IR: emit the instructions for the body N times. The only thing that requires a bit of care is that now we might see the same variable declarations multiple times, so we need to play it safe and overwrite existing entries in our map from declarations to their IR values. Of course a better answer long-term would be to do the actual unrolling in the IR. This is especially true because we might some day want to support compile-time/must-unroll loops in functions, where the loop counter comes in as a parameter (but must still be compile-time-constant at every call site). * Add support for global generic parameters (In-progress work) This commit include: 1. Update Slang API to allow specification of generic type arguments in an `EntryPointRequest` 2. Add parsing of `__generic_param` construct, which becomes a GlobalGenericParamDecl, contains members of `GenericTypeConstraintDecl`. 3. Semantics checking will check whether the provided type arguments conform to the interfaces as defined by the generic parameter, and store SubtypeWitness values in the EntryPointRequest, which will be used by `specializeIRForEntryPoint` when generating final IR. 4. Add a new type of substitution - `GlobalGenericParamSubstitution` for subsittuting references to `__generic_param` decls or to its member `GenericTypeConsraintDecl` with the actual type argument or witness tables. 5. Update `IRSpecContext` to apply `GlobalGenericParamSubstitution` when specializing the IR for an EntryPointRequest. 6. Update `render-test` to take additional `type` inputs, which specifies the type arguments to substitute into the global `__generic_param` types. progress on parameter binding * Add a more contrived test case for specializing parameter bindings * update render-test to align buffers to 256 bytes (to get rid of D3D complains on minimal buffer size). * adding one more test case for parameter binding specialization. * Cleanup according to @tfoleyNV 's suggestions. * fix a bug introduced in the cleanup
1 parent 0298a04 commit 54bf54b

35 files changed

+1123
-274
lines changed

slang.h

+28
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,18 @@ extern "C"
377377
char const* name,
378378
SlangProfileID profile);
379379

380+
/** Add an entry point in a particular translation unit,
381+
with additional arguments that specify the concrete
382+
type names for global generic type parameters.
383+
*/
384+
SLANG_API int spAddEntryPointEx(
385+
SlangCompileRequest* request,
386+
int translationUnitIndex,
387+
char const* name,
388+
SlangProfileID profile,
389+
int genericTypeNameCount,
390+
char const** genericTypeNames);
391+
380392
/** Execute the compilation request.
381393
382394
Returns zero on success, non-zero on failure.
@@ -588,6 +600,9 @@ extern "C"
588600
// HLSL register `space`, Vulkan GLSL `set`
589601
SLANG_PARAMETER_CATEGORY_REGISTER_SPACE,
590602

603+
// A parameter whose type is to be specialized by a global generic type argument
604+
SLANG_PARAMETER_CATEGORY_GENERIC,
605+
591606
//
592607
SLANG_PARAMETER_CATEGORY_COUNT,
593608
};
@@ -695,6 +710,8 @@ extern "C"
695710
SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* reflection);
696711

697712
SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* reflection, SlangUInt index);
713+
SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection);
714+
SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection);
698715

699716
#ifdef __cplusplus
700717
}
@@ -848,6 +865,7 @@ namespace slang
848865
SpecializationConstant = SLANG_PARAMETER_CATEGORY_SPECIALIZATION_CONSTANT,
849866
PushConstantBuffer = SLANG_PARAMETER_CATEGORY_PUSH_CONSTANT_BUFFER,
850867
RegisterSpace = SLANG_PARAMETER_CATEGORY_REGISTER_SPACE,
868+
GenericResource = SLANG_PARAMETER_CATEGORY_GENERIC,
851869
};
852870

853871
struct TypeLayoutReflection
@@ -1102,6 +1120,16 @@ namespace slang
11021120
{
11031121
return (EntryPointReflection*) spReflection_getEntryPointByIndex((SlangReflection*) this, index);
11041122
}
1123+
1124+
SlangUInt getGlobalConstantBufferBinding()
1125+
{
1126+
return spReflection_getGlobalConstantBufferBinding((SlangReflection*)this);
1127+
}
1128+
1129+
size_t getGlobalConstantBufferSize()
1130+
{
1131+
return spReflection_getGlobalConstantBufferSize((SlangReflection*)this);
1132+
}
11051133
};
11061134
}
11071135

source/core/list.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ namespace Slang
487487
if (predicate(buffer[i]))
488488
return i;
489489
}
490-
return -1;
490+
return (UInt)-1;
491491
}
492492

493493
template<typename Func>
@@ -498,7 +498,7 @@ namespace Slang
498498
if (predicate(buffer[i-1]))
499499
return i-1;
500500
}
501-
return -1;
501+
return (UInt)-1;
502502
}
503503

504504
template<typename T2>
@@ -509,7 +509,7 @@ namespace Slang
509509
if (buffer[i] == val)
510510
return i;
511511
}
512-
return -1;
512+
return (UInt)-1;
513513
}
514514

515515
template<typename T2>
@@ -520,7 +520,7 @@ namespace Slang
520520
if(buffer[i-1] == val)
521521
return i-1;
522522
}
523-
return -1;
523+
return (UInt)-1;
524524
}
525525

526526
void Sort()

source/slang/check.cpp

+101-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ namespace Slang
148148
return expr->type->As<DeclRefType>();
149149
}
150150

151-
152151
RefPtr<Expr> ConstructDeclRefExpr(
153152
DeclRef<Decl> declRef,
154153
RefPtr<Expr> baseExpr,
@@ -1998,6 +1997,22 @@ namespace Slang
19981997
decl->SetCheckState(DeclCheckState::Checked);
19991998
}
20001999

2000+
void visitGlobalGenericParamDecl(GlobalGenericParamDecl * decl)
2001+
{
2002+
if (decl->IsChecked(DeclCheckState::Checked)) return;
2003+
decl->SetCheckState(DeclCheckState::CheckedHeader);
2004+
// global generic param only allowed in global scope
2005+
auto program = decl->ParentDecl->As<ModuleDecl>();
2006+
if (!program)
2007+
getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly);
2008+
// Now check all of the member declarations.
2009+
for (auto member : decl->Members)
2010+
{
2011+
checkDecl(member);
2012+
}
2013+
decl->SetCheckState(DeclCheckState::Checked);
2014+
}
2015+
20012016
void visitAssocTypeDecl(AssocTypeDecl* decl)
20022017
{
20032018
if (decl->IsChecked(DeclCheckState::Checked)) return;
@@ -3703,6 +3718,19 @@ namespace Slang
37033718
return true;
37043719
}
37053720
}
3721+
// if an inheritance decl is not found, try to find a GenericTypeConstraintDecl
3722+
for (auto genConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(aggTypeDeclRef))
3723+
{
3724+
EnsureDecl(genConstraintDeclRef.getDecl());
3725+
auto inheritedType = GetSup(genConstraintDeclRef);
3726+
TypeWitnessBreadcrumb breadcrumb;
3727+
breadcrumb.prev = inBreadcrumbs;
3728+
breadcrumb.declRef = genConstraintDeclRef;
3729+
if (doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb))
3730+
{
3731+
return true;
3732+
}
3733+
}
37063734
}
37073735
else if( auto genericTypeParamDeclRef = declRef.As<GenericTypeParamDecl>() )
37083736
{
@@ -6582,6 +6610,78 @@ namespace Slang
65826610
// that we don't have to re-do this effort again later.
65836611
entryPoint->decl = entryPointFuncDecl;
65846612

6613+
// Lookup generic parameter types in global scope
6614+
for (auto name : entryPoint->genericParameterTypeNames)
6615+
{
6616+
if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName))
6617+
{
6618+
// If there doesn't appear to be any such declaration, then
6619+
// we need to diagnose it as an error, and then bail out.
6620+
sink->diagnose(translationUnitSyntax, Diagnostics::entryPointTypeParameterNotFound, name);
6621+
return;
6622+
}
6623+
RefPtr<Type> type;
6624+
if (auto aggType = firstDeclWithName->As<AggTypeDecl>())
6625+
{
6626+
type = DeclRefType::Create(entryPoint->compileRequest->mSession, DeclRef<Decl>(aggType, nullptr));
6627+
}
6628+
else if (auto typeDefDecl = firstDeclWithName->As<TypeDefDecl>())
6629+
{
6630+
type = GetType(DeclRef<TypeDefDecl>(typeDefDecl, nullptr));
6631+
}
6632+
else
6633+
{
6634+
sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name);
6635+
return;
6636+
}
6637+
entryPoint->genericParameterTypes.Add(type);
6638+
}
6639+
// check that user-provioded type arguments conforms to the generic type
6640+
// parameter declaration of this translation unit
6641+
6642+
// collect global generic parameters from all imported modules
6643+
List<RefPtr<GlobalGenericParamDecl>> globalGenericParams;
6644+
// add current translation unit first
6645+
{
6646+
auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>();
6647+
for (auto p : globalGenParams)
6648+
globalGenericParams.Add(p);
6649+
}
6650+
// add imported modules
6651+
for (auto moduleDecl : entryPoint->compileRequest->loadedModulesList)
6652+
{
6653+
auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>();
6654+
for (auto p : globalGenParams)
6655+
globalGenericParams.Add(p);
6656+
}
6657+
if (globalGenericParams.Count() != entryPoint->genericParameterTypes.Count())
6658+
{
6659+
sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, globalGenericParams.Count(),
6660+
entryPoint->genericParameterTypes.Count());
6661+
return;
6662+
}
6663+
// if number of entry-point type arguments matches parameters, try find
6664+
// SubtypeWitness for each argument
6665+
int index = 0;
6666+
for (auto & gParam : globalGenericParams)
6667+
{
6668+
for (auto constraint : gParam->getMembersOfType<GenericTypeConstraintDecl>())
6669+
{
6670+
auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr));
6671+
SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit);
6672+
auto witness = visitor.tryGetSubtypeWitness(entryPoint->genericParameterTypes[index], interfaceType);
6673+
if (!witness)
6674+
{
6675+
sink->diagnose(gParam,
6676+
Diagnostics::typeArgumentDoesNotConformToInterface, gParam->nameAndLoc.name, entryPoint->genericParameterTypes[index],
6677+
interfaceType);
6678+
}
6679+
entryPoint->genericParameterWitnesses.Add(witness);
6680+
}
6681+
index++;
6682+
}
6683+
if (sink->errorCount != 0)
6684+
return;
65856685
// TODO: after all that work, we are now in a position to start
65866686
// validating the declaration itself. E.g., we should check if
65876687
// the declared input/output parameters have suitable semantics,

source/slang/compiler.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "parser.h"
1212
#include "preprocessor.h"
1313
#include "syntax-visitors.h"
14-
14+
#include "type-layout.h"
1515
#include "reflection.h"
1616
#include "emit.h"
1717

@@ -160,7 +160,7 @@ namespace Slang
160160
entryPoint,
161161
targetReq->layout.Ptr(),
162162
CodeGenTarget::HLSL,
163-
targetReq->target);
163+
targetReq);
164164
}
165165
}
166166

@@ -207,7 +207,7 @@ namespace Slang
207207
entryPoint,
208208
targetReq->layout.Ptr(),
209209
CodeGenTarget::GLSL,
210-
targetReq->target);
210+
targetReq);
211211
}
212212
}
213213

source/slang/compiler.h

+11-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ namespace Slang
100100

101101
// The name of the entry point function (e.g., `main`)
102102
Name* name;
103+
104+
// The type names we want to substitute into the
105+
// global generic type parameters
106+
List<Name*> genericParameterTypeNames;
103107

104108
// The profile that the entry point will be compiled for
105109
// (this is a combination of the target state, and also
@@ -123,6 +127,11 @@ namespace Slang
123127
// it should not be assumed to be available in cases
124128
// where any errors were diagnosed.
125129
RefPtr<FuncDecl> decl;
130+
131+
// The declaration of the global generic parameter types
132+
// This will be filled in as part of semantic analysis.
133+
List<RefPtr<Type>> genericParameterTypes;
134+
List<RefPtr<Val>> genericParameterWitnesses;
126135
};
127136

128137
enum class PassThroughMode : SlangPassThrough
@@ -319,7 +328,8 @@ namespace Slang
319328
int addEntryPoint(
320329
int translationUnitIndex,
321330
String const& name,
322-
Profile profile);
331+
Profile profile,
332+
List<String> const & genericTypeNames);
323333

324334
UInt addTarget(
325335
CodeGenTarget target);

source/slang/decl-defs.h

+5
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ END_SYNTAX_CLASS()
126126
SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl)
127127
END_SYNTAX_CLASS()
128128

129+
// A '__generic_param' declaration, which defines a generic
130+
// entry-point parameter. Is a container of GenericTypeConstraintDecl
131+
SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl)
132+
END_SYNTAX_CLASS()
133+
129134
// A scope for local declarations (e.g., as part of a statement)
130135
SIMPLE_SYNTAX_CLASS(ScopeDecl, ContainerDecl)
131136

source/slang/diagnostic-defs.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of
196196

197197
// 303xx: interfaces and associated types
198198
DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.")
199-
199+
DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'__generic_param' can only be defined global scope.")
200200
// TODO: need to assign numbers to all these extra diagnostics...
201201

202202
DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')")
@@ -244,11 +244,17 @@ DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches en
244244
DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'")
245245
DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function")
246246

247+
DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'")
248+
DIAGNOSTIC(38005, Error, entryPointTypeSymbolNotAType, "entry-point type parameter '$0' must be declared as a type")
249+
247250
DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'")
248251
DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type")
249252
DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration")
250253
DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration")
251254

255+
DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.")
256+
DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.")
257+
252258
//
253259
// 4xxxx - IL code generation.
254260
//
@@ -264,7 +270,6 @@ DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value seman
264270
//
265271
// 5xxxx - Target code generation.
266272
//
267-
268273
DIAGNOSTIC(50020, Error, unknownStageType, "Unknown stage type '$0'.")
269274
DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.")
270275
DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.")

source/slang/emit.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -3481,9 +3481,9 @@ struct EmitVisitor
34813481
break;
34823482

34833483
case LayoutResourceKind::RegisterSpace:
3484+
case LayoutResourceKind::GenericResource:
34843485
// ignore
34853486
break;
3486-
34873487
default:
34883488
{
34893489
Emit(": register(");
@@ -6771,7 +6771,7 @@ EntryPointLayout* findEntryPointLayout(
67716771
StructTypeLayout* getGlobalStructLayout(
67726772
ProgramLayout* programLayout)
67736773
{
6774-
auto globalScopeLayout = programLayout->globalScopeLayout;
6774+
auto globalScopeLayout = programLayout->globalScopeLayout->typeLayout;
67756775
if( auto gs = globalScopeLayout.As<StructTypeLayout>() )
67766776
{
67776777
return gs.Ptr();
@@ -6816,13 +6816,13 @@ String emitEntryPoint(
68166816
EntryPointRequest* entryPoint,
68176817
ProgramLayout* programLayout,
68186818
CodeGenTarget target,
6819-
CodeGenTarget finalTarget)
6819+
TargetRequest* targetRequest)
68206820
{
68216821
auto translationUnit = entryPoint->getTranslationUnit();
68226822

68236823
SharedEmitContext sharedContext;
68246824
sharedContext.target = target;
6825-
sharedContext.finalTarget = finalTarget;
6825+
sharedContext.finalTarget = targetRequest->target;
68266826
sharedContext.entryPoint = entryPoint;
68276827

68286828
if (entryPoint)
@@ -6890,7 +6890,8 @@ String emitEntryPoint(
68906890
auto lowered = specializeIRForEntryPoint(
68916891
entryPoint,
68926892
programLayout,
6893-
target);
6893+
target,
6894+
targetRequest);
68946895

68956896
// If the user specified the flag that they want us to dump
68966897
// IR, then do it here, for the target-specific, but

source/slang/emit.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ namespace Slang
2626
// The target language to generate code in (e.g., HLSL/GLSL)
2727
CodeGenTarget target,
2828

29-
// The "final" target that we are being asked to compile for
30-
// (e.g., SPIR-V, DXBC, ...).
31-
CodeGenTarget finalTarget);
29+
// The full target request
30+
TargetRequest* targetRequest);
3231
}
3332
#endif

0 commit comments

Comments
 (0)