Skip to content

Commit b2ad8e9

Browse files
authored
Add API to control interface specialization. (shader-slang#1925)
1 parent 33f7e15 commit b2ad8e9

14 files changed

+470
-7
lines changed

slang.h

+33
Original file line numberDiff line numberDiff line change
@@ -3050,6 +3050,7 @@ namespace slang
30503050
typedef ISlangBlob IBlob;
30513051

30523052
struct IComponentType;
3053+
struct ITypeConformance;
30533054
struct IGlobalSession;
30543055
struct IModule;
30553056
struct ISession;
@@ -4023,6 +4024,32 @@ namespace slang
40234024
*/
40244025
virtual SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(
40254026
SlangCompileRequest** outCompileRequest) = 0;
4027+
4028+
4029+
/** Creates a `IComponentType` that represents a type's conformance to an interface.
4030+
The retrieved `ITypeConformance` objects can be included in a composite `IComponentType`
4031+
to explicitly specify which implementation types should be included in the final compiled
4032+
code. For example, if an module defines `IMaterial` interface and `AMaterial`,
4033+
`BMaterial`, `CMaterial` types that implements the interface, the user can exclude
4034+
`CMaterial` implementation from the resulting shader code by explcitly adding
4035+
`AMaterial:IMaterial` and `BMaterial:IMaterial` conformances to a composite
4036+
`IComponentType` and get entry point code from it. The resulting code will not have
4037+
anything related to `CMaterial` in the dynamic dispatch logic. If the user does not
4038+
explicitly include any `TypeConformances` to an interface type, all implementations to
4039+
that interface will be included by default. By linking a `ITypeConformance`, the user is
4040+
also given the opportunity to specify the dispatch ID of the implementation type. If
4041+
`conformanceIdOverride` is -1, there will be no override behavior and Slang will
4042+
automatically assign IDs to implementation types. The automatically assigned IDs can be
4043+
queried via `ISession::getTypeConformanceWitnessSequentialID`.
4044+
4045+
Returns SLANG_OK if succeeds, or SLANG_FAIL if `type` does not conform to `interfaceType`.
4046+
*/
4047+
virtual SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType(
4048+
slang::TypeReflection* type,
4049+
slang::TypeReflection* interfaceType,
4050+
ITypeConformance** outConformance,
4051+
SlangInt conformanceIdOverride,
4052+
ISlangBlob** outDiagnostics) = 0;
40264053
};
40274054

40284055
#define SLANG_UUID_ISession ISession::getTypeGuid()
@@ -4204,6 +4231,12 @@ namespace slang
42044231

42054232
#define SLANG_UUID_IEntryPoint IEntryPoint::getTypeGuid()
42064233

