Skip to content

Commit

Permalink
Split spvLoadVertex call across multiple lines
Browse files Browse the repository at this point in the history
  • Loading branch information
etang-cw committed Oct 4, 2023
1 parent 7d58bbd commit 3a1cac2
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const uchar* spvVertexBuffer0 [[buffer(0)]], device const uchar* spvVertexBuffer1 [[buffer(1)]], device const uchar* spvVertexBuffer2 [[buffer(2)]], device const uchar* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]], const device uint* spvVertexStrides [[buffer(19)]])
{
main0_out out = {};
main0_in in = spvLoadVertex(*reinterpret_cast<device const spvVertexData0*>(spvVertexBuffer0 + spvVertexStrides[0] * gl_InstanceIndex), *reinterpret_cast<device const spvVertexData1*>(spvVertexBuffer1 + spvVertexStrides[1] * gl_VertexIndex), *reinterpret_cast<device const spvVertexData2*>(spvVertexBuffer2 + spvVertexStrides[2] * gl_BaseInstance), *reinterpret_cast<device const spvVertexData3*>(spvVertexBuffer3 + spvVertexStrides[3] * (gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4)));
main0_in in = spvLoadVertex(*reinterpret_cast<device const spvVertexData0*>(spvVertexBuffer0 + spvVertexStrides[0] * gl_InstanceIndex),
*reinterpret_cast<device const spvVertexData1*>(spvVertexBuffer1 + spvVertexStrides[1] * gl_VertexIndex),
*reinterpret_cast<device const spvVertexData2*>(spvVertexBuffer2 + spvVertexStrides[2] * gl_BaseInstance),
*reinterpret_cast<device const spvVertexData3*>(spvVertexBuffer3 + spvVertexStrides[3] * (gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4)));
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ kernel void main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], d
uint gl_BaseVertex = spvDispatchBase.x;
uint gl_InstanceIndex = gl_GlobalInvocationID.y + spvDispatchBase.y;
uint gl_BaseInstance = spvDispatchBase.y;
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
}

Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], device const spvVertexData1* spvVertexBuffer1 [[buffer(1)]], device const spvVertexData2* spvVertexBuffer2 [[buffer(2)]], device const spvVertexData3* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]])
{
main0_out out = {};
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], device const spvVertexData1* spvVertexBuffer1 [[buffer(1)]], device const spvVertexData2* spvVertexBuffer2 [[buffer(2)]], device const spvVertexData3* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]])
{
main0_out out = {};
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
5 changes: 4 additions & 1 deletion reference/opt/shaders-msl/vert/attrs.vertex-loader.vert
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], device const spvVertexData1* spvVertexBuffer1 [[buffer(1)]], device const spvVertexData2* spvVertexBuffer2 [[buffer(2)]], device const spvVertexData3* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]])
{
main0_out out = {};
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const uchar* spvVertexBuffer0 [[buffer(0)]], device const uchar* spvVertexBuffer1 [[buffer(1)]], device const uchar* spvVertexBuffer2 [[buffer(2)]], device const uchar* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]], const device uint* spvVertexStrides [[buffer(19)]])
{
main0_out out = {};
main0_in in = spvLoadVertex(*reinterpret_cast<device const spvVertexData0*>(spvVertexBuffer0 + spvVertexStrides[0] * gl_InstanceIndex), *reinterpret_cast<device const spvVertexData1*>(spvVertexBuffer1 + spvVertexStrides[1] * gl_VertexIndex), *reinterpret_cast<device const spvVertexData2*>(spvVertexBuffer2 + spvVertexStrides[2] * gl_BaseInstance), *reinterpret_cast<device const spvVertexData3*>(spvVertexBuffer3 + spvVertexStrides[3] * (gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4)));
main0_in in = spvLoadVertex(*reinterpret_cast<device const spvVertexData0*>(spvVertexBuffer0 + spvVertexStrides[0] * gl_InstanceIndex),
*reinterpret_cast<device const spvVertexData1*>(spvVertexBuffer1 + spvVertexStrides[1] * gl_VertexIndex),
*reinterpret_cast<device const spvVertexData2*>(spvVertexBuffer2 + spvVertexStrides[2] * gl_BaseInstance),
*reinterpret_cast<device const spvVertexData3*>(spvVertexBuffer3 + spvVertexStrides[3] * (gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4)));
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ kernel void main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], d
uint gl_BaseVertex = spvDispatchBase.x;
uint gl_InstanceIndex = gl_GlobalInvocationID.y + spvDispatchBase.y;
uint gl_BaseInstance = spvDispatchBase.y;
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
}

Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], device const spvVertexData1* spvVertexBuffer1 [[buffer(1)]], device const spvVertexData2* spvVertexBuffer2 [[buffer(2)]], device const spvVertexData3* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]])
{
main0_out out = {};
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
5 changes: 4 additions & 1 deletion reference/shaders-msl/vert/attrs.vertex-loader.msl23.vert
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], device const spvVertexData1* spvVertexBuffer1 [[buffer(1)]], device const spvVertexData2* spvVertexBuffer2 [[buffer(2)]], device const spvVertexData3* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]])
{
main0_out out = {};
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
5 changes: 4 additions & 1 deletion reference/shaders-msl/vert/attrs.vertex-loader.vert
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVerte
vertex main0_out main0(device const spvVertexData0* spvVertexBuffer0 [[buffer(0)]], device const spvVertexData1* spvVertexBuffer1 [[buffer(1)]], device const spvVertexData2* spvVertexBuffer2 [[buffer(2)]], device const spvVertexData3* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]])
{
main0_out out = {};
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex], spvVertexBuffer1[gl_VertexIndex], spvVertexBuffer2[gl_BaseInstance], spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
main0_in in = spvLoadVertex(spvVertexBuffer0[gl_InstanceIndex],
spvVertexBuffer1[gl_VertexIndex],
spvVertexBuffer2[gl_BaseInstance],
spvVertexBuffer3[gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4]);
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}
Expand Down
64 changes: 38 additions & 26 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7527,27 +7527,37 @@ void CompilerMSL::prepare_shader_vertex_loader()
spv_function_implementations.insert(fn);
}

