Skip to content

Commit 6c24f1f

Browse files
committed
shader-rt: correct arg buffer handling
1 parent ef565f9 commit 6c24f1f

File tree

5 files changed

+152
-62
lines changed

5 files changed

+152
-62
lines changed

node-graph/gcore-shaders/src/blending.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ impl AlphaBlending {
6666
}
6767

6868
#[repr(i32)]
69-
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash)]
69+
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, bytemuck::NoUninit)]
7070
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
7171
pub enum BlendMode {
7272
// Basic group

node-graph/graster-nodes/src/adjustments.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use num_traits::float::Float;
3030
// https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=%27clrL%27%20%3D%20Color%20Lookup
3131
// https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=Color%20Lookup%20(Photoshop%20CS6
3232

33-
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, node_macro::ChoiceType)]
33+
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, node_macro::ChoiceType, bytemuck::NoUninit)]
3434
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
3535
#[widget(Dropdown)]
3636
#[repr(u32)]
@@ -482,7 +482,7 @@ pub enum RedGreenBlue {
482482
}
483483

484484
/// Color Channel
485-
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, node_macro::ChoiceType)]
485+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, node_macro::ChoiceType, bytemuck::NoUninit)]
486486
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
487487
#[widget(Radio)]
488488
#[repr(u32)]

node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ impl ShaderCodegen for PerPixelAdjust {
2222
fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> {
2323
let fn_name = &parsed.fn_name;
2424

25-
// categorize params and assign image bindings
26-
// bindings for images start at 1
27-
let params = {
28-
let mut binding_cnt = 0;
29-
parsed
25+
let mut params;
26+
let has_uniform;
27+
{
28+
// categorize params
29+
params = parsed
3030
.fields
3131
.iter()
3232
.map(|f| {
@@ -39,30 +39,50 @@ impl ShaderCodegen for PerPixelAdjust {
3939
param_type: ParamType::Uniform,
4040
}),
4141
ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => {
42-
binding_cnt += 1;
43-
Ok(Param {
42+
let param = Param {
4443
ident: Cow::Owned(format_ident!("image_{}", &ident.ident)),
4544
ty: quote!(Image2d),
46-
param_type: ParamType::Image { binding: binding_cnt },
47-
})
45+
param_type: ParamType::Image { binding: 0 },
46+
};
47+
Ok(param)
4848
}
4949
}
5050
})
51-
.collect::<syn::Result<Vec<_>>>()?
52-
};
51+
.collect::<syn::Result<Vec<_>>>()?;
52+
53+
has_uniform = params.iter().any(|p| matches!(p.param_type, ParamType::Uniform));
54+
55+
// assign image bindings
56+
// if an arg_buffer exists, bindings for images start at 1 to leave 0 for arg buffer
57+
let mut binding_cnt = if has_uniform { 1 } else { 0 };
58+
for p in params.iter_mut() {
59+
match &mut p.param_type {
60+
ParamType::Image { binding } => {
61+
*binding = binding_cnt;
62+
binding_cnt += 1;
63+
}
64+
ParamType::Uniform => {}
65+
}
66+
}
67+
}
5368

