Skip to content

Commit

Permalink
hoist entry point params for wgsl (#6116)
Browse files Browse the repository at this point in the history
Co-authored-by: Yong He <yonghe@outlook.com>
  • Loading branch information
fairywreath and csyonghe authored Jan 17, 2025
1 parent e743c17 commit 9143087
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 29 deletions.
20 changes: 6 additions & 14 deletions source/slang/slang-ir-legalize-varying-params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1574,8 +1574,8 @@ class LegalizeShaderEntryPointContext
}

protected:
LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, bool hoistParameters)
: m_module(module), m_sink(sink), hoistParameters(hoistParameters)
LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink)
: m_module(module), m_sink(sink)
{
}

Expand Down Expand Up @@ -1758,7 +1758,6 @@ class LegalizeShaderEntryPointContext
}

private:
const bool hoistParameters;
HashSet<IRStructField*> semanticInfoToRemove;

void removeSemanticLayoutsFromLegalizedStructs()
Expand Down Expand Up @@ -2985,16 +2984,9 @@ class LegalizeShaderEntryPointContext
// If the entrypoint is receiving varying inputs as a pointer, turn it into a value.
depointerizeInputParams(entryPoint.entryPointFunc);

// TODO FIXME: Enable these for WGSL and remove the `hoistParemeters` member field.
// WGSL entry point legalization currently only applies attributes to struct parameters,
// apply the same hoisting from Metal to WGSL to fix it.
if (hoistParameters)
{
hoistEntryPointParameterFromStruct(entryPoint);
packStageInParameters(entryPoint);
}

// Input Parameter Legalize
hoistEntryPointParameterFromStruct(entryPoint);
packStageInParameters(entryPoint);
flattenInputParameters(entryPoint);

// System Value Legalize
Expand Down Expand Up @@ -3023,7 +3015,7 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext
{
public:
LegalizeMetalEntryPointContext(IRModule* module, DiagnosticSink* sink)
: LegalizeShaderEntryPointContext(module, sink, true)
: LegalizeShaderEntryPointContext(module, sink)
{
generatePermittedTypes_sv_target();
}
Expand Down Expand Up @@ -3681,7 +3673,7 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
{
public:
LegalizeWGSLEntryPointContext(IRModule* module, DiagnosticSink* sink)
: LegalizeShaderEntryPointContext(module, sink, false)
: LegalizeShaderEntryPointContext(module, sink)
{
}

Expand Down
4 changes: 2 additions & 2 deletions tests/metal/nested-struct-fragment-input.slang
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
// METAL-NOT: [[ATTR3]]
// METAL-DAG: [[ATTR4:COARSEVERTEX_(1|2|3|4)]]

// WGSL:struct FragmentStageInput
// WGSL:struct pixelInput
// WGSL-DAG:@location(0) [[VAR0:[A-Za-z_0-9]+]]
// WGSL-DAG:@location(1) [[VAR1:[A-Za-z_0-9]+]]
// WGSL-DAG:@location(2) [[VAR2:[A-Za-z_0-9]+]]
Expand Down Expand Up @@ -78,7 +78,7 @@ float4 fragmentMain(FragmentStageInput input)
// METAL-DAG: {{.*}}->p3{{.*}}->p2{{.*}}->p1{{.*}}=
// METAL-DAG: {{.*}}->p3{{.*}}->p3{{.*}}->p1{{.*}}=

// WGSL: var [[UnpackedInput:[A-Za-z_0-9]+]] : FragmentStageInput
// WGSL: var [[UnpackedInput:[A-Za-z_0-9]+]] : pixelInput
// WGSL-DAG: [[UnpackedInput]].{{[A-Za-z_0-9]+}}.{{[A-Za-z_0-9]+}} = [[InputVar]].[[VAR7]];

// WGSL-DAG: [[UnpackedInput]].{{[A-Za-z_0-9]+}}.{{[A-Za-z_0-9]+}}.{{[A-Za-z_0-9]+}} = [[InputVar]].[[VAR6]];
Expand Down
6 changes: 2 additions & 4 deletions tests/metal/stage-in-2.slang
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
//WGSLSPIRV: %vertexMain = OpFunction %
//WGSLSPIRV: %fragmentMain = OpFunction %

//WGSL: struct [[CoarseVertex:CoarseVertex[_0-9]*]]
//WGSL-NEXT: {
//WGSL-NEXT: @location(0) color
//WGSL: fn fragmentMain({{.*}}[[CoarseVertex]]
//WGSL: @location(0) output
//WGSL: @location(0) color

// Uniform data to be passed from application -> shader.
cbuffer Uniforms
Expand Down
20 changes: 11 additions & 9 deletions tests/wgsl/nested-varying-input.slang
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ struct NestedVertexOutput

struct VertexOutput
{
//VERT: @builtin(position) position
//FRAG: @builtin(position) position
float4 position : SV_Position;

//VERT: @location(0) uv
//FRAG: @location(0) uv
//FRAG-DAG: @location(0) uv
float2 uv : TEXCOORD0;

//VERT: @location(1) color
//FRAG: @location(1) color
//FRAG-DAG: @location(1) color
NestedVertexOutput nested;

//VERT: @builtin(position) position
//FRAG-DAG: @builtin(position) position
float4 position : SV_Position;
};

VertexOutput vertexMain()
Expand All @@ -44,10 +44,12 @@ VertexOutput vertexMain()
return out;
}

FragmentOutput fragmentMain(VertexOutput input)
//FRAG-DAG: @location(3) color3
//FRAG-DAG: @location(6) color6
FragmentOutput fragmentMain(VertexOutput input, float4 color3: COLOR3, float4 color6: COLOR6)
{
FragmentOutput out;
out.color0 = input.nested.color;
out.color1 = input.nested.color;
out.color0 = input.nested.color + color3;
out.color1 = input.nested.color + color6;
return out;
}

0 comments on commit 9143087

Please sign in to comment.