Skip to content

Commit f060d12

Browse files
committed
WIP12: connect shader runtime
1 parent 469c346 commit f060d12

File tree

6 files changed

+82
-18
lines changed

6 files changed

+82
-18
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

node-graph/graster-nodes/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ workspace = true
1616
default = ["std"]
1717
std = [
1818
"dep:graphene-core",
19+
"dep:graphene-raster-nodes-shaders",
20+
"dep:wgpu-executor",
1921
"dep:dyn-any",
2022
"dep:image",
2123
"dep:ndarray",
@@ -24,8 +26,7 @@ std = [
2426
"dep:fastnoise-lite",
2527
"dep:serde",
2628
"dep:specta",
27-
"dep:graphene-raster-nodes-shaders",
28-
"dep:kurbo"
29+
"dep:kurbo",
2930
]
3031

3132
[dependencies]
@@ -36,6 +37,7 @@ node-macro = { workspace = true }
3637
# Local std dependencies
3738
dyn-any = { workspace = true, optional = true }
3839
graphene-core = { workspace = true, optional = true }
40+
wgpu-executor = { workspace = true, optional = true }
3941
graphene-raster-nodes-shaders = { path = "./shaders", optional = true }
4042

4143
# Workspace dependencies

node-graph/graster-nodes/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ pub mod blending_nodes;
88
pub mod cubic_spline;
99
pub mod fullscreen_vertex;
1010

11+
/// required by shader macro
12+
#[cfg(feature = "std")]
13+
pub use graphene_raster_nodes_shaders::WGSL_SHADER;
14+
1115
#[cfg(feature = "std")]
1216
pub mod curve;
1317
#[cfg(feature = "std")]

node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use quote::{ToTokens, format_ident, quote};
77
use std::borrow::Cow;
88
use syn::parse::{Parse, ParseStream};
99
use syn::punctuated::Punctuated;
10-
use syn::{Type, parse_quote};
10+
use syn::{PatIdent, Type, parse_quote};
1111

1212
#[derive(Debug, Clone)]
1313
pub struct PerPixelAdjust {}
@@ -20,15 +20,14 @@ impl Parse for PerPixelAdjust {
2020

2121
impl ShaderCodegen for PerPixelAdjust {
2222
fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> {
23-
Ok(ShaderTokens {
24-
shader_entry_point: self.codegen_shader_entry_point(parsed)?,
25-
gpu_node: self.codegen_gpu_node(parsed, node_cfg)?,
26-
})
23+
let (shader_entry_point, entry_point_name) = self.codegen_shader_entry_point(parsed)?;
24+
let gpu_node = self.codegen_gpu_node(parsed, node_cfg, &entry_point_name)?;
25+
Ok(ShaderTokens { shader_entry_point, gpu_node })
2726
}
2827
}
2928

3029
impl PerPixelAdjust {
31-
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream> {
30+
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<(TokenStream, TokenStream)> {
3231
let fn_name = &parsed.fn_name;
3332
let gpu_mod = format_ident!("{}_gpu_entry_point", fn_name);
3433
let spirv_image_ty = quote!(Image2d);
@@ -82,7 +81,10 @@ impl PerPixelAdjust {
8281
.collect::<Vec<_>>();
8382
let context = quote!(());
8483

85-
Ok(quote! {
84+
let entry_point_name = format_ident!("ENTRY_POINT_NAME");
85+
let entry_point_sym = quote!(#gpu_mod::#entry_point_name);
86+
87+
let shader_entry_point = quote! {
8688
pub mod #gpu_mod {
8789
use super::*;
8890
use graphene_core_shaders::color::Color;
@@ -91,6 +93,8 @@ impl PerPixelAdjust {
9193
use spirv_std::image::{Image2d, ImageWithMethods};
9294
use spirv_std::image::sample_with::lod;
9395

96+
pub const #entry_point_name: &str = core::concat!(core::module_path!(), "::entry_point");
97+
9498
pub struct Uniform {
9599
#(#uniform_members),*
96100
}
@@ -107,10 +111,11 @@ impl PerPixelAdjust {
107111
*color_out = color.to_vec4();
108112
}
109113
}
110-
})
114+
};
115+
Ok((shader_entry_point, entry_point_sym))
111116
}
112117

113-
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<TokenStream> {
118+
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream, entry_point_name: &TokenStream) -> syn::Result<TokenStream> {
114119
let fn_name = format_ident!("{}_gpu", parsed.fn_name);
115120
let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
116121
let mod_name = fn_name.clone();
@@ -121,7 +126,8 @@ impl PerPixelAdjust {
121126
};
122127
let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>);
123128

124-
let fields = parsed
129+
// adapt fields for gpu node
130+
let mut fields = parsed
125131
.fields
126132
.iter()
127133
.map(|f| match &f.ty {
@@ -136,11 +142,55 @@ impl PerPixelAdjust {
136142
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(f.clone()),
137143
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
138144
})
139-
.collect::<syn::Result<_>>()?;
145+
.collect::<syn::Result<Vec<_>>>()?;
146+
147+
// wgpu_executor field
148+
let wgpu_executor = format_ident!("__wgpu_executor");
149+
fields.push(ParsedField {
150+
pat_ident: PatIdent {
151+
attrs: vec![],
152+
by_ref: None,
153+
mutability: None,
154+
ident: parse_quote!(#wgpu_executor),
155+
subpat: None,
156+
},
157+
name: None,
158+
description: "".to_string(),
159+
widget_override: Default::default(),
160+
ty: ParsedFieldType::Regular(RegularParsedField {
161+
ty: parse_quote!(WgpuExecutor),
162+
exposed: false,
163+
value_source: Default::default(),
164+
number_soft_min: None,
165+
number_soft_max: None,
166+
number_hard_min: None,
167+
number_hard_max: None,
168+
number_mode_range: None,
169+
implementations: Default::default(),
170+
gpu_image: false,
171+
}),
172+
number_display_decimal_places: None,
173+
number_step: None,
174+
unit: None,
175+
});
176+
177+
// exactly one gpu_image field, may be expanded later
178+
let gpu_image_field = {
179+
let mut iter = fields.iter().filter(|f| matches!(f.ty, ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. })));
180+
match (iter.next(), iter.next()) {
181+
(Some(v), None) => Ok(v),
182+
(Some(_), Some(more)) => Err(syn::Error::new_spanned(&more.pat_ident, "No more than one parameter must be annotated with `#[gpu_image]`")),
183+
(None, _) => Err(syn::Error::new_spanned(&parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")),
184+
}?
185+
};
186+
let gpu_image = &gpu_image_field.pat_ident.ident;
140187

141188
let body = quote! {
142189
{
143-
190+
#wgpu_executor.shader_runtime.run_per_pixel_adjust(#gpu_image, &::wgpu_executor::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustInfo {
191+
wgsl_shader: crate::WGSL_SHADER,
192+
fragment_shader_name: super::#entry_point_name,
193+
}).await
144194
}
145195
};
146196

@@ -174,6 +224,7 @@ impl PerPixelAdjust {
174224
#node_cfg
175225
mod #mod_name {
176226
use super::*;
227+
use wgpu_executor::WgpuExecutor;
177228

178229
#gpu_node
179230
}

node-graph/wgpu-executor/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod context;
22
pub mod shader_runtime;
33
pub mod texture_upload;
44

5+
use crate::shader_runtime::ShaderRuntime;
56
use anyhow::Result;
67
pub use context::Context;
78
use dyn_any::StaticType;
@@ -19,6 +20,7 @@ use wgpu::{Origin3d, SurfaceConfiguration, TextureAspect};
1920
pub struct WgpuExecutor {
2021
pub context: Context,
2122
vello_renderer: Mutex<Renderer>,
23+
pub shader_runtime: ShaderRuntime,
2224
}
2325

2426
impl std::fmt::Debug for WgpuExecutor {
@@ -196,6 +198,7 @@ impl WgpuExecutor {
196198
.ok()?;
197199

198200
Some(Self {
201+
shader_runtime: ShaderRuntime::new(&context),
199202
context,
200203
vello_renderer: vello_renderer.into(),
201204
})

node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ impl ShaderRuntime {
3535
}
3636

3737
pub struct PerPixelAdjustInfo<'a> {
38-
shader_wgsl: &'a str,
39-
fragment_shader_name: &'a str,
38+
pub wgsl_shader: &'a str,
39+
pub fragment_shader_name: &'a str,
4040
}
4141

4242
pub struct PerPixelAdjustGraphicsPipeline {
@@ -48,9 +48,12 @@ impl PerPixelAdjustGraphicsPipeline {
4848
pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self {
4949
let device = &context.device;
5050
let name = info.fragment_shader_name.to_owned();
51+
// TODO workaround to naga removing `:`
52+
let fragment_name = name.replace(":", "");
53+
5154
let shader_module = device.create_shader_module(ShaderModuleDescriptor {
5255
label: Some(&format!("PerPixelAdjust {} wgsl shader", name)),
53-
source: ShaderSource::Wgsl(Cow::Borrowed(info.shader_wgsl)),
56+
source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)),
5457
});
5558
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
5659
label: Some(&format!("PerPixelAdjust {} Pipeline", name)),
@@ -74,7 +77,7 @@ impl PerPixelAdjustGraphicsPipeline {
7477
multisample: Default::default(),
7578
fragment: Some(FragmentState {
7679
module: &shader_module,
77-
entry_point: Some(&name),
80+
entry_point: Some(&fragment_name),
7881
compilation_options: Default::default(),
7982
targets: &[Some(ColorTargetState {
8083
format: TextureFormat::Rgba32Float,

0 commit comments

Comments
 (0)