5469
let entry_point_mod = format_ident!("{}_gpu_entry_point", fn_name);
5570
let entry_point_name_ident = format_ident!("ENTRY_POINT_NAME");
5671
let entry_point_name = quote!(#entry_point_mod::#entry_point_name_ident);
72+
let uniform_struct_ident = format_ident!("Uniform");
73+
let uniform_struct = quote!(#entry_point_mod::#uniform_struct_ident);
5774
let gpu_node_mod = format_ident!("{}_gpu", fn_name);
5875

5976
let codegen = PerPixelAdjustCodegen {
6077
parsed,
6178
node_cfg,
6279
params,
80+
has_uniform,
6381
entry_point_mod,
6482
entry_point_name_ident,
6583
entry_point_name,
84+
uniform_struct_ident,
85+
uniform_struct,
6686
gpu_node_mod,
6787
};
6888

@@ -77,9 +97,12 @@ pub struct PerPixelAdjustCodegen<'a> {
7797
parsed: &'a ParsedNodeFn,
7898
node_cfg: &'a TokenStream,
7999
params: Vec<Param<'a>>,
100+
has_uniform: bool,
80101
entry_point_mod: Ident,
81102
entry_point_name_ident: Ident,
82103
entry_point_name: TokenStream,
104+
uniform_struct_ident: Ident,
105+
uniform_struct: TokenStream,
83106
gpu_node_mod: Ident,
84107
}
85108

@@ -114,6 +137,7 @@ impl PerPixelAdjustCodegen<'_> {
114137

115138
let entry_point_mod = &self.entry_point_mod;
116139
let entry_point_name = &self.entry_point_name_ident;
140+
let uniform_struct_ident = &self.uniform_struct_ident;
117141
Ok(quote! {
118142
pub mod #entry_point_mod {
119143
use super::*;
@@ -125,8 +149,10 @@ impl PerPixelAdjustCodegen<'_> {
125149

126150
pub const #entry_point_name: &str = core::concat!(core::module_path!(), "::entry_point");
127151

128-
pub struct Uniform {
129-
#(#uniform_members),*
152+
#[repr(C)]
153+
#[derive(Copy, Clone, bytemuck::NoUninit)]
154+
pub struct #uniform_struct_ident {
155+
#(pub #uniform_members),*
130156
}
131157

132158
#[spirv(fragment)]
@@ -158,14 +184,26 @@ impl PerPixelAdjustCodegen<'_> {
158184
.iter()
159185
.map(|f| match &f.ty {
160186
ParsedFieldType::Regular(reg @ RegularParsedField { gpu_image: true, .. }) => Ok(ParsedField {
187+
pat_ident: PatIdent {
188+
mutability: None,
189+
by_ref: None,
190+
..f.pat_ident.clone()
191+
},
161192
ty: ParsedFieldType::Regular(RegularParsedField {
162193
ty: raster_gpu.clone(),
163194
implementations: Punctuated::default(),
164195
..reg.clone()
165196
}),
166197
..f.clone()
167198
}),
168-
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(f.clone()),
199+
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(ParsedField {
200+
pat_ident: PatIdent {
201+
mutability: None,
202+
by_ref: None,
203+
..f.pat_ident.clone()
204+
},
205+
..f.clone()
206+
}),
169207
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
170208
})
171209
.collect::<syn::Result<Vec<_>>>()?;
@@ -211,14 +249,35 @@ impl PerPixelAdjustCodegen<'_> {
211249
};
212250
let gpu_image = &gpu_image_field.pat_ident.ident;
213251

252+
// uniform buffer struct construction
253+
let has_uniform = self.has_uniform;
254+
let uniform_buffer = if has_uniform {
255+
let uniform_struct = &self.uniform_struct;
256+
let uniform_members = self
257+
.params
258+
.iter()
259+
.filter_map(|p| match p.param_type {
260+
ParamType::Image { .. } => None,
261+
ParamType::Uniform => Some(p.ident.as_ref()),
262+
})
263+
.collect::<Vec<_>>();
264+
quote!(Some(&super::#uniform_struct {
265+
#(#uniform_members),*
266+
}))
267+
} else {
268+
// explicit generics placed here cause it's easier than explicitly writing `run_per_pixel_adjust::<()>`
269+
quote!(Option::<&()>::None)
270+
};
271+
214272
// node function body
215273
let entry_point_name = &self.entry_point_name;
216274
let body = quote! {
217275
{
218-
#wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::Shaders {
276+
#wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::per_pixel_adjust_runtime::Shaders {
219277
wgsl_shader: crate::WGSL_SHADER,
220278
fragment_shader_name: super::#entry_point_name,
221-
}, #gpu_image, &1u32).await
279+
has_uniform: #has_uniform,
280+
}, #gpu_image, #uniform_buffer).await
222281
}
223282
};
224283

node-graph/wgpu-executor/src/shader_runtime/mod.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,3 @@ impl ShaderRuntime {
1818
}
1919
}
2020
}
21-
22-
pub struct Shaders<'a> {
23-
pub wgsl_shader: &'a str,
24-
pub fragment_shader_name: &'a str,
25-
}

node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::Context;
2-
use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime, Shaders};
2+
use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime};
33
use bytemuck::NoUninit;
44
use futures::lock::Mutex;
55
use graphene_core::raster_types::{GPU, Raster};
@@ -27,24 +27,33 @@ impl PerPixelAdjustShaderRuntime {
2727
}
2828

2929
impl ShaderRuntime {
30-
pub async fn run_per_pixel_adjust<T: NoUninit>(&self, shaders: &Shaders<'_>, textures: Table<Raster<GPU>>, args: &T) -> Table<Raster<GPU>> {
30+
pub async fn run_per_pixel_adjust<T: NoUninit>(&self, shaders: &Shaders<'_>, textures: Table<Raster<GPU>>, args: Option<&T>) -> Table<Raster<GPU>> {
3131
let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await;
3232
let pipeline = cache
3333
.entry(shaders.fragment_shader_name.to_owned())
3434
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders));
3535

36-
let device = &self.context.device;
37-
let arg_buffer = device.create_buffer_init(&BufferInitDescriptor {
38-
label: Some(&format!("{} arg buffer", pipeline.name.as_str())),
39-
usage: BufferUsages::STORAGE,
40-
contents: bytemuck::bytes_of(args),
36+
let arg_buffer = args.map(|args| {
37+
let device = &self.context.device;
38+
device.create_buffer_init(&BufferInitDescriptor {
39+
label: Some(&format!("{} arg buffer", pipeline.name.as_str())),
40+
usage: BufferUsages::STORAGE,
41+
contents: bytemuck::bytes_of(args),
42+
})
4143
});
42-
pipeline.dispatch(&self.context, textures, &arg_buffer)
44+
pipeline.dispatch(&self.context, textures, arg_buffer)
4345
}
4446
}
4547

48+
pub struct Shaders<'a> {
49+
pub wgsl_shader: &'a str,
50+
pub fragment_shader_name: &'a str,
51+
pub has_uniform: bool,
52+
}
53+
4654
pub struct PerPixelAdjustGraphicsPipeline {
4755
name: String,
56+
has_uniform: bool,
4857
pipeline: wgpu::RenderPipeline,
4958
}
5059

