Skip to content

Commit 249f48d

Browse files
author
Tim Foley
authored
CUDA/CPU varying compute inputs as IR pass (shader-slang#1438)
The main change here is that the CPU and CUDA C++ emit paths now rely on an earlier IR pass to legalize the varying parameter list of a kernel and translate references to varying parameters with semantics like `SV_DispatchThreadID`. Doing so removes a lot of special-case logic from the emit passes. This work moves us even closer to being able to eliminate `KernelContext` from the CPU/CUDA emit logic, because it removes the issue of state related to varying inputs being stored in `KernelContext`. The new pass that handles the legalization is in `slang-ir-legalize-varying-params.cpp`, and it borrows heavily from the existing `slang-ir-glsl-legalize.cpp` pass. The new pass factors out the target-independent and target-dependent logic, so that both CPU and CUDA can share much of the same code despite having very different rules for how the system-value parameters are being provided. An eventual goal is to have the new pass also handle the GLSL case, but doing so requires copying even more logic out of the GLSL-specific pass, and doing so seemed like a step to far for what was meant to be a stepping-stone change as part of other work. As a result of the incomplete nature of the pass, certain cases don't work for compute shader inputs for CPU/CUDA (e.g., wrapping your varying inputs in a `struct` type parameter), but those were cases that also didn't work in the existing `emit`-based logic. One major consequence of this change is that the logic for emitting the various different functions that represent an entry point for our CPU back-end has been streamlined and simplified. The original logic had a fair bit of cleverness built in to try and avoid unnecessary math ops when computing the various IDs/indices, while the new logic is much more simplistic (the main dispatch function loops over threadgroups with a triply-nested `for` and then delegates to the group-level function loops over threads with its own nested `for`s). Longer term, it will be important to simplify the CPU functions we emit further, by eliminating things like the `_Thread` function that should never really be exposed to users (the minimum granularity of invoking a CPU compute kernel should be a single threadgroup). We may eventually decide to synthesize all of the extra code that is being generated in the `emit` pass as IR instead.
1 parent 6aad38a commit 249f48d

14 files changed

+1539
-214
lines changed

source/core/core.vcxproj.filters

+3-3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@
9393
<ClInclude Include="slang-shared-library.h">
9494
<Filter>Header Files</Filter>
9595
</ClInclude>
96+
<ClInclude Include="slang-short-list.h">
97+
<Filter>Header Files</Filter>
98+
</ClInclude>
9699
<ClInclude Include="slang-smart-pointer.h">
97100
<Filter>Header Files</Filter>
98101
</ClInclude>
@@ -138,9 +141,6 @@
138141
<ClInclude Include="windows\slang-win-visual-studio-util.h">
139142
<Filter>Header Files</Filter>
140143
</ClInclude>
141-
<ClInclude Include="slang-short-list.h">
142-
<Filter>Header Files</Filter>
143-
</ClInclude>
144144
</ItemGroup>
145145
<ItemGroup>
146146
<ClCompile Include="slang-blob.cpp">

source/slang/slang-emit-cpp.cpp

+85-156
Original file line numberDiff line numberDiff line change
@@ -1955,38 +1955,48 @@ void CPPSourceEmitter::emitSimpleFuncImpl(IRFunc* func)
19551955
// Deal with decorations that need
19561956
// to be emitted as attributes
19571957

1958-
// We are going to ignore the parameters passed and just pass in the Context
19591958

1959+
// We start by emitting the result type and function name.
1960+
//
19601961
if (IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>())
19611962
{
1963+
// Note: we currently emit multiple functions to represent an entry point
1964+
// on CPU/CUDA, and these all bottleneck through the actual `IRFunc`
1965+
// here as a workhorse.
1966+
//
1967+
// Because the workhorse function is currently emitted as a member of
1968+
// `KernelContext`, and doesn't have the right signature to service
1969+
// general-purpose calls, it is being emitted with a `_` prefix.
1970+
//
19621971
StringBuilder prefixName;
19631972
prefixName << "_" << name;
19641973
emitType(resultType, prefixName);
1965-
m_writer->emit("()\n");
19661974
}
19671975
else
19681976
{
19691977
emitType(resultType, name);
1978+
}
19701979

1971-
m_writer->emit("(");
1972-
auto firstParam = func->getFirstParam();
1973-
for (auto pp = firstParam; pp; pp = pp->getNextParam())
1974-
{
1975-
// Ingore TypeType-typed parameters for now.
1976-
// In the future we will pass around runtime type info
1977-
// for TypeType parameters.
1978-
if (as<IRTypeType>(pp->getFullType()))
1979-
continue;
1980-
1981-
if (pp != firstParam)
1982-
m_writer->emit(", ");
1980+
// Next we emit the parameter list of the function.
1981+
//
1982+
m_writer->emit("(");
1983+
auto firstParam = func->getFirstParam();
1984+
for (auto pp = firstParam; pp; pp = pp->getNextParam())
1985+
{
1986+
// Ingore TypeType-typed parameters for now.
1987+
// In the future we will pass around runtime type info
1988+
// for TypeType parameters.
1989+
if (as<IRTypeType>(pp->getFullType()))
1990+
continue;
19831991

1984-
emitSimpleFuncParamImpl(pp);
1985-
}
1986-
m_writer->emit(")");
1992+
if (pp != firstParam)
1993+
m_writer->emit(", ");
19871994

1988-
emitSemantics(func);
1995+
emitSimpleFuncParamImpl(pp);
19891996
}
1997+
m_writer->emit(")");
1998+
1999+
emitSemantics(func);
19902000

19912001
// TODO: encode declaration vs. definition
19922002
if (isDefinition(func))
@@ -2431,40 +2441,6 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre
24312441

24322442
switch (inst->op)
24332443
{
2434-
case kIROp_Param:
2435-
{
2436-
auto varLayout = getVarLayout(inst);
2437-
2438-
if (varLayout)
2439-
{
2440-
if(auto systemValueSemantic = varLayout->findSystemValueSemanticAttr())
2441-
{
2442-
String semanticNameSpelling = systemValueSemantic->getName();
2443-
semanticNameSpelling = semanticNameSpelling.toLower();
2444-
2445-
if (semanticNameSpelling == "sv_dispatchthreadid")
2446-
{
2447-
m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID;
2448-
m_writer->emit("dispatchThreadID");
2449-
return;
2450-
}
2451-
else if (semanticNameSpelling == "sv_groupid")
2452-
{
2453-
m_semanticUsedFlags |= SemanticUsedFlag::GroupID;
2454-
m_writer->emit("groupID");
2455-
return;
2456-
}
2457-
else if (semanticNameSpelling == "sv_groupthreadid")
2458-
{
2459-
m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID;
2460-
m_writer->emit("calcGroupThreadID()");
2461-
return;
2462-
}
2463-
}
2464-
}
2465-
m_writer->emit(getName(inst));
2466-
break;
2467-
}
24682444
case kIROp_Var:
24692445
case kIROp_GlobalVar:
24702446
emitVarExpr(inst, outerPrec);
@@ -2591,19 +2567,19 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup
25912567
const auto& axis = axes[i];
25922568
builder.Clear();
25932569
const char elem[2] = { s_elemNames[axis.axis], 0 };
2594-
builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < start." << elem << " + " << axis.size << "; ++" << elem << ")\n{\n";
2570+
builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n";
25952571
m_writer->emit(builder);
25962572
m_writer->indent();
25972573

25982574
builder.Clear();
2599-
builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
2575+
builder << "threadInput.groupThreadID." << elem << " = " << elem << ";\n";
26002576
m_writer->emit(builder);
26012577
}
26022578

26032579
// just call at inner loop point
26042580
m_writer->emit("context._");
26052581
m_writer->emit(funcName);
2606-
m_writer->emit("();\n");
2582+
m_writer->emit("(&threadInput);\n");
26072583

26082584
// Close all the loops
26092585
for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
@@ -2626,57 +2602,20 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread
26262602
builder.Clear();
26272603
const char elem[2] = { s_elemNames[axis.axis], 0 };
26282604

2629-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2630-
{
2631-
builder << "context.groupDispatchThreadID." << elem << " = start." << elem << ";\n";
2632-
}
2633-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2634-
{
2635-
builder << "context.groupID." << elem << " += varyingInput->startGroupID." << elem << ";\n";
2636-
}
2637-
2638-
builder << "for (uint32_t " << elem << " = start." << elem << "; " << elem << " < end." << elem << "; ++" << elem << ")\n{\n";
2605+
builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n";
26392606
m_writer->emit(builder);
26402607
m_writer->indent();
26412608

2642-
builder.Clear();
2643-
builder << "context.dispatchThreadID." << elem << " = " << elem << ";\n";
2644-
2645-
if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID))
2646-
{
2647-
if (sizeAlongAxis[axis.axis] > 1)
2648-
{
2649-
builder << "const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<";\n";
2650-
2651-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2652-
{
2653-
builder << "context.groupID." << elem << " += uint32_t(next == " << elem << ");\n";
2654-
}
2655-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2656-
{
2657-
builder << "context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << ";\n";
2658-
}
2659-
}
2660-
else
2661-
{
2662-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2663-
{
2664-
builder << "context.groupDispatchThreadID." << elem << " = " << elem << ";\n";
2665-
}
2666-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2667-
{
2668-
builder << "context.groupID." << elem << " = " << elem << ";\n";
2669-
}
2670-
}
2671-
}
2672-
2673-
m_writer->emit(builder);
2609+
m_writer->emit("groupVaryingInput.startGroupID.");
2610+
m_writer->emit(elem);
2611+
m_writer->emit(" = ");
2612+
m_writer->emit(elem);
2613+
m_writer->emit(";\n");
26742614
}
26752615

