@@ -1955,38 +1955,48 @@ void CPPSourceEmitter::emitSimpleFuncImpl(IRFunc* func)
1955
1955
// Deal with decorations that need
1956
1956
// to be emitted as attributes
1957
1957
1958
- // We are going to ignore the parameters passed and just pass in the Context
1959
1958
1959
+ // We start by emitting the result type and function name.
1960
+ //
1960
1961
if (IREntryPointDecoration* entryPointDecor = func->findDecoration <IREntryPointDecoration>())
1961
1962
{
1963
+ // Note: we currently emit multiple functions to represent an entry point
1964
+ // on CPU/CUDA, and these all bottleneck through the actual `IRFunc`
1965
+ // here as a workhorse.
1966
+ //
1967
+ // Because the workhorse function is currently emitted as a member of
1968
+ // `KernelContext`, and doesn't have the right signature to service
1969
+ // general-purpose calls, it is being emitted with a `_` prefix.
1970
+ //
1962
1971
StringBuilder prefixName;
1963
1972
prefixName << " _" << name;
1964
1973
emitType (resultType, prefixName);
1965
- m_writer->emit (" ()\n " );
1966
1974
}
1967
1975
else
1968
1976
{
1969
1977
emitType (resultType, name);
1978
+ }
1970
1979
1971
- m_writer->emit (" (" );
1972
- auto firstParam = func->getFirstParam ();
1973
- for (auto pp = firstParam; pp; pp = pp->getNextParam ())
1974
- {
1975
- // Ingore TypeType-typed parameters for now.
1976
- // In the future we will pass around runtime type info
1977
- // for TypeType parameters.
1978
- if (as<IRTypeType>(pp->getFullType ()))
1979
- continue ;
1980
-
1981
- if (pp != firstParam)
1982
- m_writer->emit (" , " );
1980
+ // Next we emit the parameter list of the function.
1981
+ //
1982
+ m_writer->emit (" (" );
1983
+ auto firstParam = func->getFirstParam ();
1984
+ for (auto pp = firstParam; pp; pp = pp->getNextParam ())
1985
+ {
1986
+ // Ingore TypeType-typed parameters for now.
1987
+ // In the future we will pass around runtime type info
1988
+ // for TypeType parameters.
1989
+ if (as<IRTypeType>(pp->getFullType ()))
1990
+ continue ;
1983
1991
1984
- emitSimpleFuncParamImpl (pp);
1985
- }
1986
- m_writer->emit (" )" );
1992
+ if (pp != firstParam)
1993
+ m_writer->emit (" , " );
1987
1994
1988
- emitSemantics (func );
1995
+ emitSimpleFuncParamImpl (pp );
1989
1996
}
1997
+ m_writer->emit (" )" );
1998
+
1999
+ emitSemantics (func);
1990
2000
1991
2001
// TODO: encode declaration vs. definition
1992
2002
if (isDefinition (func))
@@ -2431,40 +2441,6 @@ void CPPSourceEmitter::emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPre
2431
2441
2432
2442
switch (inst->op )
2433
2443
{
2434
- case kIROp_Param :
2435
- {
2436
- auto varLayout = getVarLayout (inst);
2437
-
2438
- if (varLayout)
2439
- {
2440
- if (auto systemValueSemantic = varLayout->findSystemValueSemanticAttr ())
2441
- {
2442
- String semanticNameSpelling = systemValueSemantic->getName ();
2443
- semanticNameSpelling = semanticNameSpelling.toLower ();
2444
-
2445
- if (semanticNameSpelling == " sv_dispatchthreadid" )
2446
- {
2447
- m_semanticUsedFlags |= SemanticUsedFlag::DispatchThreadID;
2448
- m_writer->emit (" dispatchThreadID" );
2449
- return ;
2450
- }
2451
- else if (semanticNameSpelling == " sv_groupid" )
2452
- {
2453
- m_semanticUsedFlags |= SemanticUsedFlag::GroupID;
2454
- m_writer->emit (" groupID" );
2455
- return ;
2456
- }
2457
- else if (semanticNameSpelling == " sv_groupthreadid" )
2458
- {
2459
- m_semanticUsedFlags |= SemanticUsedFlag::GroupThreadID;
2460
- m_writer->emit (" calcGroupThreadID()" );
2461
- return ;
2462
- }
2463
- }
2464
- }
2465
- m_writer->emit (getName (inst));
2466
- break ;
2467
- }
2468
2444
case kIROp_Var :
2469
2445
case kIROp_GlobalVar :
2470
2446
emitVarExpr (inst, outerPrec);
@@ -2591,19 +2567,19 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup
2591
2567
const auto & axis = axes[i];
2592
2568
builder.Clear ();
2593
2569
const char elem[2 ] = { s_elemNames[axis.axis ], 0 };
2594
- builder << " for (uint32_t " << elem << " = start. " << elem << " ; " << elem << " < start. " << elem << " + " << axis.size << " ; ++" << elem << " )\n {\n " ;
2570
+ builder << " for (uint32_t " << elem << " = 0 ; " << elem << " < " << axis.size << " ; ++" << elem << " )\n {\n " ;
2595
2571
m_writer->emit (builder);
2596
2572
m_writer->indent ();
2597
2573
2598
2574
builder.Clear ();
2599
- builder << " context.dispatchThreadID ." << elem << " = " << elem << " ;\n " ;
2575
+ builder << " threadInput.groupThreadID ." << elem << " = " << elem << " ;\n " ;
2600
2576
m_writer->emit (builder);
2601
2577
}
2602
2578
2603
2579
// just call at inner loop point
2604
2580
m_writer->emit (" context._" );
2605
2581
m_writer->emit (funcName);
2606
- m_writer->emit (" ();\n " );
2582
+ m_writer->emit (" (&threadInput );\n " );
2607
2583
2608
2584
// Close all the loops
2609
2585
for (Index i = Index (axes.getCount () - 1 ); i >= 0 ; --i)
@@ -2626,57 +2602,20 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread
2626
2602
builder.Clear ();
2627
2603
const char elem[2 ] = { s_elemNames[axis.axis ], 0 };
2628
2604
2629
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2630
- {
2631
- builder << " context.groupDispatchThreadID." << elem << " = start." << elem << " ;\n " ;
2632
- }
2633
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2634
- {
2635
- builder << " context.groupID." << elem << " += varyingInput->startGroupID." << elem << " ;\n " ;
2636
- }
2637
-
2638
- builder << " for (uint32_t " << elem << " = start." << elem << " ; " << elem << " < end." << elem << " ; ++" << elem << " )\n {\n " ;
2605
+ builder << " for (uint32_t " << elem << " = vi.startGroupID." << elem << " ; " << elem << " < vi.endGroupID." << elem << " ; ++" << elem << " )\n {\n " ;
2639
2606
m_writer->emit (builder);
2640
2607
m_writer->indent ();
2641
2608
2642
- builder.Clear ();
2643
- builder << " context.dispatchThreadID." << elem << " = " << elem << " ;\n " ;
2644
-
2645
- if (m_semanticUsedFlags & (SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::GroupID))
2646
- {
2647
- if (sizeAlongAxis[axis.axis ] > 1 )
2648
- {
2649
- builder << " const uint32_t next = context.groupDispatchThreadID." << elem << " + " << axis.size <<" ;\n " ;
2650
-
2651
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2652
- {
2653
- builder << " context.groupID." << elem << " += uint32_t(next == " << elem << " );\n " ;
2654
- }
2655
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2656
- {
2657
- builder << " context.groupDispatchThreadID." << elem << " = (" << elem << " == next) ? next : context.groupDispatchThreadID." << elem << " ;\n " ;
2658
- }
2659
- }
2660
- else
2661
- {
2662
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2663
- {
2664
- builder << " context.groupDispatchThreadID." << elem << " = " << elem << " ;\n " ;
2665
- }
2666
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2667
- {
2668
- builder << " context.groupID." << elem << " = " << elem << " ;\n " ;
2669
- }
2670
- }
2671
- }
2672
-
2673
- m_writer->emit (builder);
2609
+ m_writer->emit (" groupVaryingInput.startGroupID." );
2610
+ m_writer->emit (elem);
2611
+ m_writer->emit (" = " );
2612
+ m_writer->emit (elem);
2613
+ m_writer->emit (" ;\n " );
2674
2614
}
2675
2615
2676
2616
// just call at inner loop point
2677
- m_writer->emit (" context._" );
2678
2617
m_writer->emit (funcName);
2679
- m_writer->emit (" ( );\n " );
2618
+ m_writer->emit (" _Group(&groupVaryingInput, entryPointParams, globalParams );\n " );
2680
2619
2681
2620
// Close all the loops
2682
2621
for (Index i = Index (axes.getCount () - 1 ); i >= 0 ; --i)
@@ -2736,6 +2675,34 @@ void CPPSourceEmitter::_emitForwardDeclarations(const List<EmitAction>& actions)
2736
2675
}
2737
2676
}
2738
2677
2678
+ static bool isVaryingResourceKind (LayoutResourceKind kind)
2679
+ {
2680
+ switch (kind)
2681
+ {
2682
+ default :
2683
+ return false ;
2684
+
2685
+ case LayoutResourceKind::VaryingInput:
2686
+ case LayoutResourceKind::VaryingOutput:
2687
+ return true ;
2688
+ }
2689
+ }
2690
+
2691
+ static bool isVaryingParameter (IRTypeLayout* typeLayout)
2692
+ {
2693
+ for (auto sizeAttr : typeLayout->getSizeAttrs ())
2694
+ {
2695
+ if (!isVaryingResourceKind (sizeAttr->getResourceKind ()))
2696
+ return false ;
2697
+ }
2698
+ return true ;
2699
+ }
2700
+
2701
+ static bool isVaryingParameter (IRVarLayout* varLayout)
2702
+ {
2703
+ return isVaryingParameter (varLayout->getTypeLayout ());
2704
+ }
2705
+
2739
2706
void CPPSourceEmitter::_findShaderParams (
2740
2707
IRGlobalParam** outEntryPointParam,
2741
2708
IRGlobalParam** outGlobalParam)
@@ -2752,6 +2719,20 @@ void CPPSourceEmitter::_findShaderParams(
2752
2719
if (!param)
2753
2720
continue ;
2754
2721
2722
+ if (auto layoutDecor = param->findDecoration <IRLayoutDecoration>())
2723
+ {
2724
+ if (auto varLayout = as<IRVarLayout>(layoutDecor->getLayout ()))
2725
+ {
2726
+ if (isVaryingParameter (varLayout))
2727
+ continue ;
2728
+ auto typeLayout = varLayout->getTypeLayout ();
2729
+ if (typeLayout->findSizeAttr (LayoutResourceKind::VaryingInput))
2730
+ continue ;
2731
+ if (typeLayout->findSizeAttr (LayoutResourceKind::VaryingOutput))
2732
+ continue ;
2733
+ }
2734
+ }
2735
+
2755
2736
// Currently, the entry-point parameters
2756
2737
// are represented as a single parameter
2757
2738
// at the global scope, and the same is
@@ -2806,28 +2787,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
2806
2787
m_writer->emit (" struct KernelContext\n {\n " );
2807
2788
m_writer->indent ();
2808
2789
2809
- m_writer->emit (" uint3 dispatchThreadID;\n " );
2810
-
2811
- // if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2812
- {
2813
- // Note not always set!
2814
- m_writer->emit (" uint3 groupID;\n " );
2815
- }
2816
-
2817
- // if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2818
- {
2819
- m_writer->emit (" uint3 groupDispatchThreadID;\n " );
2820
-
2821
- m_writer->emit (" uint3 calcGroupThreadID() const \n {\n " );
2822
- m_writer->indent ();
2823
- // groupThreadID = dispatchThreadID - groupDispatchThreadID
2824
- m_writer->emit (" uint3 v = { dispatchThreadID.x - groupDispatchThreadID.x, dispatchThreadID.y - groupDispatchThreadID.y, dispatchThreadID.z - groupDispatchThreadID.z };\n " );
2825
- m_writer->emit (" return v;\n " );
2826
- m_writer->dedent ();
2827
- m_writer->emit (" }\n " );
2828
- }
2829
-
2830
-
2831
2790
if (globalParams)
2832
2791
{
2833
2792
emitGlobalInst (globalParams);
@@ -2886,9 +2845,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
2886
2845
2887
2846
if (entryPointDecor && entryPointDecor->getProfile ().getStage () == Stage::Compute)
2888
2847
{
2889
- // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
2890
- // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
2891
-
2892
2848
Int groupThreadSize[kThreadGroupAxisCount ];
2893
2849
getComputeThreadGroupSize (func, groupThreadSize);
2894
2850
@@ -2902,23 +2858,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
2902
2858
2903
2859
_emitEntryPointDefinitionStart (func, entryPointParams, globalParams, threadFuncName, UnownedStringSlice::fromLiteral (" ComputeThreadVaryingInput" ));
2904
2860
2905
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2906
- {
2907
- m_writer->emit (" context.groupDispatchThreadID = " );
2908
- _emitInitAxisValues (groupThreadSize, UnownedStringSlice::fromLiteral (" varyingInput->groupID" ), UnownedStringSlice ());
2909
- }
2910
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2911
- {
2912
- m_writer->emit (" context.groupID = varyingInput->groupID;\n " );
2913
- }
2914
-
2915
- // Emit dispatchThreadID
2916
- m_writer->emit (" context.dispatchThreadID = " );
2917
- _emitInitAxisValues (groupThreadSize, UnownedStringSlice::fromLiteral (" varyingInput->groupID" ), UnownedStringSlice::fromLiteral (" varyingInput->groupThreadID" ));
2918
-
2919
2861
m_writer->emit (" context._" );
2920
2862
m_writer->emit (funcName);
2921
- m_writer->emit (" ();\n " );
2863
+ m_writer->emit (" (varyingInput );\n " );
2922
2864
2923
2865
_emitEntryPointDefinitionEnd (func);
2924
2866
}
@@ -2933,19 +2875,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
2933
2875
2934
2876
_emitEntryPointDefinitionStart (func, entryPointParams, globalParams, groupFuncName, UnownedStringSlice::fromLiteral (" ComputeVaryingInput" ));
2935
2877
2936
- m_writer->emit (" const uint3 start = " );
2937
- _emitInitAxisValues (groupThreadSize, UnownedStringSlice::fromLiteral (" varyingInput->startGroupID" ), UnownedStringSlice ());
2938
-
2939
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
2940
- {
2941
- m_writer->emit (" context.groupDispatchThreadID = start;\n " );
2942
- }
2943
-
2944
- if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
2945
- {
2946
- m_writer->emit (" context.groupID = varyingInput->startGroupID;\n " );
2947
- }
2948
- m_writer->emit (" context.dispatchThreadID = start;\n " );
2878
+ m_writer->emit (" ComputeThreadVaryingInput threadInput = {};\n " );
2879
+ m_writer->emit (" threadInput.groupID = varyingInput->startGroupID;\n " );
2949
2880
2950
2881
_emitEntryPointGroup (groupThreadSize, funcName);
2951
2882
_emitEntryPointDefinitionEnd (func);
@@ -2955,10 +2886,8 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
2955
2886
{
2956
2887
_emitEntryPointDefinitionStart (func, entryPointParams, globalParams, funcName, UnownedStringSlice::fromLiteral (" ComputeVaryingInput" ));
2957
2888
2958
- m_writer->emit (" const uint3 start = " );
2959
- _emitInitAxisValues (groupThreadSize, UnownedStringSlice::fromLiteral (" varyingInput->startGroupID" ), UnownedStringSlice ());
2960
- m_writer->emit (" const uint3 end = " );
2961
- _emitInitAxisValues (groupThreadSize, UnownedStringSlice::fromLiteral (" varyingInput->endGroupID" ), UnownedStringSlice ());
2889
+ m_writer->emit (" ComputeVaryingInput vi = *varyingInput;\n " );
2890
+ m_writer->emit (" ComputeVaryingInput groupVaryingInput = {};\n " );
2962
2891
2963
2892
_emitEntryPointGroupRange (groupThreadSize, funcName);
2964
2893
_emitEntryPointDefinitionEnd (func);
0 commit comments