std::string load;
SmallVector<std::string> lines;
lines.push_back(get_name(get<SPIRVariable>(stage_in_var_id).basetype));
lines.back().push_back(' ');
lines.back().append(get_name(stage_in_var_id));
lines.back().append(" = spvLoadVertex(");
size_t indent_len = lines.back().size();
std::string *line = nullptr;

for (uint32_t i = 0; i < MSLVertexLoaderWriter::MaxBindings; i++)
{
const MSLVertexLoaderWriter::Binding &binding = vertex_loader_writer.get_binding(i);
if (!binding.used)
continue;
if (!load.empty())
load.append(", ");
if (line)
{
line->push_back(',');
lines.emplace_back(indent_len, ' ');
}
line = &lines.back();
std::string istr = std::to_string(i);
if (msl_options.vertex_loader_dynamic_stride)
{
load.append("*reinterpret_cast<device const spvVertexData");
load.append(istr);
load.append("*>(");
line->append("*reinterpret_cast<device const spvVertexData");
line->append(istr);
line->append("*>(");
}
load.append("spvVertexBuffer");
load.append(istr);
line->append("spvVertexBuffer");
line->append(istr);
if (binding.stride == 0 && !msl_options.vertex_loader_dynamic_stride)
{
load.append("[0]");
line->append("[0]");
}
else
{
Expand All @@ -7568,39 +7578,41 @@ void CompilerMSL::prepare_shader_vertex_loader()
}
if (msl_options.vertex_loader_dynamic_stride)
{
load.append(" + spvVertexStrides[");
load.append(istr);
load.append("] * ");
line->append(" + spvVertexStrides[");
line->append(istr);
line->append("] * ");
}
else
{
load.push_back('[');
line->push_back('[');
}
if (binding.divisor <= 1)
{
load.append(binding.divisor == 0 ? base : index);
line->append(binding.divisor == 0 ? base : index);
}
else
{
if (msl_options.vertex_loader_dynamic_stride)
load.push_back('(');
load.append(base);
load.append(" + (");
load.append(index);
load.append(" - ");
load.append(base);
load.append(") / ");
load.append(std::to_string(binding.divisor));
line->push_back('(');
line->append(base);
line->append(" + (");
line->append(index);
line->append(" - ");
line->append(base);
line->append(") / ");
line->append(std::to_string(binding.divisor));
if (msl_options.vertex_loader_dynamic_stride)
load.push_back(')');
line->push_back(')');
}
load.push_back(msl_options.vertex_loader_dynamic_stride ? ')' : ']');
line->push_back(msl_options.vertex_loader_dynamic_stride ? ')' : ']');
}
}
lines.back().append(");");

auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
entry_func.add_fixup_hook_in([this, load]{
statement(get_name(this->get<SPIRVariable>(stage_in_var_id).basetype), " ", get_name(stage_in_var_id), " = spvLoadVertex(", load, ");");
entry_func.add_fixup_hook_in([this, lines]{
for (const std::string& l : lines)
statement(l);
}, SPIRFunction::FixupInPriority::VertexLoad);
}

Expand Down

0 comments on commit 3a1cac2

Please sign in to comment.