26762616
// just call at inner loop point
2677-
m_writer->emit("context._");
26782617
m_writer->emit(funcName);
2679-
m_writer->emit("();\n");
2618+
m_writer->emit("_Group(&groupVaryingInput, entryPointParams, globalParams);\n");
26802619

26812620
// Close all the loops
26822621
for (Index i = Index(axes.getCount() - 1); i >= 0; --i)
@@ -2736,6 +2675,34 @@ void CPPSourceEmitter::_emitForwardDeclarations(const List<EmitAction>& actions)
27362675
}
27372676
}
27382677

2678+
static bool isVaryingResourceKind(LayoutResourceKind kind)
2679+
{
2680+
switch(kind)
2681+
{
2682+
default:
2683+
return false;
2684+
2685+
case LayoutResourceKind::VaryingInput:
2686+
case LayoutResourceKind::VaryingOutput:
2687+
return true;
2688+
}
2689+
}
2690+
2691+
static bool isVaryingParameter(IRTypeLayout* typeLayout)
2692+
{
2693+
for(auto sizeAttr : typeLayout->getSizeAttrs())
2694+
{
2695+
if(!isVaryingResourceKind(sizeAttr->getResourceKind()))
2696+
return false;
2697+
}
2698+
return true;
2699+
}
2700+
2701+
static bool isVaryingParameter(IRVarLayout* varLayout)
2702+
{
2703+
return isVaryingParameter(varLayout->getTypeLayout());
2704+
}
2705+
27392706
void CPPSourceEmitter::_findShaderParams(
27402707
IRGlobalParam** outEntryPointParam,
27412708
IRGlobalParam** outGlobalParam)
@@ -2752,6 +2719,20 @@ void CPPSourceEmitter::_findShaderParams(
27522719
if(!param)
27532720
continue;
27542721

2722+
if(auto layoutDecor = param->findDecoration<IRLayoutDecoration>())
2723+
{
2724+
if(auto varLayout = as<IRVarLayout>(layoutDecor->getLayout()))
2725+
{
2726+
if(isVaryingParameter(varLayout))
2727+
continue;
2728+
auto typeLayout = varLayout->getTypeLayout();
2729+
if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingInput))
2730+
continue;
2731+
if(typeLayout->findSizeAttr(LayoutResourceKind::VaryingOutput))
2732+
continue;
2733+
}
2734+
}
2735+
27552736
// Currently, the entry-point parameters
27562737
// are represented as a single parameter
27572738
// at the global scope, and the same is
@@ -2806,28 +2787,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
28062787
m_writer->emit("struct KernelContext\n{\n");
28072788
m_writer->indent();
28082789

