Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WGSL support for WaveGetLaneIndex and WaveGetLaneCount #6357

Closed
2 changes: 2 additions & 0 deletions source/slang/glsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -6409,6 +6409,7 @@ public property uint gl_SubgroupID

public property uint gl_SubgroupSize
{
[ForceInline]
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
Expand All @@ -6418,6 +6419,7 @@ public property uint gl_SubgroupSize

public property uint gl_SubgroupInvocationID
{
[ForceInline]
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
Expand Down
18 changes: 16 additions & 2 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ typedef uint UINT;
__intrinsic_op($(kIROp_RequireGLSLExtension))
void __requireGLSLExtension(String extensionName);

__intrinsic_op($(kIROp_RequireWGSLExtension))
void __requireWGSLExtension(String extensionName);

__intrinsic_op($(kIROp_ImplicitSystemValue))
uint __implicitSystemValue(String systemValueName);

//@public:
/// Represents an interface for buffer data layout.
/// This interface is used as a base for defining specific data layouts for buffers.
Expand Down Expand Up @@ -15037,7 +15043,8 @@ uint WaveActiveCountBits(bool value)
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
[NonUniformReturn]
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
[ForceInline]
[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
uint WaveGetLaneCount()
{
__target_switch
Expand All @@ -15051,14 +15058,18 @@ uint WaveGetLaneCount()
OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupSize:uint)
};
case wgsl:
__requireWGSLExtension("subgroups");
return __implicitSystemValue("SV_WaveLaneCount");
}
}

/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
[NonUniformReturn]
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
[ForceInline]
[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
uint WaveGetLaneIndex()
{
__target_switch
Expand All @@ -15072,6 +15083,9 @@ uint WaveGetLaneIndex()
OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint)
};
case wgsl:
__requireWGSLExtension("subgroups");
return __implicitSystemValue("SV_WaveLaneIndex");
}
}

Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3055,6 +3055,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
emitOperand(as<IRGlobalValueRef>(inst)->getOperand(0), getInfo(EmitOp::General));
break;
}
case kIROp_RequireWGSLExtension:
{
emitRequireExtension(inst);
break;
}
default:
diagnoseUnhandledInst(inst);
break;
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ class CLikeSourceEmitter : public SourceEmitterBase
void _emitCallArgList(IRCall* call, int startingOperandIndex = 1);
virtual void emitCallArg(IRInst* arg);

virtual void emitRequireExtension(IRInst* inst) { SLANG_UNUSED(inst); }

String _generateUniqueName(const UnownedStringSlice& slice);