4234+
struct ITypeConformance : public IComponentType
4235+
{
4236+
SLANG_COM_INTERFACE(0x73eb3147, 0xe544, 0x41b5, { 0xb8, 0xf0, 0xa2, 0x44, 0xdf, 0x21, 0x94, 0xb })
4237+
};
4238+
#define SLANG_UUID_ITypeConformance ITypeConformance::getTypeGuid()
4239+
42074240
/** A module is the granularity of shader code compilation and loading.
42084241
42094242
In most cases a module corresponds to a single compile "translation unit."

source/slang/slang-compiler.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,92 @@ namespace Slang
296296
return empty;
297297
}
298298

299+
TypeConformance::TypeConformance(
300+
Linkage* linkage,
301+
SubtypeWitness* witness,
302+
Int confomrmanceIdOverride,
303+
DiagnosticSink* sink)
304+
: ComponentType(linkage)
305+
, m_subtypeWitness(witness)
306+
, m_conformanceIdOverride(confomrmanceIdOverride)
307+
{
308+
addDepedencyFromWitness(witness);
309+
m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink);
310+
}
311+
312+
void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness)
313+
{
314+
if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness))
315+
{
316+
auto declModule = getModule(declaredWitness->declRef.getDecl());
317+
m_moduleDependency.addDependency(declModule);
318+
m_pathDependency.addDependency(declModule);
319+
if (m_requirementSet.Add(declModule))
320+
{
321+
m_requirements.add(declModule);
322+
}
323+
// TODO: handle the specialization arguments in declaredWitness->declRef.substitutions.
324+
}
325+
else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness))
326+
{
327+
addDepedencyFromWitness(transitiveWitness->midToSup);
328+
addDepedencyFromWitness(transitiveWitness->subToMid);
329+
}
330+
else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness))
331+
{
332+
auto left = as<SubtypeWitness>(conjunctionWitness->leftWitness);
333+
if (left)
334+
addDepedencyFromWitness(left);
335+
auto right = as<SubtypeWitness>(conjunctionWitness->rightWitness);
336+
if (right)
337+
addDepedencyFromWitness(right);
338+
}
339+
}
340+
341+
ISlangUnknown* TypeConformance::getInterface(const Guid& guid)
342+
{
343+
if (guid == slang::ITypeConformance::getTypeGuid())
344+
return static_cast<slang::ITypeConformance*>(this);
345+
346+
return Super::getInterface(guid);
347+
}
348+
349+
List<Module*> const& TypeConformance::getModuleDependencies()
350+
{
351+
return m_moduleDependency.getModuleList();
352+
}
353+
354+
List<String> const& TypeConformance::getFilePathDependencies()
355+
{
356+
return m_pathDependency.getFilePathList();
357+
}
358+
359+
Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); }
360+
361+
RefPtr<ComponentType> TypeConformance::getRequirement(Index index)
362+
{
363+
return m_requirements[index];
364+
}
365+
366+
void TypeConformance::acceptVisitor(
367+
ComponentTypeVisitor* visitor,
368+
ComponentType::SpecializationInfo* specializationInfo)
369+
{
370+
SLANG_UNUSED(specializationInfo);
371+
visitor->visitTypeConformance(this);
372+
}
373+
374+
RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl(
375+
SpecializationArg const* args,
376+
Index argCount,
377+
DiagnosticSink* sink)
378+
{
379+
SLANG_UNUSED(args);
380+
SLANG_UNUSED(argCount);
381+
SLANG_UNUSED(sink);
382+
return nullptr;
383+
}
384+
299385
//
300386

301387
Profile Profile::lookUp(UnownedStringSlice const& name)

source/slang/slang-compiler.h

+126
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,125 @@ namespace Slang
850850
Profile m_profile;
851851
};
852852

853+
class TypeConformance
854+
: public ComponentType
855+
, public slang::ITypeConformance
856+
{
857+
typedef ComponentType Super;
858+
859+
public:
860+
SLANG_REF_OBJECT_IUNKNOWN_ALL
861+
862+
ISlangUnknown* getInterface(const Guid& guid);
863+
864+
TypeConformance(
865+
Linkage* linkage,
866+
SubtypeWitness* witness,
867+
Int confomrmanceIdOverride,
868+
DiagnosticSink* sink);
869+
870+
// Forward `IComponentType` methods
871+
872+
SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE
873+
{
874+
return Super::getSession();
875+
}
876+
877+
SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL
878+
getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE
879+
{
880+
return Super::getLayout(targetIndex, outDiagnostics);
881+
}
882+
883+
SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode(
884+
SlangInt entryPointIndex,
885+
SlangInt targetIndex,
886+
slang::IBlob** outCode,
887+
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
888+
{
889+
return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
890+
}
891+
892+
SLANG_NO_THROW SlangResult SLANG_MCALL specialize(
893+
slang::SpecializationArg const* specializationArgs,
894+
SlangInt specializationArgCount,
895+
slang::IComponentType** outSpecializedComponentType,
896+
ISlangBlob** outDiagnostics) SLANG_OVERRIDE
897+
{
898+
return Super::specialize(
899+
specializationArgs,
900+
specializationArgCount,
901+
outSpecializedComponentType,
902+
outDiagnostics);
903+
}
904+
905+
SLANG_NO_THROW SlangResult SLANG_MCALL link(
906+
slang::IComponentType** outLinkedComponentType,
907+
ISlangBlob** outDiagnostics) SLANG_OVERRIDE
908+
{
909+
return Super::link(outLinkedComponentType, outDiagnostics);
910+
}
911+
912+
SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(
913+
int entryPointIndex,
914+
int targetIndex,
915+
ISlangSharedLibrary** outSharedLibrary,
916+
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
917+
{
918+
return Super::getEntryPointHostCallable(
919+
entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
920+
}
921+
922+
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE;
923+
List<String> const& getFilePathDependencies() SLANG_OVERRIDE;
924+
925+
SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; }
926+
927+
/// Get the existential type parameter at `index`.
928+
SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE
929+
{
930+
static SpecializationParam emptyParam;
931+
return emptyParam;
932+
}
933+
934+
Index getRequirementCount() SLANG_OVERRIDE;
935+
RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
936+
Index getEntryPointCount() SLANG_OVERRIDE { return 0; };
937+
RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE
938+
{
939+
SLANG_UNUSED(index);
940+
return nullptr;
941+
}
942+
String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; }
943+
944+
Index getShaderParamCount() SLANG_OVERRIDE { return 0; }
945+
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE
946+
{
947+
SLANG_UNUSED(index);
948+
return ShaderParamInfo();
949+
}
950+
951+
SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; }
952+
IRModule* getIRModule() { return m_irModule.Ptr(); }
953+
protected:
954+
void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
955+
SLANG_OVERRIDE;
956+
957+
RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
958+
SpecializationArg const* args,
959+
Index argCount,
960+
DiagnosticSink* sink) SLANG_OVERRIDE;
961+
private:
962+
SubtypeWitness* m_subtypeWitness;
963+
ModuleDependencyList m_moduleDependency;
964+
FilePathDependencyList m_pathDependency;
965+
List<RefPtr<Module>> m_requirements;
966+
HashSet<Module*> m_requirementSet;
967+
RefPtr<IRModule> m_irModule;
968+
Int m_conformanceIdOverride;
969+
void addDepedencyFromWitness(SubtypeWitness* witness);
970+
};
971+
853972
enum class PassThroughMode : SlangPassThroughIntegral
854973
{
855974
None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler
@@ -1319,6 +1438,12 @@ namespace Slang
13191438
slang::TypeReflection* type,
13201439
slang::TypeReflection* interfaceType,
13211440
uint32_t* outId) override;
1441+
SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType(
1442+
slang::TypeReflection* type,
1443+
slang::TypeReflection* interfaceType,
1444+
slang::ITypeConformance** outConformance,
1445+
SlangInt conformanceIdOverride,
1446+
ISlangBlob** outDiagnostics) override;
13221447
SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(
13231448
SlangCompileRequest** outCompileRequest) override;
13241449

@@ -1756,6 +1881,7 @@ namespace Slang
17561881
virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0;
17571882
virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0;
17581883
virtual void visitSpecialized(SpecializedComponentType* specialized) = 0;
1884+
virtual void visitTypeConformance(TypeConformance* conformance) = 0;
17591885

17601886
protected:
17611887
// These helpers can be used to recurse into the logical children of a

source/slang/slang-ir-link.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
274274

275275
default:
276276
{
277-
// In the deafult case, assume that we have some sort of "hoistable"
277+
// In the default case, assume that we have some sort of "hoistable"
278278
// instruction that requires us to create a clone of it.
279279

280280
UInt argCount = originalValue->getOperandCount();
@@ -439,6 +439,8 @@ static void cloneExtraDecorations(
439439

440440
case kIROp_BindExistentialSlotsDecoration:
441441
case kIROp_LayoutDecoration:
442+
case kIROp_PublicDecoration:
443+
case kIROp_SequentialIDDecoration:
442444
if(!clonedInst->findDecorationImpl(decoration->getOp()))
443445
{
444446
cloneInst(context, builder, decoration);

source/slang/slang-ir-specialize.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2158,6 +2158,11 @@ IRInst* specializeGenericImpl(
21582158
if( auto returnValInst = as<IRReturnVal>(ii) )
21592159
{
21602160
auto specializedVal = findCloneForOperand(&env, returnValInst->getVal());
2161+
2162+
// Clone decorations on the orignal `specialize` inst over to the newly specialized
2163+
// value.
2164+
cloneInstDecorationsAndChildren(
2165+
&env, &sharedBuilderStorage, specializeInst, specializedVal);
21612166
return specializedVal;
21622167
}
21632168

0 commit comments

Comments
 (0)