2809-
m_writer->emit("uint3 dispatchThreadID;\n");
2810-
2811-
//if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2812-
{
2813-
// Note not always set!
2814-
m_writer->emit("uint3 groupID;\n");
2815-
}
2816-
2817-
//if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2818-
{
2819-
m_writer->emit("uint3 groupDispatchThreadID;\n");
2820-
2821-
m_writer->emit("uint3 calcGroupThreadID() const \n{\n");
2822-
m_writer->indent();
2823-
// groupThreadID = dispatchThreadID - groupDispatchThreadID
2824-
m_writer->emit("uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z };\n");
2825-
m_writer->emit("return v;\n");
2826-
m_writer->dedent();
2827-
m_writer->emit("}\n");
2828-
}
2829-
2830-
28312790
if (globalParams)
28322791
{
28332792
emitGlobalInst(globalParams);
@@ -2886,9 +2845,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
28862845

28872846
if (entryPointDecor && entryPointDecor->getProfile().getStage() == Stage::Compute)
28882847
{
2889-
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
2890-
// SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
2891-
28922848
Int groupThreadSize[kThreadGroupAxisCount];
28932849
getComputeThreadGroupSize(func, groupThreadSize);
28942850

@@ -2902,23 +2858,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
29022858

29032859
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, threadFuncName, UnownedStringSlice::fromLiteral("ComputeThreadVaryingInput"));
29042860

2905-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2906-
{
2907-
m_writer->emit("context.groupDispatchThreadID = ");
2908-
_emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
2909-
}
2910-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2911-
{
2912-
m_writer->emit("context.groupID = varyingInput->groupID;\n");
2913-
}
2914-
2915-
// Emit dispatchThreadID
2916-
m_writer->emit("context.dispatchThreadID = ");
2917-
_emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
2918-
29192861
m_writer->emit("context._");
29202862
m_writer->emit(funcName);
2921-
m_writer->emit("();\n");
2863+
m_writer->emit("(varyingInput);\n");
29222864

29232865
_emitEntryPointDefinitionEnd(func);
29242866
}
@@ -2933,19 +2875,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
29332875