@@ -62,32 +71,46 @@ impl PerPixelAdjustGraphicsPipeline {
6271
source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)),
6372
});
6473

74+
let entries: &[_] = if info.has_uniform {
75+
&[
76+
BindGroupLayoutEntry {
77+
binding: 0,
78+
visibility: ShaderStages::FRAGMENT,
79+
ty: BindingType::Buffer {
80+
ty: BufferBindingType::Storage { read_only: true },
81+
has_dynamic_offset: false,
82+
min_binding_size: None,
83+
},
84+
count: None,
85+
},
86+
BindGroupLayoutEntry {
87+
binding: 1,
88+
visibility: ShaderStages::FRAGMENT,
89+
ty: BindingType::Texture {
90+
sample_type: TextureSampleType::Float { filterable: false },
91+
view_dimension: TextureViewDimension::D2,
92+
multisampled: false,
93+
},
94+
count: None,
95+
},
96+
]
97+
} else {
98+
&[BindGroupLayoutEntry {
99+
binding: 0,
100+
visibility: ShaderStages::FRAGMENT,
101+
ty: BindingType::Texture {
102+
sample_type: TextureSampleType::Float { filterable: false },
103+
view_dimension: TextureViewDimension::D2,
104+
multisampled: false,
105+
},
106+
count: None,
107+
}]
108+
};
65109
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
66110
label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)),
67111
bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor {
68112
label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)),
69-
entries: &[
70-
BindGroupLayoutEntry {
71-
binding: 0,
72-
visibility: ShaderStages::FRAGMENT,
73-
ty: BindingType::Buffer {
74-
ty: BufferBindingType::Storage { read_only: true },
75-
has_dynamic_offset: false,
76-
min_binding_size: None,
77-
},
78-
count: None,
79-
},
80-
BindGroupLayoutEntry {
81-
binding: 1,
82-
visibility: ShaderStages::FRAGMENT,
83-
ty: BindingType::Texture {
84-
sample_type: TextureSampleType::Float { filterable: false },
85-
view_dimension: TextureViewDimension::D2,
86-
multisampled: false,
87-
},
88-
count: None,
89-
},
90-
],
113+
entries,
91114
})],
92115
push_constant_ranges: &[],
93116
});
@@ -125,10 +148,15 @@ impl PerPixelAdjustGraphicsPipeline {
125148
multiview: None,
126149
cache: None,
127150
});
128-
Self { pipeline, name }
151+
Self {
152+
pipeline,
153+
name,
154+
has_uniform: info.has_uniform,
155+
}
129156
}
130157

131-
pub fn dispatch(&self, context: &Context, textures: Table<Raster<GPU>>, arg_buffer: &Buffer) -> Table<Raster<GPU>> {
158+
pub fn dispatch(&self, context: &Context, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
159+
assert_eq!(self.has_uniform, arg_buffer.is_some());
132160
let device = &context.device;
133161
let name = self.name.as_str();
134162

@@ -140,11 +168,8 @@ impl PerPixelAdjustGraphicsPipeline {
140168
let view_in = tex_in.create_view(&TextureViewDescriptor::default());
141169
let format = tex_in.format();
142170

143-
let bind_group = device.create_bind_group(&BindGroupDescriptor {
144-
label: Some(&format!("{name} bind group")),
145-
// `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that
146-
layout: &self.pipeline.get_bind_group_layout(0),
147-
entries: &[
171+
let entries: &[_] = if let Some(arg_buffer) = arg_buffer.as_ref() {
172+
&[
148173
BindGroupEntry {
149174
binding: 0,
150175
resource: BindingResource::Buffer(BufferBinding {
@@ -157,7 +182,18 @@ impl PerPixelAdjustGraphicsPipeline {
157182
binding: 1,
158183
resource: BindingResource::TextureView(&view_in),
159184
},
160-
],
185+
]
186+
} else {
187+
&[BindGroupEntry {
188+
binding: 0,
189+
resource: BindingResource::TextureView(&view_in),
190+
}]
191+
};
192+
let bind_group = device.create_bind_group(&BindGroupDescriptor {
193+
label: Some(&format!("{name} bind group")),
194+
// `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that
195+
layout: &self.pipeline.get_bind_group_layout(0),
196+
entries,
161197
});
162198

163199
let tex_out = device.create_texture(&TextureDescriptor {

0 commit comments

Comments
 (0)