// Sort witnessTable entries according to the order defined in the witnessed interface type.
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-emit-glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction(
//
// Handle cases where "require" IR operations exist in the function body and are required
// as entry point decorations.
auto entryPoints = getReferencingEntryPoints(m_referencingEntryPoints, parentFunc);
auto entryPoints = m_callGraph.getReferencingEntryPoints(parentFunc);
if (entryPoints == nullptr)
return;

Expand Down Expand Up @@ -81,7 +81,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction(

void GLSLSourceEmitter::beforeComputeEmitActions(IRModule* module)
{
buildEntryPointReferenceGraph(this->m_referencingEntryPoints, module);
m_callGraph.build(module);

IRBuilder builder(module);
for (auto globalInst : module->getGlobalInsts())
Expand Down
3 changes: 2 additions & 1 deletion source/slang/slang-emit-glsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "slang-emit-c-like.h"
#include "slang-extension-tracker.h"
#include "slang-ir-call-graph.h"

namespace Slang
{
Expand Down Expand Up @@ -178,7 +179,7 @@ class GLSLSourceEmitter : public CLikeSourceEmitter

void _beforeComputeEmitProcessInstruction(IRInst* parentFunc, IRInst* inst, IRBuilder& builder);

Dictionary<IRInst*, HashSet<IRFunc*>> m_referencingEntryPoints;
CallGraph m_callGraph;

RefPtr<ShaderExtensionTracker> m_glslExtensionTracker;
};
Expand Down
20 changes: 11 additions & 9 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3553,8 +3553,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto parentFunc = getParentFunc(inst);

HashSet<IRFunc*>* entryPointsUsingInst =
getReferencingEntryPoints(m_referencingEntryPoints, parentFunc);
auto entryPointsUsingInst = m_callGraph.getReferencingEntryPoints(parentFunc);
for (IRFunc* entryPoint : *entryPointsUsingInst)
{
bool isQuad = true;
Expand Down Expand Up @@ -3608,8 +3607,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}

case kIROp_RequireMaximallyReconverges:
if (auto entryPointsUsingInst =
getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst)))
if (
// auto entryPointsUsingInst =
// getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst))
auto entryPointsUsingInst =
m_callGraph.getReferencingEntryPoints(getParentFunc(inst)))
{
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_maximal_reconvergence"));
for (IRFunc* entryPoint : *entryPointsUsingInst)
Expand All @@ -3623,7 +3625,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
break;
case kIROp_RequireQuadDerivatives:
if (auto entryPointsUsingInst =
getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst)))
m_callGraph.getReferencingEntryPoints(getParentFunc(inst)))
{
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_quad_control"));
requireSPIRVCapability(SpvCapabilityQuadControlKHR);
Expand Down Expand Up @@ -4326,7 +4328,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
if (m_mapIRInstToSpvInst.tryGetValue(globalInst, spvGlobalInst))
{
// Is this globalInst referenced by this entry point?
auto refSet = m_referencingEntryPoints.tryGetValue(globalInst);
auto refSet = m_callGraph.getReferencingEntryPoints(globalInst);
if (refSet && refSet->contains(entryPoint))
{
if (!isSpirv14OrLater())
Expand Down Expand Up @@ -5129,7 +5131,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex

bool isInstUsedInStage(IRInst* inst, Stage s)
{
auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst);
auto referencingEntryPoints = m_callGraph.getReferencingEntryPoints(inst);
if (!referencingEntryPoints)
return false;
for (auto entryPoint : *referencingEntryPoints)
Expand Down Expand Up @@ -5329,7 +5331,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
else if (semanticName == "sv_primitiveid")
{
auto entryPoints = m_referencingEntryPoints.tryGetValue(inst);
auto entryPoints = m_callGraph.getReferencingEntryPoints(inst);
// SPIRV requires `Geometry` capability being declared for a fragment
// shader, if that shader uses sv_primitiveid.
// We will check if this builtin is used by non-ray-tracing, non-geometry or
Expand Down Expand Up @@ -8029,7 +8031,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case SpvOpExecutionMode:
{
if (auto refEntryPointSet =
m_referencingEntryPoints.tryGetValue(getParentFunc(inst)))
m_callGraph.getReferencingEntryPoints(getParentFunc(inst)))
{
for (auto entryPoint : *refEntryPointSet)
{
Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1696,4 +1696,9 @@ void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
}
}

void WGSLSourceEmitter::emitRequireExtension(IRInst* inst)
{
_requireExtension(as<IRRequireWGSLExtension>(inst)->getExtensionName());
}

} // namespace Slang
2 changes: 2 additions & 0 deletions source/slang/slang-emit-wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter

virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }

virtual void emitRequireExtension(IRInst* inst) SLANG_OVERRIDE;

private:
bool maybeEmitSystemSemantic(IRInst* inst);

Expand Down
120 changes: 86 additions & 34 deletions source/slang/slang-ir-call-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,45 @@
namespace Slang
{

void buildEntryPointReferenceGraph(
Dictionary<IRInst*, HashSet<IRFunc*>>& referencingEntryPoints,
IRModule* module)
CallGraph::CallGraph(IRModule* module)
{
build(module);
}

template<typename T, typename U>
static void addToReferenceMap(Dictionary<T, HashSet<U>>& map, T key, U value)
{
if (auto set = map.tryGetValue(key))
{
set->add(value);
}
else
{
HashSet<U> newSet;
newSet.add(value);
map.add(key, _Move(newSet));
}
}


void CallGraph::registerInstructionReference(IRInst* inst, IRFunc* entryPoint, IRFunc* parentFunc)
{
addToReferenceMap(m_referencingEntryPoints, inst, entryPoint);
addToReferenceMap(m_referencingFunctions, inst, parentFunc);
}

void CallGraph::registerCallReference(IRFunc* func, IRCall* call)
{
addToReferenceMap(m_referencingCalls, func, call);
}

void CallGraph::build(IRModule* module)
{
struct WorkItem
{
IRFunc* entryPoint;
IRInst* inst;
IRFunc* entryPoint;
IRFunc* parentFunc;

HashCode getHashCode() const
{
Expand All @@ -32,51 +63,52 @@ void buildEntryPointReferenceGraph(
workList.add(item);
};

auto registerEntryPointReference = [&](IRFunc* entryPoint, IRInst* inst)
{
if (auto set = referencingEntryPoints.tryGetValue(inst))
set->add(entryPoint);
else
{
HashSet<IRFunc*> newSet;
newSet.add(entryPoint);
referencingEntryPoints.add(inst, _Move(newSet));
}
};
auto visit = [&](IRFunc* entryPoint, IRInst* inst)
auto visit = [&](IRInst* inst, IRFunc* entryPoint, IRFunc* parentFunc)
{
if (auto code = as<IRGlobalValueWithCode>(inst))
{
registerEntryPointReference(entryPoint, inst);
registerInstructionReference(inst, entryPoint, parentFunc);

if (auto func = as<IRFunc>(code))
{
parentFunc = func;
}

for (auto child : code->getChildren())
{
addToWorkList({entryPoint, child});
addToWorkList({child, entryPoint, parentFunc});
}
return;
}

switch (inst->getOp())
{
// Only these instruction types and `IRGlobalValueWithCode` instructions are registered to
// the reference graph.
case kIROp_Call:
{
auto call = as<IRCall>(inst);
registerCallReference(as<IRFunc>(call->getCallee()), call);
addToWorkList({call->getCallee(), entryPoint, parentFunc});
}
[[fallthrough]];
case kIROp_GlobalParam:
case kIROp_SPIRVAsmOperandBuiltinVar:
registerEntryPointReference(entryPoint, inst);
case kIROp_ImplicitSystemValue:
registerInstructionReference(inst, entryPoint, parentFunc);
break;

case kIROp_Block:
case kIROp_SPIRVAsm:
for (auto child : inst->getChildren())
{
addToWorkList({entryPoint, child});
}
break;
case kIROp_Call:
{
auto call = as<IRCall>(inst);
addToWorkList({entryPoint, call->getCallee()});
addToWorkList({child, entryPoint, parentFunc});
}
break;
case kIROp_SPIRVAsmOperandInst:
{
auto operand = as<IRSPIRVAsmOperandInst>(inst);
addToWorkList({entryPoint, operand->getValue()});
addToWorkList({operand->getValue(), entryPoint, parentFunc});
}
break;
}
Expand All @@ -88,7 +120,7 @@ void buildEntryPointReferenceGraph(
case kIROp_GlobalParam:
case kIROp_GlobalVar:
case kIROp_SPIRVAsmOperandBuiltinVar:
addToWorkList({entryPoint, operand});
addToWorkList({operand, entryPoint, parentFunc});
break;
}
}
Expand All @@ -99,21 +131,41 @@ void buildEntryPointReferenceGraph(
if (globalInst->getOp() == kIROp_Func &&
globalInst->findDecoration<IREntryPointDecoration>())
{
visit(as<IRFunc>(globalInst), globalInst);
auto entryPointFunc = as<IRFunc>(globalInst);
visit(globalInst, entryPointFunc, nullptr);
}
}
for (Index i = 0; i < workList.getCount(); i++)
visit(workList[i].entryPoint, workList[i].inst);
visit(workList[i].inst, workList[i].entryPoint, workList[i].parentFunc);
}

HashSet<IRFunc*>* getReferencingEntryPoints(
Dictionary<IRInst*, HashSet<IRFunc*>>& m_referencingEntryPoints,
IRInst* inst)
const HashSet<IRFunc*>* CallGraph::getReferencingEntryPoints(IRInst* inst) const
{
auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst);
const auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst);
if (!referencingEntryPoints)
return nullptr;
return referencingEntryPoints;
}

const HashSet<IRFunc*>* CallGraph::getReferencingFunctions(IRInst* inst) const
{
const auto* referencingFunctions = m_referencingFunctions.tryGetValue(inst);
if (!referencingFunctions)
return nullptr;
return referencingFunctions;
}

const HashSet<IRCall*>* CallGraph::getReferencingCalls(IRFunc* func) const
{
const auto* referencingCalls = m_referencingCalls.tryGetValue(func);
if (!referencingCalls)
return nullptr;
return referencingCalls;
}

const Dictionary<IRInst*, HashSet<IRFunc*>>& CallGraph::getReferencingEntryPointsMap() const
{
return m_referencingEntryPoints;
}

} // namespace Slang
Loading