29342876
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
29352877

2936-
m_writer->emit("const uint3 start = ");
2937-
_emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
2938-
2939-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2940-
{
2941-
m_writer->emit("context.groupDispatchThreadID = start;\n");
2942-
}
2943-
2944-
if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2945-
{
2946-
m_writer->emit("context.groupID = varyingInput->startGroupID;\n");
2947-
}
2948-
m_writer->emit("context.dispatchThreadID = start;\n");
2878+
m_writer->emit("ComputeThreadVaryingInput threadInput = {};\n");
2879+
m_writer->emit("threadInput.groupID = varyingInput->startGroupID;\n");
29492880

29502881
_emitEntryPointGroup(groupThreadSize, funcName);
29512882
_emitEntryPointDefinitionEnd(func);
@@ -2955,10 +2886,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
29552886
{
29562887
_emitEntryPointDefinitionStart(func, entryPointParams, globalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
29572888

2958-
m_writer->emit("const uint3 start = ");
2959-
_emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
2960-
m_writer->emit("const uint3 end = ");
2961-
_emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
2889+
m_writer->emit("ComputeVaryingInput vi = *varyingInput;\n");
2890+
m_writer->emit("ComputeVaryingInput groupVaryingInput = {};\n");
29622891

29632892
_emitEntryPointGroupRange(groupThreadSize, funcName);
29642893
_emitEntryPointDefinitionEnd(func);

0 commit comments

Comments
 (0)