From 37cbbe8639a5a9c8693e747e0533349a8907e429 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 7 Feb 2025 21:48:34 -0500 Subject: [PATCH 01/13] WIP --- source/slang/hlsl.meta.slang | 15 +++++-- source/slang/slang-emit-c-like.cpp | 4 ++ source/slang/slang-ir-inst-defs.h | 5 +++ source/slang/slang-ir-insts.h | 9 +++++ .../slang-ir-legalize-varying-params.cpp | 14 +++++++ .../slang/slang-ir-legalize-varying-params.h | 2 + source/slang/slang-ir-wgsl-legalize.cpp | 40 +++++++++++++++++-- 7 files changed, 82 insertions(+), 7 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a10e747c01..5a562b5d5e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6,6 +6,9 @@ typedef uint UINT; __intrinsic_op($(kIROp_RequireGLSLExtension)) void __requireGLSLExtension(String extensionName); +__intrinsic_op($(kIROp_ImplicitSystemValue)) +uint __implicitSystemValue(String systemValueName); + //@public: /// Represents an interface for buffer data layout. /// This interface is used as a base for defining specific data layouts for buffers. @@ -15036,7 +15039,8 @@ uint WaveActiveCountBits(bool value) __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] uint WaveGetLaneCount() { __target_switch @@ -15050,6 +15054,8 @@ uint WaveGetLaneCount() OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupSize:uint) }; + case wgsl: + return __implicitSystemValue("SV_WaveLaneCount"); } } @@ -15057,7 +15063,8 @@ uint WaveGetLaneCount() __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] uint WaveGetLaneIndex() { __target_switch @@ -15071,6 +15078,8 @@ uint WaveGetLaneIndex() OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint) }; + case wgsl: + return __implicitSystemValue("SV_WaveLaneIndex"); } } @@ -20554,7 +20563,7 @@ ${ // In order to support this approach, we need intrinsics that // can magically fetch the binding information for a resource. // -// TODO: These operations are kind of *screaming* for us to +// TODO: These operations are kind of *screaming* for us tohlsl.m // have a built-in `interface` that all of the opaque resource // types conform to, so that we can define builtins that work // for any resource type. diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index db2c0150f7..5651632349 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3055,6 +3055,10 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO emitOperand(as(inst)->getOperand(0), getInfo(EmitOp::General)); break; } + case kIROp_ImplicitSystemValue: + { + break; + } default: diagnoseUnhandledInst(inst); break; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 55880eab5d..3ff1ba3dfc 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -671,6 +671,11 @@ INST(RequireComputeDerivative, RequireComputeDerivative, 0, 0) INST(StaticAssert, StaticAssert, 2, 0) INST(Printf, Printf, 1, 0) +// Built-in inputs/outputs(system values) that are implicitly added. +// These must be passed in as entry function parameters for the target language(eg. WGSL and Metal), but +// do not explicitly originate from decorated entry point function parameters in Slang. +INST(ImplicitSystemValue, ImplicitSystemValue, 1, 0) + // Quad control execution modes. INST(RequireMaximallyReconverges, RequireMaximallyReconverges, 0, 0) INST(RequireQuadDerivatives, RequireQuadDerivatives, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9c3892c0e2..e6d3a8383a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3499,6 +3499,15 @@ struct IRStaticAssert : IRInst IR_LEAF_ISA(StaticAssert) }; +struct IRImplicitSystemValue : IRInst +{ + IR_LEAF_ISA(ImplicitSystemValue) + UnownedStringSlice getSystemValueName() + { + return as(getOperand(0))->getStringSlice(); + } +}; + struct IREmbeddedDownstreamIR : IRInst { IR_LEAF_ISA(EmbeddedDownstreamIR) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 3b65ee59af..05d640cee4 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3854,6 +3854,20 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext break; } + case SystemValueSemanticName::WaveLaneCount: + { + result.systemValueName = toSlice("subgroup_size"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + + case SystemValueSemanticName::WaveLaneIndex: + { + result.systemValueName = toSlice("subgroup_invocation_id"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + default: { m_sink->diagnose( diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index e742f30936..0a7c3be8e7 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -68,6 +68,8 @@ void depointerizeInputParams(IRFunc* entryPoint); M(Target, SV_Target) \ M(StartVertexLocation, SV_StartVertexLocation) \ M(StartInstanceLocation, SV_StartInstanceLocation) \ + M(WaveLaneCount, SV_WaveLaneCount) \ + M(WaveLaneIndex, SV_WaveLaneIndex) \ /* end */ /// A known system-value semantic name that can be applied to a parameter diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index efa028703c..d0256ae60d 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -121,7 +121,7 @@ static void legalizeSwitch(IRSwitch* switchInst) switchInst->removeAndDeallocate(); } -static void processInst(IRInst* inst) +static void processInst(IRInst* inst, IRModule* module, const List& entryPoints) { switch (inst->getOp()) { @@ -157,13 +157,42 @@ static void processInst(IRInst* inst) legalizeBinaryOp(inst); break; + case kIROp_ImplicitSystemValue: + { + const auto implicitSysVal = as(inst); + printf("Legalizing implicit sysval!\n"); + printf("%s\n", implicitSysVal->getSystemValueName().begin()); + + + for (const auto entryPoint : entryPoints) + { + // builder.addParam + // entryPoint.entryPointFunc->ad + + + IRBuilder builder(entryPoint.entryPointFunc); + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + auto param = builder.emitParam(builder.getUIntType()); + builder.addSemanticDecoration(param, implicitSysVal->getSystemValueName()); + + inst->replaceUsesWith(param); + inst->removeAndDeallocate(); + } + // param->add + + break; + } + + case kIROp_Func: legalizeFunc(static_cast(inst)); [[fallthrough]]; + default: for (auto child : inst->getModifiableChildren()) { - processInst(child); + processInst(child, module, entryPoints); } } } @@ -213,12 +242,15 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) info.entryPointDecor = entryPointDecor; info.entryPointFunc = func; entryPoints.add(info); + + // processInst(func); } + // Go through every instruction in the module and legalize them as needed. + processInst(module->getModuleInst(), module, entryPoints); + legalizeEntryPointVaryingParamsForWGSL(module, sink, entryPoints); - // Go through every instruction in the module and legalize them as needed. - processInst(module->getModuleInst()); // Some global insts are illegal, e.g. function calls. // We need to inline and remove those. From a9b57b00304242a449c15296dd7ff0ae0e27bc2b Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sun, 9 Feb 2025 13:23:26 -0500 Subject: [PATCH 02/13] more WIP --- source/slang/slang-ir-call-graph.cpp | 4 + .../slang/slang-ir-legalize-system-values.cpp | 92 +++++++++++++ .../slang/slang-ir-legalize-system-values.h | 12 ++ source/slang/slang-ir-wgsl-legalize.cpp | 128 ++++++++---------- 4 files changed, 168 insertions(+), 68 deletions(-) create mode 100644 source/slang/slang-ir-legalize-system-values.cpp create mode 100644 source/slang/slang-ir-legalize-system-values.h diff --git a/source/slang/slang-ir-call-graph.cpp b/source/slang/slang-ir-call-graph.cpp index 47b18be2ed..ea6f1d9fad 100644 --- a/source/slang/slang-ir-call-graph.cpp +++ b/source/slang/slang-ir-call-graph.cpp @@ -2,6 +2,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" +#include "slang-ir.h" namespace Slang { @@ -56,10 +57,13 @@ void buildEntryPointReferenceGraph( } switch (inst->getOp()) { + // Only these instruction types are registered to the entry point reference graph. case kIROp_GlobalParam: case kIROp_SPIRVAsmOperandBuiltinVar: + case kIROp_ImplicitSystemValue: registerEntryPointReference(entryPoint, inst); break; + case kIROp_Block: case kIROp_SPIRVAsm: for (auto child : inst->getChildren()) diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp new file mode 100644 index 0000000000..e411e41ce1 --- /dev/null +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -0,0 +1,92 @@ +#include "slang-ir-legalize-system-values.h" + +#include "core/slang-dictionary.h" +#include "slang-diagnostics.h" +#include "slang-ir-insts.h" +#include "slang-ir-legalize-varying-params.h" + +namespace Slang +{ + +class ImplicitSystemValueLegalizationContext +{ +public: + ImplicitSystemValueLegalizationContext( + const Dictionary>& entryPointReferenceGraph, + const List& implicitSystemValueInstructions) + : m_entryPointReferenceGraph(entryPointReferenceGraph) + , m_implicitSystemValueInstructions(implicitSystemValueInstructions) + { + } + + void legalize() + { + for (auto implicitSysVal : m_implicitSystemValueInstructions) + { + for (auto entryPoint : *m_entryPointReferenceGraph.tryGetValue(implicitSysVal)) + { + auto param = getOrCreateSystemValueParam(entryPoint, implicitSysVal); + implicitSysVal->replaceUsesWith(param); + implicitSysVal->removeAndDeallocate(); + } + } + } + +private: + IRParam* getOrCreateSystemValueParam(IRFunc* entryPoint, IRImplicitSystemValue* implicitSysVal) + { + if (!m_entryPointMap.containsKey(entryPoint)) + { + m_entryPointMap.add(entryPoint, SystemValueParamMap()); + } + auto systemValueParamMap = m_entryPointMap.tryGetValue(entryPoint); + + const auto systemValueName = + convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); + SLANG_ASSERT(systemValueName != SystemValueSemanticName::Unknown); + + IRParam* param; + const auto paramPtr = systemValueParamMap->tryGetValue(systemValueName); + if (paramPtr == nullptr) + { + IRBuilder builder(entryPoint); + builder.setInsertBefore(entryPoint->getFirstBlock()->getFirstOrdinaryInst()); + + // The new system value parameter's type is harcoded. + // + // Implicit system values are currently only being used for subgroup size and subgroup + // invocation id, both of which are 32-bit unsigned. + SLANG_ASSERT( + (systemValueName == SystemValueSemanticName::WaveLaneCount) || + (systemValueName == SystemValueSemanticName::WaveLaneIndex)); + param = builder.emitParam(builder.getUIntType()); + builder.addSemanticDecoration(param, implicitSysVal->getSystemValueName()); + + systemValueParamMap->add(systemValueName, param); + } + else + { + param = *paramPtr; + } + + return param; + } + + using SystemValueParamMap = Dictionary; + + const Dictionary>& m_entryPointReferenceGraph; + const List& m_implicitSystemValueInstructions; + Dictionary m_entryPointMap; +}; + +void legalizeImplicitSystemValues( + const Dictionary>& entryPointReferenceGraph, + const List& implicitSystemValueInstructions) +{ + ImplicitSystemValueLegalizationContext( + entryPointReferenceGraph, + implicitSystemValueInstructions) + .legalize(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-legalize-system-values.h b/source/slang/slang-ir-legalize-system-values.h new file mode 100644 index 0000000000..617556c40f --- /dev/null +++ b/source/slang/slang-ir-legalize-system-values.h @@ -0,0 +1,12 @@ +// slang-ir-legalize-system-values.h +#pragma once +#include "slang-ir-insts.h" + +namespace Slang +{ + +void legalizeImplicitSystemValues( + const Dictionary>& entryPointReferenceGraph, + const List& implicitSystemValueInstructions); + +} // namespace Slang diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index d0256ae60d..1530844b5c 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1,8 +1,10 @@ #include "slang-ir-wgsl-legalize.h" +#include "slang-ir-call-graph.h" #include "slang-ir-insts.h" #include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" +#include "slang-ir-legalize-system-values.h" #include "slang-ir-legalize-varying-params.h" #include "slang-ir.h" @@ -121,81 +123,63 @@ static void legalizeSwitch(IRSwitch* switchInst) switchInst->removeAndDeallocate(); } -static void processInst(IRInst* inst, IRModule* module, const List& entryPoints) +class InstructionLegalizationContext { - switch (inst->getOp()) +public: + void processInst(IRInst* inst) { - case kIROp_Call: - legalizeCall(static_cast(inst)); - break; - - case kIROp_Switch: - legalizeSwitch(as(inst)); - break; - - // For all binary operators, make sure both side of the operator have the same type - // (vector-ness and matrix-ness). - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_Div: - case kIROp_FRem: - case kIROp_IRem: - case kIROp_And: - case kIROp_Or: - case kIROp_BitAnd: - case kIROp_BitOr: - case kIROp_BitXor: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_Eql: - case kIROp_Neq: - case kIROp_Greater: - case kIROp_Less: - case kIROp_Geq: - case kIROp_Leq: - legalizeBinaryOp(inst); - break; - - case kIROp_ImplicitSystemValue: + switch (inst->getOp()) { - const auto implicitSysVal = as(inst); - printf("Legalizing implicit sysval!\n"); - printf("%s\n", implicitSysVal->getSystemValueName().begin()); - - - for (const auto entryPoint : entryPoints) - { - // builder.addParam - // entryPoint.entryPointFunc->ad - - - IRBuilder builder(entryPoint.entryPointFunc); - builder.setInsertBefore( - entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - auto param = builder.emitParam(builder.getUIntType()); - builder.addSemanticDecoration(param, implicitSysVal->getSystemValueName()); + case kIROp_Call: + legalizeCall(static_cast(inst)); + break; - inst->replaceUsesWith(param); - inst->removeAndDeallocate(); - } - // param->add + case kIROp_Switch: + legalizeSwitch(as(inst)); + break; + // For all binary operators, make sure both side of the operator have the same type + // (vector-ness and matrix-ness). + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + legalizeBinaryOp(inst); break; - } + case kIROp_ImplicitSystemValue: + implicitSystemValueInstructions.add(as(inst)); + break; - case kIROp_Func: - legalizeFunc(static_cast(inst)); - [[fallthrough]]; + case kIROp_Func: + legalizeFunc(static_cast(inst)); + [[fallthrough]]; - default: - for (auto child : inst->getModifiableChildren()) - { - processInst(child, module, entryPoints); + default: + for (auto child : inst->getModifiableChildren()) + { + processInst(child); + } } } -} + + List implicitSystemValueInstructions; +}; struct GlobalInstInliningContext : public GlobalInstInliningContextGeneric { @@ -242,15 +226,23 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) info.entryPointDecor = entryPointDecor; info.entryPointFunc = func; entryPoints.add(info); - - // processInst(func); } // Go through every instruction in the module and legalize them as needed. - processInst(module->getModuleInst(), module, entryPoints); + InstructionLegalizationContext instContext; + instContext.processInst(module->getModuleInst()); - legalizeEntryPointVaryingParamsForWGSL(module, sink, entryPoints); + // Legalize implicit system values to entry point parameters. + if (instContext.implicitSystemValueInstructions.getCount() != 0) + { + Dictionary> entryPointReferenceGraph; + buildEntryPointReferenceGraph(entryPointReferenceGraph, module); + legalizeImplicitSystemValues( + entryPointReferenceGraph, + instContext.implicitSystemValueInstructions); + } + legalizeEntryPointVaryingParamsForWGSL(module, sink, entryPoints); // Some global insts are illegal, e.g. function calls. // We need to inline and remove those. From ab83ca395a2dea6b16805eac28177741ad5f2013 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Mon, 10 Feb 2025 00:33:32 -0500 Subject: [PATCH 03/13] WIP --- source/slang/slang-ir-call-graph.cpp | 64 ++++- source/slang/slang-ir-call-graph.h | 4 +- .../slang/slang-ir-legalize-system-values.cpp | 264 +++++++++++++++++- .../slang/slang-ir-legalize-system-values.h | 2 + source/slang/slang-ir-wgsl-legalize.cpp | 11 +- 5 files changed, 323 insertions(+), 22 deletions(-) diff --git a/source/slang/slang-ir-call-graph.cpp b/source/slang/slang-ir-call-graph.cpp index ea6f1d9fad..ca68a5f4d7 100644 --- a/source/slang/slang-ir-call-graph.cpp +++ b/source/slang/slang-ir-call-graph.cpp @@ -9,12 +9,15 @@ namespace Slang void buildEntryPointReferenceGraph( Dictionary>& referencingEntryPoints, - IRModule* module) + IRModule* module, + Dictionary>* referencingFunctions, + Dictionary>* referencingCalls) { struct WorkItem { IRFunc* entryPoint; IRInst* inst; + IRFunc* parentFunc; HashCode getHashCode() const { @@ -44,14 +47,52 @@ void buildEntryPointReferenceGraph( referencingEntryPoints.add(inst, _Move(newSet)); } }; - auto visit = [&](IRFunc* entryPoint, IRInst* inst) + + const auto registerFunctionReference = [&](IRFunc* func, IRInst* inst) + { + if (referencingFunctions && func) + { + if (auto set = referencingFunctions->tryGetValue(inst)) + set->add(func); + else + { + HashSet newSet; + newSet.add(func); + referencingFunctions->add(inst, _Move(newSet)); + } + } + }; + + auto registerCallReference = [&](IRCall* call, IRFunc* func) + { + if (referencingCalls) + { + if (auto set = referencingCalls->tryGetValue(func)) + set->add(call); + else + { + HashSet newSet; + newSet.add(call); + referencingCalls->add(func, _Move(newSet)); + } + } + }; + + auto visit = [&](IRFunc* entryPoint, IRInst* inst, IRFunc* parentFunc) { if (auto code = as(inst)) { registerEntryPointReference(entryPoint, inst); + registerFunctionReference(parentFunc, inst); + + if (auto func = as(code)) + { + parentFunc = func; + } + for (auto child : code->getChildren()) { - addToWorkList({entryPoint, child}); + addToWorkList({entryPoint, child, parentFunc}); } return; } @@ -62,25 +103,29 @@ void buildEntryPointReferenceGraph( case kIROp_SPIRVAsmOperandBuiltinVar: case kIROp_ImplicitSystemValue: registerEntryPointReference(entryPoint, inst); + registerFunctionReference(parentFunc, inst); break; case kIROp_Block: case kIROp_SPIRVAsm: for (auto child : inst->getChildren()) { - addToWorkList({entryPoint, child}); + addToWorkList({entryPoint, child, parentFunc}); } break; case kIROp_Call: + registerEntryPointReference(entryPoint, inst); + registerFunctionReference(parentFunc, inst); { auto call = as(inst); - addToWorkList({entryPoint, call->getCallee()}); + registerCallReference(call, as(call->getCallee())); + addToWorkList({entryPoint, call->getCallee(), parentFunc}); } break; case kIROp_SPIRVAsmOperandInst: { auto operand = as(inst); - addToWorkList({entryPoint, operand->getValue()}); + addToWorkList({entryPoint, operand->getValue(), parentFunc}); } break; } @@ -92,7 +137,7 @@ void buildEntryPointReferenceGraph( case kIROp_GlobalParam: case kIROp_GlobalVar: case kIROp_SPIRVAsmOperandBuiltinVar: - addToWorkList({entryPoint, operand}); + addToWorkList({entryPoint, operand, parentFunc}); break; } } @@ -103,11 +148,12 @@ void buildEntryPointReferenceGraph( if (globalInst->getOp() == kIROp_Func && globalInst->findDecoration()) { - visit(as(globalInst), globalInst); + auto entryPointFunc = as(globalInst); + visit(entryPointFunc, globalInst, nullptr); } } for (Index i = 0; i < workList.getCount(); i++) - visit(workList[i].entryPoint, workList[i].inst); + visit(workList[i].entryPoint, workList[i].inst, workList[i].parentFunc); } HashSet* getReferencingEntryPoints( diff --git a/source/slang/slang-ir-call-graph.h b/source/slang/slang-ir-call-graph.h index 4ee6423566..3dc400bb1d 100644 --- a/source/slang/slang-ir-call-graph.h +++ b/source/slang/slang-ir-call-graph.h @@ -6,7 +6,9 @@ namespace Slang void buildEntryPointReferenceGraph( Dictionary>& referencingEntryPoints, - IRModule* module); + IRModule* module, + Dictionary>* referencingFunctions = nullptr, + Dictionary>* referencingCalls = nullptr); HashSet* getReferencingEntryPoints( Dictionary>& m_referencingEntryPoints, diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp index e411e41ce1..53f676c5cc 100644 --- a/source/slang/slang-ir-legalize-system-values.cpp +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -1,6 +1,7 @@ #include "slang-ir-legalize-system-values.h" #include "core/slang-dictionary.h" +#include "core/slang-string.h" #include "slang-diagnostics.h" #include "slang-ir-insts.h" #include "slang-ir-legalize-varying-params.h" @@ -13,8 +14,12 @@ class ImplicitSystemValueLegalizationContext public: ImplicitSystemValueLegalizationContext( const Dictionary>& entryPointReferenceGraph, + const Dictionary>& functionReferenceGraph, + const Dictionary>& callReferenceGraph, const List& implicitSystemValueInstructions) : m_entryPointReferenceGraph(entryPointReferenceGraph) + , m_functionReferenceGraph(functionReferenceGraph) + , m_callReferenceGraph(callReferenceGraph) , m_implicitSystemValueInstructions(implicitSystemValueInstructions) { } @@ -23,7 +28,8 @@ class ImplicitSystemValueLegalizationContext { for (auto implicitSysVal : m_implicitSystemValueInstructions) { - for (auto entryPoint : *m_entryPointReferenceGraph.tryGetValue(implicitSysVal)) + // for (auto entryPoint : *m_entryPointReferenceGraph.tryGetValue(implicitSysVal)) + for (auto entryPoint : *m_functionReferenceGraph.tryGetValue(implicitSysVal)) { auto param = getOrCreateSystemValueParam(entryPoint, implicitSysVal); implicitSysVal->replaceUsesWith(param); @@ -33,13 +39,210 @@ class ImplicitSystemValueLegalizationContext } private: - IRParam* getOrCreateSystemValueParam(IRFunc* entryPoint, IRImplicitSystemValue* implicitSysVal) + using SystemValueParamMap = Dictionary; + + SystemValueParamMap& getParamMap(IRFunc* func) + { + if (auto map = m_functionMap.tryGetValue(func)) + { + return *map; + } + else + { + m_functionMap.add(func, SystemValueParamMap()); + return m_functionMap.getValue(func); + } + } + + // + // Attempt to retrieve a parameter for a specific function and system value type combination. + // + IRParam* tryGetParam(IRFunc* func, SystemValueSemanticName systemValueName) + { + if (auto param = getParamMap(func).tryGetValue(systemValueName)) + { + return *param; + } + else + { + return nullptr; + } + } + + // + // Implicit system values are "global variables" and can be used anywhere within the source + // code. The implementation target(i.e WGSL) however requires system values, aka built-in + // values, to be accessed via parameters to the entry point; they are not globally available. + // + // For any implicit system values found in non entry point functions, we need to ensure that + // they are explicitly passed as parameters from the entry point to the relevant functions. This + // means adding new parameters to the function signatures to include the required system values. + // + struct CreateParamWorkItem + { + /// Function to add a new param to. + IRFunc* func; + + /// System value semantic info. + SystemValueSemanticName systemValueName; + UnownedStringSlice systemValueString; + + Index paramIndex; + }; + + // + // In addition to modifying the function signatures, we need to ensure the calls to the modified + // function include the system value variable. + // + struct ModifyCallWorkItem + { + /// Call to modify. + IRCall* call; + + /// Parameter to add to the call. + // IRParam* paramToAdd; + Index paramIndex; + }; + + // + // Creates a new parameter for functions that require the implicit system value. + // Returns work items of calls that need to be modified as a result of adding the parameters. + // + List createFunctionParams( + IRFunc* parentFunc, + IRImplicitSystemValue* implicitSysVal) + { + // Work list for parameter creation work. + List createParamWorkList; + + // Work list for call replacement work. + List modifyCallWorkList; + + // List to store new parameters that represent system value variables - used to propagate + // the param variables from parameter creation work to call replacement work. + List addedParams; + + const auto systemValueName = + convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); + SLANG_ASSERT(systemValueName != SystemValueSemanticName::Unknown); + + createParamWorkList.add( + {parentFunc, + systemValueName, + implicitSysVal->getSystemValueName(), + addedParams.getCount()}); + addedParams.add(nullptr); + + const auto addWorkItems = [&](const HashSet& calls, CreateParamWorkItem workItem) + { + for (auto call : calls) + { + for (auto callerFunc : m_functionReferenceGraph.getValue(call)) + { + // The caller(of a function that was added a parameter) also requires a + // new parameter to pass in the system value variable to the callee. + workItem.func = callerFunc; + workItem.paramIndex = addedParams.getCount(); + createParamWorkList.add(workItem); + addedParams.add(nullptr); + + // The call needs to be modified to add the new parameter. + modifyCallWorkList.add({call, workItem.paramIndex}); + } + } + }; + + IRBuilder builder(parentFunc); + + // The new system value parameter's type is harcoded. + // + // Implicit system values are currently only being used for subgroup size and + // subgroup invocation id, both of which are 32-bit unsigned. + SLANG_ASSERT( + (systemValueName == SystemValueSemanticName::WaveLaneCount) || + (systemValueName == SystemValueSemanticName::WaveLaneIndex)); + auto paramType = builder.getUIntType(); + + const auto createParamWork = [&](CreateParamWorkItem workItem) + { + // If the parameter for system value type has not been created, create it. + if (!tryGetParam(workItem.func, workItem.systemValueName)) + { + builder.setInsertBefore(workItem.func->getFirstBlock()->getFirstOrdinaryInst()); + + auto param = builder.emitParam(paramType); + + // Add system value semantic decoration if adding to entry point. + if (parentFunc->findDecoration()) + { + builder.addSemanticDecoration(param, workItem.systemValueString); + } + + getParamMap(workItem.func).add(systemValueName, param); + addedParams[workItem.paramIndex] = param; + + if (auto calls = m_callReferenceGraph.tryGetValue(workItem.func)) + { + addWorkItems(*calls, workItem); + } + } + }; + + for (Index i = 0; i < createParamWorkList.getCount(); i++) + { + createParamWork(createParamWorkList[i]); + } + + const auto modifyCallWork = [&](IRCall* call, IRParam* param) + { + List newCallParams; + for (auto arg : call->getArgsList()) + { + newCallParams.add(arg); + } + newCallParams.add(param); + + IRBuilder newBuilder(call); + newBuilder.setInsertAfter(call); + auto newCall = newBuilder.emitCallInst(paramType, call->getCallee(), newCallParams); + + // call->replaceUsesWith(newCall); + // call->transferDecorationsTo(newCall); + // call->removeAndDeallocate(); + }; + + for (const auto workItem : modifyCallWorkList) + { + modifyCallWork(workItem.call, addedParams[workItem.paramIndex]); + } + + return modifyCallWorkList; + } + + IRParam* getOrCreateSystemValueParam(IRFunc* parentFunc, IRImplicitSystemValue* implicitSysVal) { - if (!m_entryPointMap.containsKey(entryPoint)) + const auto systemValueName = + convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); + SLANG_ASSERT(systemValueName != SystemValueSemanticName::Unknown); + + if (auto param = tryGetParam(parentFunc, systemValueName)) { - m_entryPointMap.add(entryPoint, SystemValueParamMap()); + return param; } - auto systemValueParamMap = m_entryPointMap.tryGetValue(entryPoint); + + createFunctionParams(parentFunc, implicitSysVal); + return tryGetParam(parentFunc, systemValueName); + } + + // void replaceCalls() {} + + IRParam* getOrCreateSystemValueParam2(IRFunc* parentFunc, IRImplicitSystemValue* implicitSysVal) + { + if (!m_functionMap.containsKey(parentFunc)) + { + m_functionMap.add(parentFunc, SystemValueParamMap()); + } + auto systemValueParamMap = m_functionMap.tryGetValue(parentFunc); const auto systemValueName = convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); @@ -49,8 +252,8 @@ class ImplicitSystemValueLegalizationContext const auto paramPtr = systemValueParamMap->tryGetValue(systemValueName); if (paramPtr == nullptr) { - IRBuilder builder(entryPoint); - builder.setInsertBefore(entryPoint->getFirstBlock()->getFirstOrdinaryInst()); + IRBuilder builder(parentFunc); + builder.setInsertBefore(parentFunc->getFirstBlock()->getFirstOrdinaryInst()); // The new system value parameter's type is harcoded. // @@ -59,10 +262,43 @@ class ImplicitSystemValueLegalizationContext SLANG_ASSERT( (systemValueName == SystemValueSemanticName::WaveLaneCount) || (systemValueName == SystemValueSemanticName::WaveLaneIndex)); - param = builder.emitParam(builder.getUIntType()); - builder.addSemanticDecoration(param, implicitSysVal->getSystemValueName()); + auto paramType = builder.getUIntType(); + param = builder.emitParam(paramType); + + + if (parentFunc->findDecoration()) + { + builder.addSemanticDecoration(param, implicitSysVal->getSystemValueName()); + } systemValueParamMap->add(systemValueName, param); + + if (m_callReferenceGraph.containsKey(parentFunc)) + { + for (auto call : m_callReferenceGraph.getValue(parentFunc)) + { + for (auto callParentFunc : m_functionReferenceGraph.getValue(call)) + { + SLANG_ASSERT(parentFunc != callParentFunc); + auto newCallParam = + getOrCreateSystemValueParam(callParentFunc, implicitSysVal); + + List newCallParams; + for (auto arg : call->getArgsList()) + { + newCallParams.add(arg); + } + newCallParams.add(newCallParam); + + builder.setInsertAfter(call); + auto newCall = builder.emitCallInst(paramType, parentFunc, newCallParams); + + call->replaceUsesWith(newCall); + call->transferDecorationsTo(newCall); + call->removeAndDeallocate(); + } + } + } } else { @@ -72,19 +308,25 @@ class ImplicitSystemValueLegalizationContext return param; } - using SystemValueParamMap = Dictionary; const Dictionary>& m_entryPointReferenceGraph; + const Dictionary>& m_functionReferenceGraph; + const Dictionary>& m_callReferenceGraph; const List& m_implicitSystemValueInstructions; - Dictionary m_entryPointMap; + + Dictionary m_functionMap; }; void legalizeImplicitSystemValues( const Dictionary>& entryPointReferenceGraph, + const Dictionary>& functionReferenceGraph, + const Dictionary>& callReferenceGraph, const List& implicitSystemValueInstructions) { ImplicitSystemValueLegalizationContext( entryPointReferenceGraph, + functionReferenceGraph, + callReferenceGraph, implicitSystemValueInstructions) .legalize(); } diff --git a/source/slang/slang-ir-legalize-system-values.h b/source/slang/slang-ir-legalize-system-values.h index 617556c40f..f0abedef43 100644 --- a/source/slang/slang-ir-legalize-system-values.h +++ b/source/slang/slang-ir-legalize-system-values.h @@ -7,6 +7,8 @@ namespace Slang void legalizeImplicitSystemValues( const Dictionary>& entryPointReferenceGraph, + const Dictionary>& functionReferenceGraph, + const Dictionary>& callReferenceGraph, const List& implicitSystemValueInstructions); } // namespace Slang diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 1530844b5c..1c0f96c509 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -236,9 +236,18 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) if (instContext.implicitSystemValueInstructions.getCount() != 0) { Dictionary> entryPointReferenceGraph; - buildEntryPointReferenceGraph(entryPointReferenceGraph, module); + // Dictionary> functionReferenceGraph; + Dictionary> functionReferenceGraph; + Dictionary> callReferenceGraph; + buildEntryPointReferenceGraph( + entryPointReferenceGraph, + module, + &functionReferenceGraph, + &callReferenceGraph); legalizeImplicitSystemValues( entryPointReferenceGraph, + functionReferenceGraph, + callReferenceGraph, instContext.implicitSystemValueInstructions); } From 4938e5d366a57effb50bbdfd8181784272b05238 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Mon, 10 Feb 2025 02:01:58 -0500 Subject: [PATCH 04/13] clean legalize sys val --- .../slang/slang-ir-legalize-system-values.cpp | 260 ++++++------------ .../slang/slang-ir-legalize-system-values.h | 2 +- source/slang/slang-ir-wgsl-legalize.cpp | 2 +- 3 files changed, 87 insertions(+), 177 deletions(-) diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp index 53f676c5cc..3f540474d9 100644 --- a/source/slang/slang-ir-legalize-system-values.cpp +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -2,7 +2,6 @@ #include "core/slang-dictionary.h" #include "core/slang-string.h" -#include "slang-diagnostics.h" #include "slang-ir-insts.h" #include "slang-ir-legalize-varying-params.h" @@ -13,14 +12,15 @@ class ImplicitSystemValueLegalizationContext { public: ImplicitSystemValueLegalizationContext( - const Dictionary>& entryPointReferenceGraph, + IRModule* module, const Dictionary>& functionReferenceGraph, const Dictionary>& callReferenceGraph, const List& implicitSystemValueInstructions) - : m_entryPointReferenceGraph(entryPointReferenceGraph) - , m_functionReferenceGraph(functionReferenceGraph) + : m_functionReferenceGraph(functionReferenceGraph) , m_callReferenceGraph(callReferenceGraph) , m_implicitSystemValueInstructions(implicitSystemValueInstructions) + , m_builder(module) + , m_paramType(m_builder.getUIntType()) { } @@ -31,7 +31,7 @@ class ImplicitSystemValueLegalizationContext // for (auto entryPoint : *m_entryPointReferenceGraph.tryGetValue(implicitSysVal)) for (auto entryPoint : *m_functionReferenceGraph.tryGetValue(implicitSysVal)) { - auto param = getOrCreateSystemValueParam(entryPoint, implicitSysVal); + auto param = getOrCreateSystemValueVariable(entryPoint, implicitSysVal); implicitSysVal->replaceUsesWith(param); implicitSysVal->removeAndDeallocate(); } @@ -39,8 +39,14 @@ class ImplicitSystemValueLegalizationContext } private: + // + // A function (including entry points) must have at most one parameter for each system value + // semantic type. + // + // This map tracks the association between system value semantics and their corresponding + // function parameters. + // using SystemValueParamMap = Dictionary; - SystemValueParamMap& getParamMap(IRFunc* func) { if (auto map = m_functionMap.tryGetValue(func)) @@ -56,6 +62,7 @@ class ImplicitSystemValueLegalizationContext // // Attempt to retrieve a parameter for a specific function and system value type combination. + // Returns nullptr if parameter has not been created. // IRParam* tryGetParam(IRFunc* func, SystemValueSemanticName systemValueName) { @@ -69,6 +76,12 @@ class ImplicitSystemValueLegalizationContext } } + struct ModifyCallWorkItem + { + IRCall* call; + IRFunc* caller; + }; + // // Implicit system values are "global variables" and can be used anywhere within the source // code. The implementation target(i.e WGSL) however requires system values, aka built-in @@ -78,123 +91,88 @@ class ImplicitSystemValueLegalizationContext // they are explicitly passed as parameters from the entry point to the relevant functions. This // means adding new parameters to the function signatures to include the required system values. // - struct CreateParamWorkItem - { - /// Function to add a new param to. - IRFunc* func; - - /// System value semantic info. - SystemValueSemanticName systemValueName; - UnownedStringSlice systemValueString; - - Index paramIndex; - }; - - // - // In addition to modifying the function signatures, we need to ensure the calls to the modified - // function include the system value variable. - // - struct ModifyCallWorkItem - { - /// Call to modify. - IRCall* call; - - /// Parameter to add to the call. - // IRParam* paramToAdd; - Index paramIndex; - }; - - // - // Creates a new parameter for functions that require the implicit system value. - // Returns work items of calls that need to be modified as a result of adding the parameters. + // This function traverses the call graph of a function that contains an implicit system value + // instruction, and adds necessary parameters to pass in the system value variable up to the + // entry point function. Returns work items of calls that need to be modified as a result of + // adding the parameters. // List createFunctionParams( - IRFunc* parentFunc, - IRImplicitSystemValue* implicitSysVal) + IRFunc* func, + SystemValueSemanticName systemValueName, + UnownedStringSlice systemValueString) { - // Work list for parameter creation work. - List createParamWorkList; - - // Work list for call replacement work. + List functionWorkList; List modifyCallWorkList; - // List to store new parameters that represent system value variables - used to propagate - // the param variables from parameter creation work to call replacement work. - List addedParams; - - const auto systemValueName = - convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); - SLANG_ASSERT(systemValueName != SystemValueSemanticName::Unknown); - - createParamWorkList.add( - {parentFunc, - systemValueName, - implicitSysVal->getSystemValueName(), - addedParams.getCount()}); - addedParams.add(nullptr); - - const auto addWorkItems = [&](const HashSet& calls, CreateParamWorkItem workItem) + const auto addWorkItems = [&](const HashSet& calls) { for (auto call : calls) { - for (auto callerFunc : m_functionReferenceGraph.getValue(call)) + for (auto caller : m_functionReferenceGraph.getValue(call)) { // The caller(of a function that was added a parameter) also requires a // new parameter to pass in the system value variable to the callee. - workItem.func = callerFunc; - workItem.paramIndex = addedParams.getCount(); - createParamWorkList.add(workItem); - addedParams.add(nullptr); + functionWorkList.add(caller); // The call needs to be modified to add the new parameter. - modifyCallWorkList.add({call, workItem.paramIndex}); + modifyCallWorkList.add({call, caller}); } } }; - IRBuilder builder(parentFunc); - - // The new system value parameter's type is harcoded. - // // Implicit system values are currently only being used for subgroup size and - // subgroup invocation id, both of which are 32-bit unsigned. + // subgroup invocation id. SLANG_ASSERT( (systemValueName == SystemValueSemanticName::WaveLaneCount) || (systemValueName == SystemValueSemanticName::WaveLaneIndex)); - auto paramType = builder.getUIntType(); - const auto createParamWork = [&](CreateParamWorkItem workItem) + const auto createParamWork = [&](IRFunc* func) { // If the parameter for system value type has not been created, create it. - if (!tryGetParam(workItem.func, workItem.systemValueName)) + if (!tryGetParam(func, systemValueName)) { - builder.setInsertBefore(workItem.func->getFirstBlock()->getFirstOrdinaryInst()); + m_builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto param = builder.emitParam(paramType); + auto param = m_builder.emitParam(m_paramType); // Add system value semantic decoration if adding to entry point. - if (parentFunc->findDecoration()) + if (func->findDecoration()) { - builder.addSemanticDecoration(param, workItem.systemValueString); + m_builder.addSemanticDecoration(param, systemValueString); } - getParamMap(workItem.func).add(systemValueName, param); - addedParams[workItem.paramIndex] = param; + getParamMap(func).add(systemValueName, param); - if (auto calls = m_callReferenceGraph.tryGetValue(workItem.func)) + if (auto calls = m_callReferenceGraph.tryGetValue(func)) { - addWorkItems(*calls, workItem); + addWorkItems(*calls); } } }; - for (Index i = 0; i < createParamWorkList.getCount(); i++) + functionWorkList.add(func); + for (Index i = 0; i < functionWorkList.getCount(); i++) { - createParamWork(createParamWorkList[i]); + createParamWork(functionWorkList[i]); } - const auto modifyCallWork = [&](IRCall* call, IRParam* param) + return modifyCallWorkList; + } + + // + // In addition to modifying the function signatures, we need to ensure the calls to the modified + // function include the system value variable. + // + void modifyCalls( + const List& workItems, + SystemValueSemanticName systemValueName) + { + for (const auto workItem : workItems) { + auto call = workItem.call; + auto param = tryGetParam(workItem.caller, systemValueName); + SLANG_ASSERT(param); + List newCallParams; for (auto arg : call->getArgsList()) { @@ -202,129 +180,61 @@ class ImplicitSystemValueLegalizationContext } newCallParams.add(param); - IRBuilder newBuilder(call); - newBuilder.setInsertAfter(call); - auto newCall = newBuilder.emitCallInst(paramType, call->getCallee(), newCallParams); - - // call->replaceUsesWith(newCall); - // call->transferDecorationsTo(newCall); - // call->removeAndDeallocate(); - }; + m_builder.setInsertAfter(call); + auto newCall = m_builder.emitCallInst(m_paramType, call->getCallee(), newCallParams); - for (const auto workItem : modifyCallWorkList) - { - modifyCallWork(workItem.call, addedParams[workItem.paramIndex]); + call->replaceUsesWith(newCall); + call->transferDecorationsTo(newCall); + call->removeAndDeallocate(); } - - return modifyCallWorkList; } - IRParam* getOrCreateSystemValueParam(IRFunc* parentFunc, IRImplicitSystemValue* implicitSysVal) + IRParam* getOrCreateSystemValueVariable( + IRFunc* parentFunc, + IRImplicitSystemValue* implicitSysVal) { - const auto systemValueName = + auto systemValueName = convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); SLANG_ASSERT(systemValueName != SystemValueSemanticName::Unknown); + // If parameter for the specific function and system value type combination was already + // created, return it directly. if (auto param = tryGetParam(parentFunc, systemValueName)) { return param; } - createFunctionParams(parentFunc, implicitSysVal); - return tryGetParam(parentFunc, systemValueName); - } - - // void replaceCalls() {} - - IRParam* getOrCreateSystemValueParam2(IRFunc* parentFunc, IRImplicitSystemValue* implicitSysVal) - { - if (!m_functionMap.containsKey(parentFunc)) - { - m_functionMap.add(parentFunc, SystemValueParamMap()); - } - auto systemValueParamMap = m_functionMap.tryGetValue(parentFunc); - - const auto systemValueName = - convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); - SLANG_ASSERT(systemValueName != SystemValueSemanticName::Unknown); - - IRParam* param; - const auto paramPtr = systemValueParamMap->tryGetValue(systemValueName); - if (paramPtr == nullptr) - { - IRBuilder builder(parentFunc); - builder.setInsertBefore(parentFunc->getFirstBlock()->getFirstOrdinaryInst()); - - // The new system value parameter's type is harcoded. - // - // Implicit system values are currently only being used for subgroup size and subgroup - // invocation id, both of which are 32-bit unsigned. - SLANG_ASSERT( - (systemValueName == SystemValueSemanticName::WaveLaneCount) || - (systemValueName == SystemValueSemanticName::WaveLaneIndex)); - auto paramType = builder.getUIntType(); - param = builder.emitParam(paramType); - - - if (parentFunc->findDecoration()) - { - builder.addSemanticDecoration(param, implicitSysVal->getSystemValueName()); - } - - systemValueParamMap->add(systemValueName, param); + // Create new parameters for the relevant functions up to the entry point function. + const auto callWorkItems = + createFunctionParams(parentFunc, systemValueName, implicitSysVal->getSystemValueName()); - if (m_callReferenceGraph.containsKey(parentFunc)) - { - for (auto call : m_callReferenceGraph.getValue(parentFunc)) - { - for (auto callParentFunc : m_functionReferenceGraph.getValue(call)) - { - SLANG_ASSERT(parentFunc != callParentFunc); - auto newCallParam = - getOrCreateSystemValueParam(callParentFunc, implicitSysVal); - - List newCallParams; - for (auto arg : call->getArgsList()) - { - newCallParams.add(arg); - } - newCallParams.add(newCallParam); - - builder.setInsertAfter(call); - auto newCall = builder.emitCallInst(paramType, parentFunc, newCallParams); - - call->replaceUsesWith(newCall); - call->transferDecorationsTo(newCall); - call->removeAndDeallocate(); - } - } - } - } - else - { - param = *paramPtr; - } + // Modify related function calls to account for the new parameters. + modifyCalls(callWorkItems, systemValueName); - return param; + return tryGetParam(parentFunc, systemValueName); } - - const Dictionary>& m_entryPointReferenceGraph; const Dictionary>& m_functionReferenceGraph; const Dictionary>& m_callReferenceGraph; const List& m_implicitSystemValueInstructions; Dictionary m_functionMap; + + IRBuilder m_builder; + + // Implicit system values are currently only being used for subgroup size and + // subgroup invocation id, both of which are 32-bit unsigned. + IRType* m_paramType; }; void legalizeImplicitSystemValues( - const Dictionary>& entryPointReferenceGraph, + IRModule* module, const Dictionary>& functionReferenceGraph, const Dictionary>& callReferenceGraph, const List& implicitSystemValueInstructions) { ImplicitSystemValueLegalizationContext( - entryPointReferenceGraph, + module, functionReferenceGraph, callReferenceGraph, implicitSystemValueInstructions) diff --git a/source/slang/slang-ir-legalize-system-values.h b/source/slang/slang-ir-legalize-system-values.h index f0abedef43..36cf3e9b55 100644 --- a/source/slang/slang-ir-legalize-system-values.h +++ b/source/slang/slang-ir-legalize-system-values.h @@ -6,7 +6,7 @@ namespace Slang { void legalizeImplicitSystemValues( - const Dictionary>& entryPointReferenceGraph, + IRModule* module, const Dictionary>& functionReferenceGraph, const Dictionary>& callReferenceGraph, const List& implicitSystemValueInstructions); diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 1c0f96c509..96cacdee36 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -245,7 +245,7 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) &functionReferenceGraph, &callReferenceGraph); legalizeImplicitSystemValues( - entryPointReferenceGraph, + module, functionReferenceGraph, callReferenceGraph, instContext.implicitSystemValueInstructions); From dae0e7cf671cba416d17350ee8acf73e8d3fd1de Mon Sep 17 00:00:00 2001 From: fairywreath Date: Mon, 10 Feb 2025 03:33:55 -0500 Subject: [PATCH 05/13] wip --- source/slang/slang-ir-call-graph.h | 15 +++++++++++++++ source/slang/slang-ir-legalize-system-values.cpp | 1 - 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-ir-call-graph.h b/source/slang/slang-ir-call-graph.h index 3dc400bb1d..d2baaf340c 100644 --- a/source/slang/slang-ir-call-graph.h +++ b/source/slang/slang-ir-call-graph.h @@ -14,4 +14,19 @@ HashSet* getReferencingEntryPoints( Dictionary>& m_referencingEntryPoints, IRInst* inst); + +/* +class FunctionCallGraph +{ +public: + const HashSet* getReferencingFunctions(IRInst* inst) const; + const HashSet* getFunctionCalls(IRFunc* func) const; + +private: + Dictionary> m_referencingFunctions; + Dictionary> m_functionCalls; +}; +*/ + + } // namespace Slang diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp index 3f540474d9..ce51194a56 100644 --- a/source/slang/slang-ir-legalize-system-values.cpp +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -28,7 +28,6 @@ class ImplicitSystemValueLegalizationContext { for (auto implicitSysVal : m_implicitSystemValueInstructions) { - // for (auto entryPoint : *m_entryPointReferenceGraph.tryGetValue(implicitSysVal)) for (auto entryPoint : *m_functionReferenceGraph.tryGetValue(implicitSysVal)) { auto param = getOrCreateSystemValueVariable(entryPoint, implicitSysVal); From 3862743dd48bfd38e5560aeb1859840d0bec1477 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Mon, 10 Feb 2025 14:36:51 -0600 Subject: [PATCH 06/13] properly include extension --- source/slang/hlsl.meta.slang | 5 +++++ source/slang/slang-emit-c-like.cpp | 5 +++++ source/slang/slang-emit-c-like.h | 8 +++++--- source/slang/slang-emit-wgsl.cpp | 5 +++++ source/slang/slang-emit-wgsl.h | 2 ++ source/slang/slang-ir-inst-defs.h | 1 + source/slang/slang-ir-insts.h | 10 ++++++++++ 7 files changed, 33 insertions(+), 3 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 5a562b5d5e..121470d761 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6,6 +6,9 @@ typedef uint UINT; __intrinsic_op($(kIROp_RequireGLSLExtension)) void __requireGLSLExtension(String extensionName); +__intrinsic_op($(kIROp_RequireWGSLExtension)) +void __requireWGSLExtension(String extensionName); + __intrinsic_op($(kIROp_ImplicitSystemValue)) uint __implicitSystemValue(String systemValueName); @@ -15055,6 +15058,7 @@ uint WaveGetLaneCount() result:$$uint = OpLoad builtin(SubgroupSize:uint) }; case wgsl: + __requireWGSLExtension("subgroups"); return __implicitSystemValue("SV_WaveLaneCount"); } } @@ -15079,6 +15083,7 @@ uint WaveGetLaneIndex() result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint) }; case wgsl: + __requireWGSLExtension("subgroups"); return __implicitSystemValue("SV_WaveLaneIndex"); } } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 5651632349..70d83d7125 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3059,6 +3059,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO { break; } + case kIROp_RequireWGSLExtension: + { + emitRequireExtension(inst); + break; + } default: diagnoseUnhandledInst(inst); break; diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index e83b6e5861..f7f1d372f6 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -414,18 +414,18 @@ class CLikeSourceEmitter : public SourceEmitterBase /// Emit type attributes that should appear after, e.g., a `struct` keyword void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); } - virtual void emitMemoryQualifiers(IRInst* /*varInst*/){}; + virtual void emitMemoryQualifiers(IRInst* /*varInst*/) {}; virtual void emitStructFieldAttributes( IRStructType* /* structType */, IRStructField* /* field */ - ){}; + ) {}; void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout); void emitMeshShaderModifiers(IRInst* varInst); virtual void emitPackOffsetModifier( IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/ - ){}; + ) {}; /// Emit modifiers that should apply even for a declaration of an SSA temporary. @@ -678,6 +678,8 @@ class CLikeSourceEmitter : public SourceEmitterBase void _emitCallArgList(IRCall* call, int startingOperandIndex = 1); virtual void emitCallArg(IRInst* arg); + virtual void emitRequireExtension(IRInst* inst) { SLANG_UNUSED(inst); } + String _generateUniqueName(const UnownedStringSlice& slice); // Sort witnessTable entries according to the order defined in the witnessed interface type. diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 13c79e9acc..933e750e2a 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1696,4 +1696,9 @@ void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) } } +void WGSLSourceEmitter::emitRequireExtension(IRInst* inst) +{ + _requireExtension(as(inst)->getExtensionName()); +} + } // namespace Slang diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index 441933b570..2392c4d3c3 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -65,6 +65,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; } + virtual void emitRequireExtension(IRInst* inst) SLANG_OVERRIDE; + private: bool maybeEmitSystemSemantic(IRInst* inst); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 3ff1ba3dfc..20c4583bde 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -667,6 +667,7 @@ INST(discard, discard, 0, 0) INST(RequirePrelude, RequirePrelude, 1, 0) INST(RequireGLSLExtension, RequireGLSLExtension, 1, 0) +INST(RequireWGSLExtension, RequireWGSLExtension, 1, 0) INST(RequireComputeDerivative, RequireComputeDerivative, 0, 0) INST(StaticAssert, StaticAssert, 2, 0) INST(Printf, Printf, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index e6d3a8383a..a19f553d6b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3479,6 +3479,16 @@ struct IRRequireGLSLExtension : IRInst } }; +struct IRRequireWGSLExtension : IRInst +{ + IR_LEAF_ISA(RequireWGSLExtension) + UnownedStringSlice getExtensionName() + { + return as(getOperand(0))->getStringSlice(); + } +}; +; + struct IRRequireComputeDerivative : IRInst { IR_LEAF_ISA(RequireComputeDerivative) From f59b7de5f5dcd73b8e37c0aecb5fb3743901b5cc Mon Sep 17 00:00:00 2001 From: fairywreath Date: Mon, 10 Feb 2025 17:30:34 -0500 Subject: [PATCH 07/13] enable test and minor fixes --- source/slang/glsl.meta.slang | 2 ++ .../shader-subgroup-builtin-variables.slang | 7 +++++-- tests/hlsl-intrinsic/wave-get-lane-index.slang | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index 6f0ca1bf34..38e2871bd5 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -6409,6 +6409,7 @@ public property uint gl_SubgroupID public property uint gl_SubgroupSize { + [ForceInline] [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); @@ -6418,6 +6419,7 @@ public property uint gl_SubgroupSize public property uint gl_SubgroupInvocationID { + [ForceInline] [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang index 21b533178e..626a613a4e 100644 --- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang +++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang @@ -10,6 +10,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl //TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU #version 430 //TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer @@ -24,15 +25,17 @@ void computeMain() { if (gl_GlobalInvocationID.x == 3) { outputBuffer.data[0] = true - && gl_NumSubgroups == 1 - && gl_SubgroupID == 0 //1 subgroup, 0 based indexing && gl_SubgroupSize == 32 && gl_SubgroupInvocationID == 3 +#if !defined(WGPU) + && gl_SubgroupID == 0 //1 subgroup, 0 based indexing + && gl_NumSubgroups == 1 && gl_SubgroupEqMask == uvec4(0b1000,0,0,0) && gl_SubgroupGeMask == uvec4(0xFFFFFFF8,0,0,0) && gl_SubgroupGtMask == uvec4(0xFFFFFFF0,0,0,0) && gl_SubgroupLeMask == uvec4(0b1111,0,0,0) && gl_SubgroupLtMask == uvec4(0b111,0,0,0) +#endif ; } // CHECK_GLSL: void main( diff --git a/tests/hlsl-intrinsic/wave-get-lane-index.slang b/tests/hlsl-intrinsic/wave-get-lane-index.slang index fb09022c23..e9b917442a 100644 --- a/tests/hlsl-intrinsic/wave-get-lane-index.slang +++ b/tests/hlsl-intrinsic/wave-get-lane-index.slang @@ -4,6 +4,7 @@ //TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj //TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj //TEST:COMPARE_COMPUTE_EX:-cuda -compute -shaderobj +//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer outputBuffer; From 622caefd635ec1695328918fc435f59997050b51 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Mon, 10 Feb 2025 17:51:34 -0500 Subject: [PATCH 08/13] improve comments and minor fix --- .../slang/slang-ir-legalize-system-values.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp index ce51194a56..50a9979e03 100644 --- a/source/slang/slang-ir-legalize-system-values.cpp +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -39,8 +39,8 @@ class ImplicitSystemValueLegalizationContext private: // - // A function (including entry points) must have at most one parameter for each system value - // semantic type. + // A function (including entry points) must have at most one parameter for each implicit system + // value semantic type. // // This map tracks the association between system value semantics and their corresponding // function parameters. @@ -100,6 +100,12 @@ class ImplicitSystemValueLegalizationContext SystemValueSemanticName systemValueName, UnownedStringSlice systemValueString) { + // Implicit system values are currently only being used for subgroup size and + // subgroup invocation id. + SLANG_ASSERT( + (systemValueName == SystemValueSemanticName::WaveLaneCount) || + (systemValueName == SystemValueSemanticName::WaveLaneIndex)); + List functionWorkList; List modifyCallWorkList; @@ -113,18 +119,12 @@ class ImplicitSystemValueLegalizationContext // new parameter to pass in the system value variable to the callee. functionWorkList.add(caller); - // The call needs to be modified to add the new parameter. + // The call needs to be modified to account for the new parameter. modifyCallWorkList.add({call, caller}); } } }; - // Implicit system values are currently only being used for subgroup size and - // subgroup invocation id. - SLANG_ASSERT( - (systemValueName == SystemValueSemanticName::WaveLaneCount) || - (systemValueName == SystemValueSemanticName::WaveLaneIndex)); - const auto createParamWork = [&](IRFunc* func) { // If the parameter for system value type has not been created, create it. @@ -140,8 +140,8 @@ class ImplicitSystemValueLegalizationContext m_builder.addSemanticDecoration(param, systemValueString); } + fixUpFuncType(func); getParamMap(func).add(systemValueName, param); - if (auto calls = m_callReferenceGraph.tryGetValue(func)) { addWorkItems(*calls); From e84ce8e0d6609bf9528fc43b187c6a08e9603e64 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Wed, 12 Feb 2025 15:56:32 -0600 Subject: [PATCH 09/13] add tests --- ...licit-system-values-dynamic-dispatch.slang | 45 ++++++++ tests/wgsl/implicit-system-values.slang | 103 ++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 tests/wgsl/implicit-system-values-dynamic-dispatch.slang create mode 100644 tests/wgsl/implicit-system-values.slang diff --git a/tests/wgsl/implicit-system-values-dynamic-dispatch.slang b/tests/wgsl/implicit-system-values-dynamic-dispatch.slang new file mode 100644 index 0000000000..c51369d6e0 --- /dev/null +++ b/tests/wgsl/implicit-system-values-dynamic-dispatch.slang @@ -0,0 +1,45 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + uint getLaneIndex(uint base); +} + +struct Impl1 : IInterface +{ + uint getLaneIndex(uint base) + { + return base * WaveGetLaneIndex(); + } +}; + + +struct Impl2 : IInterface +{ + uint getLaneIndex(uint base) + { + return base; + } +} + +//TEST_INPUT: type_conformance Impl1:IInterface = 0 +//TEST_INPUT: type_conformance Impl2:IInterface = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject(dispatchThreadID.x, 0); // Impl1 + outputBuffer[0] = obj.getLaneIndex(5); + + obj = createDynamicObject(dispatchThreadID.x + 1, 0); // Impl2 + outputBuffer[1] = obj.getLaneIndex(5); + + // BUF: 0 + // BUF-NEXT: 5 +} diff --git a/tests/wgsl/implicit-system-values.slang b/tests/wgsl/implicit-system-values.slang new file mode 100644 index 0000000000..1f2b6bfc00 --- /dev/null +++ b/tests/wgsl/implicit-system-values.slang @@ -0,0 +1,103 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + uint getLaneIndex(); + uint getLaneIndex(uint base); +}; + +struct Impl1 : IInterface +{ + uint getLaneIndex() + { + return WaveGetLaneIndex(); + } + + uint getLaneIndex(uint base) + { + return base + WaveGetLaneIndex(); + } +}; + +struct Impl2 : IInterface +{ + uint getLaneIndex() + { + return 100; + } + + uint getLaneIndex(uint base) + { + return base - 1; + } +}; + +struct Impl3 : IInterface +{ + uint getLaneIndex(uint base = 1) + { + return base + WaveGetLaneIndex() + 1; + } +}; + +uint getLaneIndexGeneric(T interface, uint base = 3) where T : IInterface +{ + return interface.getLaneIndex(base); +} + +struct MyStruct where T : IInterface +{ + T interface; + uint getLaneIndex(uint base = 4) + { + return interface.getLaneIndex(base); + } +} + + +void testDynamicDispatch() +{ +} + +[numthreads(1,1,1)] +void computeMain() +{ + Impl1 impl1; + Impl2 impl2; + Impl3 impl3; + + MyStruct s1; + MyStruct s2; + MyStruct s3; + + // BUF: 1 + outputBuffer[0] = uint( + true + + // Interface implementations. + && (0 == impl1.getLaneIndex()) + && (1 == impl1.getLaneIndex(2)) + && (100 == impl2.getLaneIndex()) + && (2 == impl3.getLaneIndex()) + && (3 == impl3.getLaneIndex(2)) + + // Interface as function generic parameter. + && (4 == getLaneIndexGeneric(impl1, 4)) + && (3 == getLaneIndexGeneric(impl2, 4)) + && (5 == getLaneIndexGeneric(impl3, 4)) + && (3 == getLaneIndexGeneric(impl1)) + && (2 == getLaneIndexGeneric(impl2)) + && (4 == getLaneIndexGeneric(impl3)) + + // Interface as struct generic member. + && (5 == s1.getLaneIndex(5)) + && (4 == s2.getLaneIndex(5)) + && (6 == s3.getLaneIndex(5)) + && (4 == s1.getLaneIndex()) + && (3 == s2.getLaneIndex()) + && (5 == s3.getLaneIndex()) + ); +} From c6404a7ae5ebf84bd135be92f26d46d01b71c4c5 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Wed, 12 Feb 2025 17:19:18 -0500 Subject: [PATCH 10/13] update test --- ...licit-system-values-dynamic-dispatch.slang | 53 +++++++++++++------ tests/wgsl/implicit-system-values.slang | 41 ++++++++++---- 2 files changed, 68 insertions(+), 26 deletions(-) diff --git a/tests/wgsl/implicit-system-values-dynamic-dispatch.slang b/tests/wgsl/implicit-system-values-dynamic-dispatch.slang index c51369d6e0..7a8c6a1e59 100644 --- a/tests/wgsl/implicit-system-values-dynamic-dispatch.slang +++ b/tests/wgsl/implicit-system-values-dynamic-dispatch.slang @@ -1,10 +1,14 @@ // Test calling differentiable function through dynamic dispatch. -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -output-using-type -//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; +//TEST_INPUT: type_conformance Impl1:IInterface = 0 +//TEST_INPUT: type_conformance Impl2:IInterface = 1 +//TEST_INPUT: type_conformance Impl3:IInterface = 2 + [anyValueSize(16)] interface IInterface { @@ -15,31 +19,50 @@ struct Impl1 : IInterface { uint getLaneIndex(uint base) { - return base * WaveGetLaneIndex(); + return base; } -}; - +} struct Impl2 : IInterface { uint getLaneIndex(uint base) { - return base; + return base * WaveGetLaneIndex() * 2; } } -//TEST_INPUT: type_conformance Impl1:IInterface = 0 -//TEST_INPUT: type_conformance Impl2:IInterface = 1 +struct Impl3 : IInterface +{ + uint getLaneIndex(uint base) + { + return base + WaveGetLaneIndex(); + } +}; -[numthreads(1, 1, 1)] +[numthreads(2, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - var obj = createDynamicObject(dispatchThreadID.x, 0); // Impl1 - outputBuffer[0] = obj.getLaneIndex(5); + const uint base = 5; - obj = createDynamicObject(dispatchThreadID.x + 1, 0); // Impl2 - outputBuffer[1] = obj.getLaneIndex(5); + if (dispatchThreadID.x == 0) + { + var obj = createDynamicObject(dispatchThreadID.x, 0); // Impl0 + outputBuffer[0] = obj.getLaneIndex(base); - // BUF: 0 - // BUF-NEXT: 5 + obj = createDynamicObject(dispatchThreadID.x + 1, 0); // Impl1 + outputBuffer[1] = obj.getLaneIndex(base); + } + else + { + var obj = createDynamicObject(dispatchThreadID.x, 0); // Impl1 + outputBuffer[2] = obj.getLaneIndex(base); + + obj = createDynamicObject(dispatchThreadID.x + 1, 0); // Impl2 + outputBuffer[3] = obj.getLaneIndex(base); + } + + // BUF: 5 + // BUF-NEXT: 0 + // BUF-NEXT: 10 + // BUF-NEXT: 6 } diff --git a/tests/wgsl/implicit-system-values.slang b/tests/wgsl/implicit-system-values.slang index 1f2b6bfc00..4d1f3e7658 100644 --- a/tests/wgsl/implicit-system-values.slang +++ b/tests/wgsl/implicit-system-values.slang @@ -43,6 +43,21 @@ struct Impl3 : IInterface } }; +struct Impl4 : IInterface +{ + uint getLaneIndex() + { + return WaveGetLaneIndex() + 2; + } + + uint getLaneIndex(uint base) + { + return base * 2; + } +}; + + + uint getLaneIndexGeneric(T interface, uint base = 3) where T : IInterface { return interface.getLaneIndex(base); @@ -57,47 +72,51 @@ struct MyStruct where T : IInterface } } - -void testDynamicDispatch() -{ -} - [numthreads(1,1,1)] void computeMain() { Impl1 impl1; Impl2 impl2; Impl3 impl3; + Impl4 impl4; MyStruct s1; MyStruct s2; MyStruct s3; + MyStruct s4; // BUF: 1 outputBuffer[0] = uint( - true - + (0 == WaveGetLaneIndex()) + // Interface implementations. && (0 == impl1.getLaneIndex()) - && (1 == impl1.getLaneIndex(2)) && (100 == impl2.getLaneIndex()) && (2 == impl3.getLaneIndex()) + && (2 == impl4.getLaneIndex()) + && (2 == impl1.getLaneIndex(2)) + && (1 == impl2.getLaneIndex(2)) && (3 == impl3.getLaneIndex(2)) + && (12 == impl4.getLaneIndex(6)) // Interface as function generic parameter. - && (4 == getLaneIndexGeneric(impl1, 4)) - && (3 == getLaneIndexGeneric(impl2, 4)) - && (5 == getLaneIndexGeneric(impl3, 4)) && (3 == getLaneIndexGeneric(impl1)) && (2 == getLaneIndexGeneric(impl2)) && (4 == getLaneIndexGeneric(impl3)) + && (6 == getLaneIndexGeneric(impl4)) + && (4 == getLaneIndexGeneric(impl1, 4)) + && (3 == getLaneIndexGeneric(impl2, 4)) + && (5 == getLaneIndexGeneric(impl3, 4)) + && (8 == getLaneIndexGeneric(impl4, 4)) // Interface as struct generic member. && (5 == s1.getLaneIndex(5)) && (4 == s2.getLaneIndex(5)) && (6 == s3.getLaneIndex(5)) + && (10 == s4.getLaneIndex(5)) && (4 == s1.getLaneIndex()) && (3 == s2.getLaneIndex()) && (5 == s3.getLaneIndex()) + && (8 == s4.getLaneIndex()) ); } From 76019c60338dbcd1c631b48bf30d8167681b809f Mon Sep 17 00:00:00 2001 From: fairywreath Date: Thu, 13 Feb 2025 15:16:50 -0500 Subject: [PATCH 11/13] clean up call graph building --- source/slang/slang-emit-glsl.cpp | 4 +- source/slang/slang-emit-glsl.h | 3 +- source/slang/slang-emit-spirv.cpp | 20 +-- source/slang/slang-ir-call-graph.cpp | 149 +++++++++--------- source/slang/slang-ir-call-graph.h | 41 +++-- .../slang/slang-ir-legalize-system-values.cpp | 27 ++-- .../slang/slang-ir-legalize-system-values.h | 4 +- .../slang-ir-specialize-stage-switch.cpp | 5 +- source/slang/slang-ir-spirv-legalize.cpp | 5 +- source/slang/slang-ir-spirv-legalize.h | 6 +- .../slang-ir-translate-glsl-global-var.cpp | 11 +- source/slang/slang-ir-wgsl-legalize.cpp | 14 +- 12 files changed, 142 insertions(+), 147 deletions(-) diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 25dab3fb35..11f7b5ad5a 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -47,7 +47,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction( // // Handle cases where "require" IR operations exist in the function body and are required // as entry point decorations. - auto entryPoints = getReferencingEntryPoints(m_referencingEntryPoints, parentFunc); + auto entryPoints = m_callGraph.getReferencingEntryPoints(parentFunc); if (entryPoints == nullptr) return; @@ -81,7 +81,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction( void GLSLSourceEmitter::beforeComputeEmitActions(IRModule* module) { - buildEntryPointReferenceGraph(this->m_referencingEntryPoints, module); + m_callGraph.build(module); IRBuilder builder(module); for (auto globalInst : module->getGlobalInsts()) diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index 8308a9954d..7ea37f81f0 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -4,6 +4,7 @@ #include "slang-emit-c-like.h" #include "slang-extension-tracker.h" +#include "slang-ir-call-graph.h" namespace Slang { @@ -178,7 +179,7 @@ class GLSLSourceEmitter : public CLikeSourceEmitter void _beforeComputeEmitProcessInstruction(IRInst* parentFunc, IRInst* inst, IRBuilder& builder); - Dictionary> m_referencingEntryPoints; + CallGraph m_callGraph; RefPtr m_glslExtensionTracker; }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index ef015df7f9..5a65e2d61b 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -3553,8 +3553,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto parentFunc = getParentFunc(inst); - HashSet* entryPointsUsingInst = - getReferencingEntryPoints(m_referencingEntryPoints, parentFunc); + auto entryPointsUsingInst = m_callGraph.getReferencingEntryPoints(parentFunc); for (IRFunc* entryPoint : *entryPointsUsingInst) { bool isQuad = true; @@ -3608,8 +3607,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } case kIROp_RequireMaximallyReconverges: - if (auto entryPointsUsingInst = - getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst))) + if ( + // auto entryPointsUsingInst = + // getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst)) + auto entryPointsUsingInst = + m_callGraph.getReferencingEntryPoints(getParentFunc(inst))) { ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_maximal_reconvergence")); for (IRFunc* entryPoint : *entryPointsUsingInst) @@ -3623,7 +3625,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex break; case kIROp_RequireQuadDerivatives: if (auto entryPointsUsingInst = - getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst))) + m_callGraph.getReferencingEntryPoints(getParentFunc(inst))) { ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_quad_control")); requireSPIRVCapability(SpvCapabilityQuadControlKHR); @@ -4326,7 +4328,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (m_mapIRInstToSpvInst.tryGetValue(globalInst, spvGlobalInst)) { // Is this globalInst referenced by this entry point? - auto refSet = m_referencingEntryPoints.tryGetValue(globalInst); + auto refSet = m_callGraph.getReferencingEntryPoints(globalInst); if (refSet && refSet->contains(entryPoint)) { if (!isSpirv14OrLater()) @@ -5129,7 +5131,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex bool isInstUsedInStage(IRInst* inst, Stage s) { - auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst); + auto referencingEntryPoints = m_callGraph.getReferencingEntryPoints(inst); if (!referencingEntryPoints) return false; for (auto entryPoint : *referencingEntryPoints) @@ -5329,7 +5331,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } else if (semanticName == "sv_primitiveid") { - auto entryPoints = m_referencingEntryPoints.tryGetValue(inst); + auto entryPoints = m_callGraph.getReferencingEntryPoints(inst); // SPIRV requires `Geometry` capability being declared for a fragment // shader, if that shader uses sv_primitiveid. // We will check if this builtin is used by non-ray-tracing, non-geometry or @@ -8029,7 +8031,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case SpvOpExecutionMode: { if (auto refEntryPointSet = - m_referencingEntryPoints.tryGetValue(getParentFunc(inst))) + m_callGraph.getReferencingEntryPoints(getParentFunc(inst))) { for (auto entryPoint : *refEntryPointSet) { diff --git a/source/slang/slang-ir-call-graph.cpp b/source/slang/slang-ir-call-graph.cpp index ca68a5f4d7..08343879a2 100644 --- a/source/slang/slang-ir-call-graph.cpp +++ b/source/slang/slang-ir-call-graph.cpp @@ -7,16 +7,44 @@ namespace Slang { -void buildEntryPointReferenceGraph( - Dictionary>& referencingEntryPoints, - IRModule* module, - Dictionary>* referencingFunctions, - Dictionary>* referencingCalls) +CallGraph::CallGraph(IRModule* module) +{ + build(module); +} + +template +static void addToReferenceMap(Dictionary>& map, T key, U value) +{ + if (auto set = map.tryGetValue(key)) + { + set->add(value); + } + else + { + HashSet newSet; + newSet.add(value); + map.add(key, _Move(newSet)); + } +} + + +void CallGraph::registerInstructionReference(IRInst* inst, IRFunc* entryPoint, IRFunc* parentFunc) +{ + addToReferenceMap(m_referencingEntryPoints, inst, entryPoint); + addToReferenceMap(m_referencingFunctions, inst, parentFunc); +} + +void CallGraph::registerCallReference(IRFunc* func, IRCall* call) +{ + addToReferenceMap(m_referencingCalls, func, call); +} + +void CallGraph::build(IRModule* module) { struct WorkItem { - IRFunc* entryPoint; IRInst* inst; + IRFunc* entryPoint; IRFunc* parentFunc; HashCode getHashCode() const @@ -36,54 +64,11 @@ void buildEntryPointReferenceGraph( workList.add(item); }; - auto registerEntryPointReference = [&](IRFunc* entryPoint, IRInst* inst) - { - if (auto set = referencingEntryPoints.tryGetValue(inst)) - set->add(entryPoint); - else - { - HashSet newSet; - newSet.add(entryPoint); - referencingEntryPoints.add(inst, _Move(newSet)); - } - }; - - const auto registerFunctionReference = [&](IRFunc* func, IRInst* inst) - { - if (referencingFunctions && func) - { - if (auto set = referencingFunctions->tryGetValue(inst)) - set->add(func); - else - { - HashSet newSet; - newSet.add(func); - referencingFunctions->add(inst, _Move(newSet)); - } - } - }; - - auto registerCallReference = [&](IRCall* call, IRFunc* func) - { - if (referencingCalls) - { - if (auto set = referencingCalls->tryGetValue(func)) - set->add(call); - else - { - HashSet newSet; - newSet.add(call); - referencingCalls->add(func, _Move(newSet)); - } - } - }; - - auto visit = [&](IRFunc* entryPoint, IRInst* inst, IRFunc* parentFunc) + auto visit = [&](IRInst* inst, IRFunc* entryPoint, IRFunc* parentFunc) { if (auto code = as(inst)) { - registerEntryPointReference(entryPoint, inst); - registerFunctionReference(parentFunc, inst); + registerInstructionReference(inst, entryPoint, parentFunc); if (auto func = as(code)) { @@ -92,40 +77,39 @@ void buildEntryPointReferenceGraph( for (auto child : code->getChildren()) { - addToWorkList({entryPoint, child, parentFunc}); + addToWorkList({child, entryPoint, parentFunc}); } return; } + switch (inst->getOp()) { - // Only these instruction types are registered to the entry point reference graph. + // Only these instruction types and `IRGlobalValueWithCode` instructions are registered to + // the reference graph. + case kIROp_Call: + { + auto call = as(inst); + registerCallReference(as(call->getCallee()), call); + addToWorkList({call->getCallee(), entryPoint, parentFunc}); + } + [[fallthrough]]; case kIROp_GlobalParam: case kIROp_SPIRVAsmOperandBuiltinVar: case kIROp_ImplicitSystemValue: - registerEntryPointReference(entryPoint, inst); - registerFunctionReference(parentFunc, inst); + registerInstructionReference(inst, entryPoint, parentFunc); break; case kIROp_Block: case kIROp_SPIRVAsm: for (auto child : inst->getChildren()) { - addToWorkList({entryPoint, child, parentFunc}); - } - break; - case kIROp_Call: - registerEntryPointReference(entryPoint, inst); - registerFunctionReference(parentFunc, inst); - { - auto call = as(inst); - registerCallReference(call, as(call->getCallee())); - addToWorkList({entryPoint, call->getCallee(), parentFunc}); + addToWorkList({child, entryPoint, parentFunc}); } break; case kIROp_SPIRVAsmOperandInst: { auto operand = as(inst); - addToWorkList({entryPoint, operand->getValue(), parentFunc}); + addToWorkList({operand->getValue(), entryPoint, parentFunc}); } break; } @@ -137,7 +121,7 @@ void buildEntryPointReferenceGraph( case kIROp_GlobalParam: case kIROp_GlobalVar: case kIROp_SPIRVAsmOperandBuiltinVar: - addToWorkList({entryPoint, operand, parentFunc}); + addToWorkList({operand, entryPoint, parentFunc}); break; } } @@ -149,21 +133,40 @@ void buildEntryPointReferenceGraph( globalInst->findDecoration()) { auto entryPointFunc = as(globalInst); - visit(entryPointFunc, globalInst, nullptr); + visit(globalInst, entryPointFunc, nullptr); } } for (Index i = 0; i < workList.getCount(); i++) - visit(workList[i].entryPoint, workList[i].inst, workList[i].parentFunc); + visit(workList[i].inst, workList[i].entryPoint, workList[i].parentFunc); } -HashSet* getReferencingEntryPoints( - Dictionary>& m_referencingEntryPoints, - IRInst* inst) +const HashSet* CallGraph::getReferencingEntryPoints(IRInst* inst) const { - auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst); + const auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst); if (!referencingEntryPoints) return nullptr; return referencingEntryPoints; } +const HashSet* CallGraph::getReferencingFunctions(IRInst* inst) const +{ + const auto* referencingFunctions = m_referencingFunctions.tryGetValue(inst); + if (!referencingFunctions) + return nullptr; + return referencingFunctions; +} + +const HashSet* CallGraph::getReferencingCalls(IRFunc* func) const +{ + const auto* referencingCalls = m_referencingCalls.tryGetValue(func); + if (!referencingCalls) + return nullptr; + return referencingCalls; +} + +const Dictionary>& CallGraph::getReferencingEntryPointsMap() const +{ + return m_referencingEntryPoints; +} + } // namespace Slang diff --git a/source/slang/slang-ir-call-graph.h b/source/slang/slang-ir-call-graph.h index d2baaf340c..dfc2f2415f 100644 --- a/source/slang/slang-ir-call-graph.h +++ b/source/slang/slang-ir-call-graph.h @@ -1,32 +1,41 @@ +#pragma once + #include "slang-ir-clone.h" #include "slang-ir-insts.h" namespace Slang { -void buildEntryPointReferenceGraph( - Dictionary>& referencingEntryPoints, - IRModule* module, - Dictionary>* referencingFunctions = nullptr, - Dictionary>* referencingCalls = nullptr); +struct CallGraph +{ +public: + CallGraph() = default; + explicit CallGraph(IRModule* module); -HashSet* getReferencingEntryPoints( - Dictionary>& m_referencingEntryPoints, - IRInst* inst); + void build(IRModule* module); + /// Retrieves the set of entry points that invoke the given instruction in its call graph. + /// Returns nullptr if the instruction has no referencing entry points. + const HashSet* getReferencingEntryPoints(IRInst* inst) const; -/* -class FunctionCallGraph -{ -public: + /// Retrieves the set of functions that directly contain the given instruction in their body. + /// Returns nullptr if the instruction is not referenced by any function. const HashSet* getReferencingFunctions(IRInst* inst) const; - const HashSet* getFunctionCalls(IRFunc* func) const; + + /// Retrieves the set of calls that invoke the given function. + /// Returns nullptr if the function is never called. + const HashSet* getReferencingCalls(IRFunc* func) const; + + + const Dictionary>& getReferencingEntryPointsMap() const; private: + void registerInstructionReference(IRInst* inst, IRFunc* entryPoint, IRFunc* parentFunc); + void registerCallReference(IRFunc* func, IRCall* call); + + Dictionary> m_referencingEntryPoints; Dictionary> m_referencingFunctions; - Dictionary> m_functionCalls; + Dictionary> m_referencingCalls; }; -*/ - } // namespace Slang diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp index 50a9979e03..ccb3b19aaa 100644 --- a/source/slang/slang-ir-legalize-system-values.cpp +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -2,6 +2,7 @@ #include "core/slang-dictionary.h" #include "core/slang-string.h" +#include "slang-ir-call-graph.h" #include "slang-ir-insts.h" #include "slang-ir-legalize-varying-params.h" @@ -13,11 +14,9 @@ class ImplicitSystemValueLegalizationContext public: ImplicitSystemValueLegalizationContext( IRModule* module, - const Dictionary>& functionReferenceGraph, - const Dictionary>& callReferenceGraph, + const CallGraph& callGraph, const List& implicitSystemValueInstructions) - : m_functionReferenceGraph(functionReferenceGraph) - , m_callReferenceGraph(callReferenceGraph) + : m_callGraph(callGraph) , m_implicitSystemValueInstructions(implicitSystemValueInstructions) , m_builder(module) , m_paramType(m_builder.getUIntType()) @@ -28,9 +27,9 @@ class ImplicitSystemValueLegalizationContext { for (auto implicitSysVal : m_implicitSystemValueInstructions) { - for (auto entryPoint : *m_functionReferenceGraph.tryGetValue(implicitSysVal)) + for (auto parentFunc : *m_callGraph.getReferencingFunctions(implicitSysVal)) { - auto param = getOrCreateSystemValueVariable(entryPoint, implicitSysVal); + auto param = getOrCreateSystemValueVariable(parentFunc, implicitSysVal); implicitSysVal->replaceUsesWith(param); implicitSysVal->removeAndDeallocate(); } @@ -113,7 +112,7 @@ class ImplicitSystemValueLegalizationContext { for (auto call : calls) { - for (auto caller : m_functionReferenceGraph.getValue(call)) + for (auto caller : *m_callGraph.getReferencingFunctions(call)) { // The caller(of a function that was added a parameter) also requires a // new parameter to pass in the system value variable to the callee. @@ -142,7 +141,7 @@ class ImplicitSystemValueLegalizationContext fixUpFuncType(func); getParamMap(func).add(systemValueName, param); - if (auto calls = m_callReferenceGraph.tryGetValue(func)) + if (auto calls = m_callGraph.getReferencingCalls(func)) { addWorkItems(*calls); } @@ -213,8 +212,7 @@ class ImplicitSystemValueLegalizationContext return tryGetParam(parentFunc, systemValueName); } - const Dictionary>& m_functionReferenceGraph; - const Dictionary>& m_callReferenceGraph; + const CallGraph& m_callGraph; const List& m_implicitSystemValueInstructions; Dictionary m_functionMap; @@ -228,15 +226,10 @@ class ImplicitSystemValueLegalizationContext void legalizeImplicitSystemValues( IRModule* module, - const Dictionary>& functionReferenceGraph, - const Dictionary>& callReferenceGraph, + const CallGraph& callGraph, const List& implicitSystemValueInstructions) { - ImplicitSystemValueLegalizationContext( - module, - functionReferenceGraph, - callReferenceGraph, - implicitSystemValueInstructions) + ImplicitSystemValueLegalizationContext(module, callGraph, implicitSystemValueInstructions) .legalize(); } diff --git a/source/slang/slang-ir-legalize-system-values.h b/source/slang/slang-ir-legalize-system-values.h index 36cf3e9b55..4a142da140 100644 --- a/source/slang/slang-ir-legalize-system-values.h +++ b/source/slang/slang-ir-legalize-system-values.h @@ -1,5 +1,6 @@ // slang-ir-legalize-system-values.h #pragma once +#include "slang-ir-call-graph.h" #include "slang-ir-insts.h" namespace Slang @@ -7,8 +8,7 @@ namespace Slang void legalizeImplicitSystemValues( IRModule* module, - const Dictionary>& functionReferenceGraph, - const Dictionary>& callReferenceGraph, + const CallGraph& callGraph, const List& implicitSystemValueInstructions); } // namespace Slang diff --git a/source/slang/slang-ir-specialize-stage-switch.cpp b/source/slang/slang-ir-specialize-stage-switch.cpp index f65aa4d4cd..f984be9c1e 100644 --- a/source/slang/slang-ir-specialize-stage-switch.cpp +++ b/source/slang/slang-ir-specialize-stage-switch.cpp @@ -123,8 +123,7 @@ void specializeFuncToStage( void specializeStageSwitch(IRModule* module) { - Dictionary> mapInstToReferencingEntryPoints; - buildEntryPointReferenceGraph(mapInstToReferencingEntryPoints, module); + const auto callGraph = CallGraph(module); HashSet stageSpecificFunctions; discoverStageSpecificFunctions(stageSpecificFunctions, module); @@ -133,7 +132,7 @@ void specializeStageSwitch(IRModule* module) Dictionary> mapFuncToStageSpecializedFunc; for (auto func : stageSpecificFunctions) { - auto referencingEntryPoints = mapInstToReferencingEntryPoints.tryGetValue(func); + auto referencingEntryPoints = callGraph.getReferencingEntryPoints(func); if (!referencingEntryPoints) continue; if (func->findDecoration()) diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index c672180b70..e670f4a730 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -2,7 +2,6 @@ #include "slang-ir-spirv-legalize.h" #include "slang-emit-base.h" -#include "slang-ir-call-graph.h" #include "slang-ir-clone.h" #include "slang-ir-composite-reg-to-mem.h" #include "slang-ir-dce.h" @@ -2102,7 +2101,7 @@ static bool hasExplicitInterlockInst(IRFunc* func) void insertFragmentShaderInterlock(SPIRVEmitSharedContext* context, IRModule* module) { HashSet fragmentShaders; - for (auto& [inst, entryPoints] : context->m_referencingEntryPoints) + for (const auto& [inst, entryPoints] : context->m_callGraph.getReferencingEntryPointsMap()) { if (isRasterOrderedResource(inst)) { @@ -2154,7 +2153,7 @@ void legalizeIRForSPIRV( SLANG_UNUSED(entryPoints); legalizeSPIRV(context, module, codeGenContext->getSink()); simplifyIRForSpirvLegalization(context->m_targetProgram, codeGenContext->getSink(), module); - buildEntryPointReferenceGraph(context->m_referencingEntryPoints, module); + context->m_callGraph.build(module); insertFragmentShaderInterlock(context, module); } diff --git a/source/slang/slang-ir-spirv-legalize.h b/source/slang/slang-ir-spirv-legalize.h index 3c9bdf26a5..01adc7b604 100644 --- a/source/slang/slang-ir-spirv-legalize.h +++ b/source/slang/slang-ir-spirv-legalize.h @@ -1,6 +1,7 @@ // slang-ir-spirv-legalize.h #pragma once #include "../core/slang-basic.h" +#include "slang-ir-call-graph.h" #include "slang-ir-insts.h" #include "slang-ir-spirv-snippet.h" @@ -20,9 +21,8 @@ struct SPIRVEmitSharedContext TargetProgram* m_targetProgram; Dictionary> m_parsedSpvSnippets; - Dictionary> - m_referencingEntryPoints; // The entry-points that directly or transitively reference this - // global inst. + // Track entry-points that directly or transitively reference this global inst. + CallGraph m_callGraph; DiagnosticSink* m_sink; const SPIRVCoreGrammarInfo* m_grammarInfo; diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 7b2b8d1ee2..5912ccef96 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -13,8 +13,7 @@ struct GlobalVarTranslationContext void processModule(IRModule* module) { - Dictionary> referencingEntryPoints; - buildEntryPointReferenceGraph(referencingEntryPoints, module); + const auto callGraph = CallGraph(module); List entryPoints; List getWorkGroupSizeInsts; @@ -30,7 +29,7 @@ struct GlobalVarTranslationContext getWorkGroupSizeInsts.add(inst); } for (auto inst : getWorkGroupSizeInsts) - materializeGetWorkGroupSize(module, referencingEntryPoints, inst); + materializeGetWorkGroupSize(module, callGraph, inst); IRBuilder builder(module); for (auto entryPoint : entryPoints) @@ -39,7 +38,7 @@ struct GlobalVarTranslationContext List inputVars; for (auto inst : module->getGlobalInsts()) { - if (auto referencingEntryPointSet = referencingEntryPoints.tryGetValue(inst)) + if (auto referencingEntryPointSet = callGraph.getReferencingEntryPoints(inst)) { if (referencingEntryPointSet->contains((IRFunc*)entryPoint)) { @@ -266,7 +265,7 @@ struct GlobalVarTranslationContext // void materializeGetWorkGroupSize( IRModule* module, - Dictionary>& referenceGraph, + const CallGraph& callGraph, IRInst* workgroupSizeInst) { IRBuilder builder(workgroupSizeInst); @@ -276,7 +275,7 @@ struct GlobalVarTranslationContext { if (auto parentFunc = getParentFunc(use->getUser())) { - auto referenceSet = referenceGraph.tryGetValue(parentFunc); + auto referenceSet = callGraph.getReferencingEntryPoints(parentFunc); if (!referenceSet) return; if (referenceSet->getCount() == 1) diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 96cacdee36..75b49f013e 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1,6 +1,5 @@ #include "slang-ir-wgsl-legalize.h" -#include "slang-ir-call-graph.h" #include "slang-ir-insts.h" #include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" @@ -235,19 +234,10 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) // Legalize implicit system values to entry point parameters. if (instContext.implicitSystemValueInstructions.getCount() != 0) { - Dictionary> entryPointReferenceGraph; - // Dictionary> functionReferenceGraph; - Dictionary> functionReferenceGraph; - Dictionary> callReferenceGraph; - buildEntryPointReferenceGraph( - entryPointReferenceGraph, - module, - &functionReferenceGraph, - &callReferenceGraph); + const auto callGraph = CallGraph(module); legalizeImplicitSystemValues( module, - functionReferenceGraph, - callReferenceGraph, + callGraph, instContext.implicitSystemValueInstructions); } From 972a7839119815c5839ecb9ab01c7c83951ed0a5 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Thu, 13 Feb 2025 16:05:38 -0500 Subject: [PATCH 12/13] cleanup --- source/slang/hlsl.meta.slang | 2 +- source/slang/slang-emit-c-like.cpp | 4 --- source/slang/slang-emit-c-like.h | 6 ++-- source/slang/slang-ir-call-graph.cpp | 1 - source/slang/slang-ir-call-graph.h | 3 +- .../slang/slang-ir-legalize-system-values.cpp | 34 +++++++++++-------- 6 files changed, 24 insertions(+), 26 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 121470d761..7fbc616d03 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -20568,7 +20568,7 @@ ${ // In order to support this approach, we need intrinsics that // can magically fetch the binding information for a resource. // -// TODO: These operations are kind of *screaming* for us tohlsl.m +// TODO: These operations are kind of *screaming* for us to // have a built-in `interface` that all of the opaque resource // types conform to, so that we can define builtins that work // for any resource type. diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 70d83d7125..102159cf68 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3055,10 +3055,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO emitOperand(as(inst)->getOperand(0), getInfo(EmitOp::General)); break; } - case kIROp_ImplicitSystemValue: - { - break; - } case kIROp_RequireWGSLExtension: { emitRequireExtension(inst); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index f7f1d372f6..6040f3b302 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -414,18 +414,18 @@ class CLikeSourceEmitter : public SourceEmitterBase /// Emit type attributes that should appear after, e.g., a `struct` keyword void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); } - virtual void emitMemoryQualifiers(IRInst* /*varInst*/) {}; + virtual void emitMemoryQualifiers(IRInst* /*varInst*/){}; virtual void emitStructFieldAttributes( IRStructType* /* structType */, IRStructField* /* field */ - ) {}; + ){}; void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout); void emitMeshShaderModifiers(IRInst* varInst); virtual void emitPackOffsetModifier( IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/ - ) {}; + ){}; /// Emit modifiers that should apply even for a declaration of an SSA temporary. diff --git a/source/slang/slang-ir-call-graph.cpp b/source/slang/slang-ir-call-graph.cpp index 08343879a2..9677e55386 100644 --- a/source/slang/slang-ir-call-graph.cpp +++ b/source/slang/slang-ir-call-graph.cpp @@ -2,7 +2,6 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" -#include "slang-ir.h" namespace Slang { diff --git a/source/slang/slang-ir-call-graph.h b/source/slang/slang-ir-call-graph.h index dfc2f2415f..e6de6af121 100644 --- a/source/slang/slang-ir-call-graph.h +++ b/source/slang/slang-ir-call-graph.h @@ -14,7 +14,7 @@ struct CallGraph void build(IRModule* module); - /// Retrieves the set of entry points that invoke the given instruction in its call graph. + /// Retrieves the set of entry points that transitively invoke the given instruction. /// Returns nullptr if the instruction has no referencing entry points. const HashSet* getReferencingEntryPoints(IRInst* inst) const; @@ -26,7 +26,6 @@ struct CallGraph /// Returns nullptr if the function is never called. const HashSet* getReferencingCalls(IRFunc* func) const; - const Dictionary>& getReferencingEntryPointsMap() const; private: diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp index ccb3b19aaa..2e4dfb7136 100644 --- a/source/slang/slang-ir-legalize-system-values.cpp +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -27,6 +27,8 @@ class ImplicitSystemValueLegalizationContext { for (auto implicitSysVal : m_implicitSystemValueInstructions) { + // Call graph is guaranteed to return valid referencing functions(non nullptr) as + // instructions processed here are all valid/non-dead instructions. for (auto parentFunc : *m_callGraph.getReferencingFunctions(implicitSysVal)) { auto param = getOrCreateSystemValueVariable(parentFunc, implicitSysVal); @@ -105,7 +107,7 @@ class ImplicitSystemValueLegalizationContext (systemValueName == SystemValueSemanticName::WaveLaneCount) || (systemValueName == SystemValueSemanticName::WaveLaneIndex)); - List functionWorkList; + List createParamWorkList; List modifyCallWorkList; const auto addWorkItems = [&](const HashSet& calls) @@ -116,7 +118,7 @@ class ImplicitSystemValueLegalizationContext { // The caller(of a function that was added a parameter) also requires a // new parameter to pass in the system value variable to the callee. - functionWorkList.add(caller); + createParamWorkList.add(caller); // The call needs to be modified to account for the new parameter. modifyCallWorkList.add({call, caller}); @@ -141,6 +143,8 @@ class ImplicitSystemValueLegalizationContext fixUpFuncType(func); getParamMap(func).add(systemValueName, param); + + // Update all functions that call this function. if (auto calls = m_callGraph.getReferencingCalls(func)) { addWorkItems(*calls); @@ -148,24 +152,23 @@ class ImplicitSystemValueLegalizationContext } }; - functionWorkList.add(func); - for (Index i = 0; i < functionWorkList.getCount(); i++) + createParamWorkList.add(func); + for (Index i = 0; i < createParamWorkList.getCount(); i++) { - createParamWork(functionWorkList[i]); + createParamWork(createParamWorkList[i]); } return modifyCallWorkList; } // - // In addition to modifying the function signatures, we need to ensure the calls to the modified - // function include the system value variable. + // The function calls need to be modified to account for the change in function signature. // void modifyCalls( - const List& workItems, + const List& workList, SystemValueSemanticName systemValueName) { - for (const auto workItem : workItems) + for (const auto workItem : workList) { auto call = workItem.call; auto param = tryGetParam(workItem.caller, systemValueName); @@ -197,10 +200,8 @@ class ImplicitSystemValueLegalizationContext // If parameter for the specific function and system value type combination was already // created, return it directly. - if (auto param = tryGetParam(parentFunc, systemValueName)) - { - return param; - } + if (auto existingParam = tryGetParam(parentFunc, systemValueName)) + return existingParam; // Create new parameters for the relevant functions up to the entry point function. const auto callWorkItems = @@ -209,16 +210,19 @@ class ImplicitSystemValueLegalizationContext // Modify related function calls to account for the new parameters. modifyCalls(callWorkItems, systemValueName); - return tryGetParam(parentFunc, systemValueName); + auto newParam = tryGetParam(parentFunc, systemValueName); + SLANG_ASSERT(newParam); + return newParam; } const CallGraph& m_callGraph; const List& m_implicitSystemValueInstructions; Dictionary m_functionMap; - IRBuilder m_builder; + // Type of system value. + // // Implicit system values are currently only being used for subgroup size and // subgroup invocation id, both of which are 32-bit unsigned. IRType* m_paramType; From 624d7c74cd211417b7ffb76da5aaccf3ce35d352 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Thu, 13 Feb 2025 17:02:24 -0500 Subject: [PATCH 13/13] assert cleanup --- source/slang/slang-ir-legalize-system-values.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/source/slang/slang-ir-legalize-system-values.cpp b/source/slang/slang-ir-legalize-system-values.cpp index 2e4dfb7136..92b4aa1530 100644 --- a/source/slang/slang-ir-legalize-system-values.cpp +++ b/source/slang/slang-ir-legalize-system-values.cpp @@ -101,12 +101,6 @@ class ImplicitSystemValueLegalizationContext SystemValueSemanticName systemValueName, UnownedStringSlice systemValueString) { - // Implicit system values are currently only being used for subgroup size and - // subgroup invocation id. - SLANG_ASSERT( - (systemValueName == SystemValueSemanticName::WaveLaneCount) || - (systemValueName == SystemValueSemanticName::WaveLaneIndex)); - List createParamWorkList; List modifyCallWorkList; @@ -196,7 +190,12 @@ class ImplicitSystemValueLegalizationContext { auto systemValueName = convertSystemValueSemanticNameToEnum(implicitSysVal->getSystemValueName()); - SLANG_ASSERT(systemValueName != SystemValueSemanticName::Unknown); + + // Implicit system values are currently only being used for subgroup size and + // subgroup invocation id. + SLANG_ASSERT( + (systemValueName == SystemValueSemanticName::WaveLaneCount) || + (systemValueName == SystemValueSemanticName::WaveLaneIndex)); // If parameter for the specific function and system value type combination was already // created, return it directly.