Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WGSL support for WaveGetLaneIndex and WaveGetLaneCount #6357

Closed
2 changes: 2 additions & 0 deletions source/slang/glsl.meta.slang
Original file line number Diff line number Diff line change
@@ -6409,6 +6409,7 @@ public property uint gl_SubgroupID

public property uint gl_SubgroupSize
{
[ForceInline]
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
@@ -6418,6 +6419,7 @@ public property uint gl_SubgroupSize

public property uint gl_SubgroupInvocationID
{
[ForceInline]
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
18 changes: 16 additions & 2 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
@@ -6,6 +6,12 @@ typedef uint UINT;
__intrinsic_op($(kIROp_RequireGLSLExtension))
void __requireGLSLExtension(String extensionName);

__intrinsic_op($(kIROp_RequireWGSLExtension))
void __requireWGSLExtension(String extensionName);

__intrinsic_op($(kIROp_ImplicitSystemValue))
uint __implicitSystemValue(String systemValueName);

//@public:
/// Represents an interface for buffer data layout.
/// This interface is used as a base for defining specific data layouts for buffers.
@@ -15037,7 +15043,8 @@ uint WaveActiveCountBits(bool value)
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
[NonUniformReturn]
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
[ForceInline]
[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
uint WaveGetLaneCount()
{
__target_switch
@@ -15051,14 +15058,18 @@ uint WaveGetLaneCount()
OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupSize:uint)
};
case wgsl:
__requireWGSLExtension("subgroups");
return __implicitSystemValue("SV_WaveLaneCount");
}
}

/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
[NonUniformReturn]
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
[ForceInline]
[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
uint WaveGetLaneIndex()
{
__target_switch
@@ -15072,6 +15083,9 @@ uint WaveGetLaneIndex()
OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint)
};
case wgsl:
__requireWGSLExtension("subgroups");
return __implicitSystemValue("SV_WaveLaneIndex");
}
}

5 changes: 5 additions & 0 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
@@ -3055,6 +3055,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
emitOperand(as<IRGlobalValueRef>(inst)->getOperand(0), getInfo(EmitOp::General));
break;
}
case kIROp_RequireWGSLExtension:
{
emitRequireExtension(inst);
break;
}
default:
diagnoseUnhandledInst(inst);
break;
2 changes: 2 additions & 0 deletions source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
@@ -678,6 +678,8 @@ class CLikeSourceEmitter : public SourceEmitterBase
void _emitCallArgList(IRCall* call, int startingOperandIndex = 1);
virtual void emitCallArg(IRInst* arg);

virtual void emitRequireExtension(IRInst* inst) { SLANG_UNUSED(inst); }

String _generateUniqueName(const UnownedStringSlice& slice);

