Skip to content

Commit d49ef85

Browse files
committed
mesh shader ext: add support for mesh shaders, based on @BeastLe9enD work
1 parent d63b3ed commit d49ef85

File tree

5 files changed

+95
-4
lines changed

5 files changed

+95
-4
lines changed

crates/rustc_codegen_spirv/src/linker/simple_passes.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ pub fn outgoing_edges(block: &Block) -> impl Iterator<Item = Word> + '_ {
9494
| Op::Kill
9595
| Op::Unreachable
9696
| Op::IgnoreIntersectionKHR
97-
| Op::TerminateRayKHR => (0..0).step_by(1),
97+
| Op::TerminateRayKHR
98+
| Op::EmitMeshTasksEXT => (0..0).step_by(1),
9899
_ => panic!("Invalid block terminator: {terminator:?}"),
99100
};
100101
operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref())

crates/rustc_codegen_spirv/src/spirv_type_constraints.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,8 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
955955
}
956956
// SPV_EXT_mesh_shader
957957
Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => {
958-
reserved!(SPV_EXT_mesh_shader)
958+
// NOTE(eddyb) we actually use these despite not being in the standard yet.
959+
// reserved!(SPV_EXT_mesh_shader)
959960
}
960961
// SPV_NV_ray_tracing_motion_blur
961962
Op::TraceMotionNV | Op::TraceRayMotionNV => reserved!(SPV_NV_ray_tracing_motion_blur),

crates/rustc_codegen_spirv/src/symbols.rs

+23-2
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ const BUILTINS: &[(&str, BuiltIn)] = {
122122
("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
123123
("bary_coord", BaryCoordKHR),
124124
("bary_coord_no_persp", BaryCoordNoPerspKHR),
125+
("primitive_point_indices_ext", PrimitivePointIndicesEXT),
126+
("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
127+
(
128+
"primitive_triangle_indices_ext",
129+
PrimitiveTriangleIndicesEXT,
130+
),
131+
("cull_primitive_ext", CullPrimitiveEXT),
125132
("frag_size_ext", FragSizeEXT),
126133
("frag_invocation_count_ext", FragInvocationCountEXT),
127134
("launch_id", BuiltIn::LaunchIdKHR),
@@ -171,6 +178,7 @@ const STORAGE_CLASSES: &[(&str, StorageClass)] = {
171178
("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
172179
("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
173180
("physical_storage_buffer", PhysicalStorageBuffer),
181+
("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
174182
]
175183
};
176184

@@ -185,6 +193,8 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
185193
("compute", GLCompute),
186194
("task_nv", TaskNV),
187195
("mesh_nv", MeshNV),
196+
("task_ext", TaskEXT),
197+
("mesh_ext", MeshEXT),
188198
("ray_generation", ExecutionModel::RayGenerationKHR),
189199
("intersection", ExecutionModel::IntersectionKHR),
190200
("any_hit", ExecutionModel::AnyHitKHR),
@@ -265,6 +275,17 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
265275
("output_primitives_nv", OutputPrimitivesNV, Value),
266276
("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
267277
("output_triangles_nv", OutputTrianglesNV, None),
278+
("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
279+
(
280+
"output_triangles_ext",
281+
ExecutionMode::OutputTrianglesEXT,
282+
None,
283+
),
284+
(
285+
"output_primitives_ext",
286+
ExecutionMode::OutputPrimitivesEXT,
287+
Value,
288+
),
268289
(
269290
"pixel_interlock_ordered_ext",
270291
PixelInterlockOrderedEXT,
@@ -717,7 +738,7 @@ fn parse_entry_attrs(
717738
.execution_modes
718739
.push((origin_mode, ExecutionModeExtra::new([])));
719740
}
720-
GLCompute | MeshNV | TaskNV => {
741+
GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
721742
if let Some(local_size) = local_size {
722743
entry
723744
.execution_modes
@@ -726,7 +747,7 @@ fn parse_entry_attrs(
726747
return Err((
727748
arg.span(),
728749
String::from(
729-
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]` or `#[spirv(task_nv)]`",
750+
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
730751
),
731752
));
732753
}

crates/spirv-std/src/arch.rs

+2
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ mod atomics;
1717
mod barrier;
1818
mod demote_to_helper_invocation_ext;
1919
mod derivative;
20+
mod mesh_shading;
2021
mod primitive;
2122
mod ray_tracing;
2223

2324
pub use atomics::*;
2425
pub use barrier::*;
2526
pub use demote_to_helper_invocation_ext::*;
2627
pub use derivative::*;
28+
pub use mesh_shading::*;
2729
pub use primitive::*;
2830
pub use ray_tracing::*;
2931

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#[cfg(target_arch = "spirv")]
2+
use core::arch::asm;
3+
4+
/// Sets the actual output size of the primitives and vertices that the mesh shader
5+
/// workgroup will emit upon completion.
6+
///
7+
/// 'Vertex Count' must be a 32-bit unsigned integer value.
8+
/// It defines the array size of per-vertex outputs.
9+
///
10+
/// 'Primitive Count' must a 32-bit unsigned integer value.
11+
/// It defines the array size of per-primitive outputs.
12+
///
13+
/// The arguments are taken from the first invocation in each workgroup.
14+
/// Any invocation must execute this instruction no more than once and under
15+
/// uniform control flow.
16+
/// There must not be any control flow path to an output write that is not preceded
17+
/// by this instruction.
18+
///
19+
/// This instruction is only valid in the *MeshEXT* Execution Model.
20+
#[spirv_std_macros::gpu_only]
21+
#[doc(alias = "OpSetMeshOutputsEXT")]
22+
#[inline]
23+
pub unsafe fn set_mesh_outputs_ext(vertex_count: u32, primitive_count: u32) {
24+
asm! {
25+
"OpSetMeshOutputsEXT {vertex_count} {primitive_count}",
26+
vertex_count = in(reg) vertex_count,
27+
primitive_count = in(reg) primitive_count,
28+
}
29+
}
30+
31+
/// Defines the grid size of subsequent mesh shader workgroups to generate
32+
/// upon completion of the task shader workgroup.
33+
///
34+
/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value.
35+
/// They configure the number of local workgroups in each respective dimensions
36+
/// for the launch of child mesh tasks. See Vulkan API specification for more detail.
37+
///
38+
/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations.
39+
/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*.
40+
///
41+
/// The arguments are taken from the first invocation in each workgroup.
42+
/// Any invocation must execute this instruction exactly once and under uniform
43+
/// control flow.
44+
/// This instruction also serves as an *OpControlBarrier* instruction, and also
45+
/// performs and adheres to the description and semantics of an *OpControlBarrier*
46+
/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and
47+
/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and
48+
/// *AcquireRelease*.
49+
/// Ceases all further processing: Only instructions executed before
50+
/// *OpEmitMeshTasksEXT* have observable side effects.
51+
///
52+
/// This instruction must be the last instruction in a block.
53+
///
54+
/// This instruction is only valid in the *TaskEXT* Execution Model.
55+
#[spirv_std_macros::gpu_only]
56+
#[doc(alias = "OpEmitMeshTasksEXT")]
57+
#[inline]
58+
pub unsafe fn emit_mesh_tasks_ext(group_count_x: u32, group_count_y: u32, group_count_z: u32) -> ! {
59+
asm! {
60+
"OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z}",
61+
group_count_x = in(reg) group_count_x,
62+
group_count_y = in(reg) group_count_y,
63+
group_count_z = in(reg) group_count_z,
64+
options(noreturn),
65+
}
66+
}

0 commit comments

Comments
 (0)