// Sort witnessTable entries according to the order defined in the witnessed interface type.
4 changes: 2 additions & 2 deletions source/slang/slang-emit-glsl.cpp
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction(
//
// Handle cases where "require" IR operations exist in the function body and are required
// as entry point decorations.
auto entryPoints = getReferencingEntryPoints(m_referencingEntryPoints, parentFunc);
auto entryPoints = m_callGraph.getReferencingEntryPoints(parentFunc);
if (entryPoints == nullptr)
return;

@@ -81,7 +81,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction(

void GLSLSourceEmitter::beforeComputeEmitActions(IRModule* module)
{
buildEntryPointReferenceGraph(this->m_referencingEntryPoints, module);
m_callGraph.build(module);

IRBuilder builder(module);
for (auto globalInst : module->getGlobalInsts())
3 changes: 2 additions & 1 deletion source/slang/slang-emit-glsl.h
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

#include "slang-emit-c-like.h"
#include "slang-extension-tracker.h"
#include "slang-ir-call-graph.h"

namespace Slang
{
@@ -178,7 +179,7 @@ class GLSLSourceEmitter : public CLikeSourceEmitter

void _beforeComputeEmitProcessInstruction(IRInst* parentFunc, IRInst* inst, IRBuilder& builder);

Dictionary<IRInst*, HashSet<IRFunc*>> m_referencingEntryPoints;
CallGraph m_callGraph;

RefPtr<ShaderExtensionTracker> m_glslExtensionTracker;
};
20 changes: 11 additions & 9 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
@@ -3553,8 +3553,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto parentFunc = getParentFunc(inst);

HashSet<IRFunc*>* entryPointsUsingInst =
getReferencingEntryPoints(m_referencingEntryPoints, parentFunc);
auto entryPointsUsingInst = m_callGraph.getReferencingEntryPoints(parentFunc);
for (IRFunc* entryPoint : *entryPointsUsingInst)
{
bool isQuad = true;
@@ -3608,8 +3607,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}

case kIROp_RequireMaximallyReconverges:
if (auto entryPointsUsingInst =
getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst)))
if (
// auto entryPointsUsingInst =
// getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst))
auto entryPointsUsingInst =
m_callGraph.getReferencingEntryPoints(getParentFunc(inst)))
{
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_maximal_reconvergence"));
for (IRFunc* entryPoint : *entryPointsUsingInst)
@@ -3623,7 +3625,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
break;
case kIROp_RequireQuadDerivatives:
if (auto entryPointsUsingInst =
getReferencingEntryPoints(m_referencingEntryPoints, getParentFunc(inst)))
m_callGraph.getReferencingEntryPoints(getParentFunc(inst)))
{
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_quad_control"));
requireSPIRVCapability(SpvCapabilityQuadControlKHR);
@@ -4326,7 +4328,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
if (m_mapIRInstToSpvInst.tryGetValue(globalInst, spvGlobalInst))
{
// Is this globalInst referenced by this entry point?
auto refSet = m_referencingEntryPoints.tryGetValue(globalInst);
auto refSet = m_callGraph.getReferencingEntryPoints(globalInst);
if (refSet && refSet->contains(entryPoint))
{
if (!isSpirv14OrLater())
@@ -5129,7 +5131,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex

bool isInstUsedInStage(IRInst* inst, Stage s)
{
auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst);
auto referencingEntryPoints = m_callGraph.getReferencingEntryPoints(inst);
if (!referencingEntryPoints)
return false;
for (auto entryPoint : *referencingEntryPoints)
@@ -5329,7 +5331,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
else if (semanticName == "sv_primitiveid")
{
auto entryPoints = m_referencingEntryPoints.tryGetValue(inst);
auto entryPoints = m_callGraph.getReferencingEntryPoints(inst);
// SPIRV requires `Geometry` capability being declared for a fragment
// shader, if that shader uses sv_primitiveid.
// We will check if this builtin is used by non-ray-tracing, non-geometry or
@@ -8029,7 +8031,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case SpvOpExecutionMode:
{
if (auto refEntryPointSet =
m_referencingEntryPoints.tryGetValue(getParentFunc(inst)))
m_callGraph.getReferencingEntryPoints(getParentFunc(inst)))
{
for (auto entryPoint : *refEntryPointSet)
{
5 changes: 5 additions & 0 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
@@ -1696,4 +1696,9 @@ void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
}
}

void WGSLSourceEmitter::emitRequireExtension(IRInst* inst)
{
_requireExtension(as<IRRequireWGSLExtension>(inst)->getExtensionName());
}

} // namespace Slang
2 changes: 2 additions & 0 deletions source/slang/slang-emit-wgsl.h
Original file line number Diff line number Diff line change
@@ -65,6 +65,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter

virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }

virtual void emitRequireExtension(IRInst* inst) SLANG_OVERRIDE;

private:
bool maybeEmitSystemSemantic(IRInst* inst);

120 changes: 86 additions & 34 deletions source/slang/slang-ir-call-graph.cpp
Original file line number Diff line number Diff line change
@@ -6,14 +6,45 @@
namespace Slang
{

void buildEntryPointReferenceGraph(
Dictionary<IRInst*, HashSet<IRFunc*>>& referencingEntryPoints,
IRModule* module)
CallGraph::CallGraph(IRModule* module)
{
build(module);
}

template<typename T, typename U>
static void addToReferenceMap(Dictionary<T, HashSet<U>>& map, T key, U value)
{
if (auto set = map.tryGetValue(key))
{
set->add(value);
}
else
{
HashSet<U> newSet;
newSet.add(value);
map.add(key, _Move(newSet));
}
}


void CallGraph::registerInstructionReference(IRInst* inst, IRFunc* entryPoint, IRFunc* parentFunc)
{
addToReferenceMap(m_referencingEntryPoints, inst, entryPoint);
addToReferenceMap(m_referencingFunctions, inst, parentFunc);
}

void CallGraph::registerCallReference(IRFunc* func, IRCall* call)
{
addToReferenceMap(m_referencingCalls, func, call);
}

void CallGraph::build(IRModule* module)
{
struct WorkItem
{
IRFunc* entryPoint;
IRInst* inst;
IRFunc* entryPoint;
IRFunc* parentFunc;

HashCode getHashCode() const
{
@@ -32,51 +63,52 @@ void buildEntryPointReferenceGraph(
workList.add(item);
};

auto registerEntryPointReference = [&](IRFunc* entryPoint, IRInst* inst)
{
if (auto set = referencingEntryPoints.tryGetValue(inst))
set->add(entryPoint);
else
{
HashSet<IRFunc*> newSet;
newSet.add(entryPoint);
referencingEntryPoints.add(inst, _Move(newSet));
}
};
auto visit = [&](IRFunc* entryPoint, IRInst* inst)
auto visit = [&](IRInst* inst, IRFunc* entryPoint, IRFunc* parentFunc)
{
if (auto code = as<IRGlobalValueWithCode>(inst))
{
registerEntryPointReference(entryPoint, inst);
registerInstructionReference(inst, entryPoint, parentFunc);

if (auto func = as<IRFunc>(code))
{
parentFunc = func;
}

for (auto child : code->getChildren())
{
addToWorkList({entryPoint, child});
addToWorkList({child, entryPoint, parentFunc});
}
return;
}

switch (inst->getOp())
{
// Only these instruction types and `IRGlobalValueWithCode` instructions are registered to
// the reference graph.
case kIROp_Call:
{
auto call = as<IRCall>(inst);
registerCallReference(as<IRFunc>(call->getCallee()), call);
addToWorkList({call->getCallee(), entryPoint, parentFunc});
}
[[fallthrough]];
case kIROp_GlobalParam:
case kIROp_SPIRVAsmOperandBuiltinVar:
registerEntryPointReference(entryPoint, inst);
case kIROp_ImplicitSystemValue:
registerInstructionReference(inst, entryPoint, parentFunc);
break;

case kIROp_Block:
case kIROp_SPIRVAsm:
for (auto child : inst->getChildren())
{
addToWorkList({entryPoint, child});
}
break;
case kIROp_Call:
{
auto call = as<IRCall>(inst);
addToWorkList({entryPoint, call->getCallee()});
addToWorkList({child, entryPoint, parentFunc});
}
break;
case kIROp_SPIRVAsmOperandInst:
{
auto operand = as<IRSPIRVAsmOperandInst>(inst);
addToWorkList({entryPoint, operand->getValue()});
addToWorkList({operand->getValue(), entryPoint, parentFunc});
}
break;
}
@@ -88,7 +120,7 @@ void buildEntryPointReferenceGraph(
case kIROp_GlobalParam:
case kIROp_GlobalVar:
case kIROp_SPIRVAsmOperandBuiltinVar:
addToWorkList({entryPoint, operand});
addToWorkList({operand, entryPoint, parentFunc});
break;
}
}
@@ -99,21 +131,41 @@ void buildEntryPointReferenceGraph(
if (globalInst->getOp() == kIROp_Func &&
globalInst->findDecoration<IREntryPointDecoration>())
{
visit(as<IRFunc>(globalInst), globalInst);
auto entryPointFunc = as<IRFunc>(globalInst);
visit(globalInst, entryPointFunc, nullptr);
}
}
for (Index i = 0; i < workList.getCount(); i++)
visit(workList[i].entryPoint, workList[i].inst);
visit(workList[i].inst, workList[i].entryPoint, workList[i].parentFunc);
}

HashSet<IRFunc*>* getReferencingEntryPoints(
Dictionary<IRInst*, HashSet<IRFunc*>>& m_referencingEntryPoints,
IRInst* inst)
const HashSet<IRFunc*>* CallGraph::getReferencingEntryPoints(IRInst* inst) const
{
auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst);
const auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst);
if (!referencingEntryPoints)
return nullptr;
return referencingEntryPoints;
}

const HashSet<IRFunc*>* CallGraph::getReferencingFunctions(IRInst* inst) const
{
const auto* referencingFunctions = m_referencingFunctions.tryGetValue(inst);
if (!referencingFunctions)
return nullptr;
return referencingFunctions;
}

const HashSet<IRCall*>* CallGraph::getReferencingCalls(IRFunc* func) const
{
const auto* referencingCalls = m_referencingCalls.tryGetValue(func);
if (!referencingCalls)
return nullptr;
return referencingCalls;
}

const Dictionary<IRInst*, HashSet<IRFunc*>>& CallGraph::getReferencingEntryPointsMap() const
{
return m_referencingEntryPoints;
}

} // namespace Slang
Loading