Skip to content

Commit b47671a

Browse files
WgpuContextBuilder
1 parent 63a018f commit b47671a

File tree

8 files changed

+148
-97
lines changed

8 files changed

+148
-97
lines changed

desktop/src/gpu.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
use graphite_desktop_wrapper::{WgpuContext, WgpuContextBuilder, WgpuFeatures};
2+
3+
pub(super) async fn create_wgpu_context() -> WgpuContext {
4+
let wgpu_context_builder = WgpuContextBuilder::new().with_features(WgpuFeatures::PUSH_CONSTANTS);
5+
6+
// TODO: add a cli flag to list adapters and exit instead of always printing
7+
println!("\nAvailable WGPU adapters:\n{}", wgpu_context_builder.available_adapters_fmt().await);
8+
9+
// TODO: make this configurable via cli flags instead
10+
let wgpu_context = match std::env::var("GRAPHITE_WGPU_ADAPTER").ok().and_then(|s| s.parse().ok()) {
11+
None => wgpu_context_builder.build().await,
12+
Some(adapter_index) => {
13+
tracing::info!("Overriding WGPU adapter selection with adapter index {adapter_index}");
14+
wgpu_context_builder
15+
.build_with_adapter_selection(|adapters: &mut Vec<_>| if adapter_index < adapters.len() { Some(adapters.remove(adapter_index)) } else { None })
16+
.await
17+
}
18+
}
19+
.expect("Failed to create WGPU context");
20+
21+
// TODO: add a cli flag to list adapters and exit instead of always printing
22+
println!("Using WGPU adapter: {:?}", wgpu_context.adapter.get_info());
23+
24+
wgpu_context
25+
}

desktop/src/main.rs

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ use std::process::exit;
22
use tracing_subscriber::EnvFilter;
33
use winit::event_loop::EventLoop;
44

5-
use graphite_desktop_wrapper::WgpuContext;
6-
75
pub(crate) mod consts;
86

97
mod app;
@@ -14,6 +12,8 @@ mod native_window;
1412
mod persist;
1513
mod render;
1614

15+
mod gpu;
16+
1717
use app::App;
1818
use cef::CefHandler;
1919
use event::CreateAppEventSchedulerEventLoopExt;
@@ -31,7 +31,7 @@ fn main() {
3131
return;
3232
}
3333

34-
let wgpu_context = futures::executor::block_on(init_wgpu_context());
34+
let wgpu_context = futures::executor::block_on(gpu::create_wgpu_context());
3535

3636
let event_loop = EventLoop::new().unwrap();
3737
let (app_event_sender, app_event_receiver) = std::sync::mpsc::channel();
@@ -67,52 +67,3 @@ fn main() {
6767

6868
event_loop.run_app(&mut app).unwrap();
6969
}
70-
71-
async fn init_wgpu_context() -> WgpuContext {
72-
// TODO: make this configurable via cli flags instead
73-
let adapter_override = std::env::var("GRAPHITE_WGPU_ADAPTER").ok().map(|s| usize::from_str_radix(&s, 10).ok()).flatten();
74-
75-
let instance_descriptor = wgpu::InstanceDescriptor {
76-
backends: wgpu::Backends::all(),
77-
..Default::default()
78-
};
79-
let instance = wgpu::Instance::new(&instance_descriptor);
80-
81-
let mut adapters = instance.enumerate_adapters(wgpu::Backends::all());
82-
83-
// TODO: add a cli flag to list adapters and exit instead of always printing
84-
let adapters_fmt = adapters
85-
.iter()
86-
.enumerate()
87-
.map(|(i, a)| {
88-
let info = a.get_info();
89-
format!(
90-
"\nAdapter {}:\n Name: {}\n Backend: {:?}\n Driver: {}\n Device ID: {}\n Vendor ID: {}",
91-
i, info.name, info.backend, info.driver, info.device, info.vendor
92-
)
93-
})
94-
.collect::<Vec<_>>()
95-
.join("\n");
96-
println!("\nAvailable wgpu adapters:\n {}\n", adapters_fmt);
97-
98-
let adapter_index = if let Some(index) = adapter_override
99-
&& index < adapters.len()
100-
{
101-
index
102-
} else if cfg!(target_os = "windows") {
103-
match adapters.iter().enumerate().find(|(_, a)| a.get_info().backend == wgpu::Backend::Dx12) {
104-
Some((index, _)) => index,
105-
None => 0,
106-
}
107-
} else {
108-
0 // Same behavior as requests adapter
109-
};
110-
111-
tracing::info!("Using WGPU adapter {adapter_index}");
112-
113-
let adapter = adapters.remove(adapter_index);
114-
115-
WgpuContext::new_with_instance_and_adapter(instance, adapter)
116-
.await
117-
.expect("Failed to create WGPU context with specified adapter")
118-
}

desktop/wrapper/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ use graphite_editor::messages::prelude::{FrontendMessage, Message};
55
// TODO: Remove usage of this reexport in desktop create and remove this line
66
pub use graphene_std::Color;
77

8-
pub use wgpu_executor::Context as WgpuContext;
8+
pub use wgpu_executor::WgpuContext;
9+
pub use wgpu_executor::WgpuContextBuilder;
910
pub use wgpu_executor::WgpuExecutor;
11+
pub use wgpu_executor::WgpuFeatures;
1012

1113
pub mod messages;
1214
use messages::{DesktopFrontendMessage, DesktopWrapperMessage};

node-graph/graph-craft/src/wasm_application_io.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ impl WasmApplicationIo {
143143
io
144144
}
145145
#[cfg(all(not(target_family = "wasm"), feature = "wgpu"))]
146-
pub fn new_with_context(context: wgpu_executor::Context) -> Self {
146+
pub fn new_with_context(context: wgpu_executor::WgpuContext) -> Self {
147147
#[cfg(feature = "wgpu")]
148148
let executor = WgpuExecutor::with_context(context);
149149

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,123 @@
11
use std::sync::Arc;
2-
use wgpu::{Device, Instance, Queue};
2+
use wgpu::{Adapter, Backends, Device, Features, Instance, Queue};
33

44
#[derive(Debug, Clone)]
55
pub struct Context {
66
pub device: Arc<Device>,
77
pub queue: Arc<Queue>,
88
pub instance: Arc<Instance>,
9-
pub adapter: Arc<wgpu::Adapter>,
9+
pub adapter: Arc<Adapter>,
1010
}
1111

1212
impl Context {
1313
pub async fn new() -> Option<Self> {
14-
let instance_descriptor = wgpu::InstanceDescriptor {
15-
backends: wgpu::Backends::all(),
16-
..Default::default()
17-
};
18-
let instance = Instance::new(&instance_descriptor);
14+
ContextBuilder::new().build().await
15+
}
16+
}
1917

20-
let adapter_options = wgpu::RequestAdapterOptions {
21-
power_preference: wgpu::PowerPreference::HighPerformance,
22-
compatible_surface: None,
23-
force_fallback_adapter: false,
18+
#[derive(Default)]
19+
pub struct ContextBuilder {
20+
backends: Backends,
21+
features: Features,
22+
}
23+
impl ContextBuilder {
24+
pub fn new() -> Self {
25+
Self {
26+
backends: Backends::all(),
27+
features: Features::empty(),
28+
}
29+
}
30+
pub fn with_backends(mut self, backends: Backends) -> Self {
31+
self.backends = backends;
32+
self
33+
}
34+
pub fn with_features(mut self, features: Features) -> Self {
35+
self.features = features;
36+
self
37+
}
38+
pub async fn build(self) -> Option<Context> {
39+
self.build_with_adapter_selection_inner(None::<WgpuAdapterSelectorFn>).await
40+
}
41+
pub async fn build_with_adapter_selection<S: WgpuAdapterSelector>(self, select: S) -> Option<Context> {
42+
self.build_with_adapter_selection_inner(Some(select)).await
43+
}
44+
async fn build_with_adapter_selection_inner<S: WgpuAdapterSelector>(self, select: Option<S>) -> Option<Context> {
45+
let instance = self.build_instance();
46+
47+
let selected_adapter = if let Some(select) = select {
48+
self.select_adapter(&instance, select)
49+
} else if cfg!(target_os = "windows") {
50+
self.select_adapter(&instance, |adapters: &mut Vec<Adapter>| {
51+
adapters.iter().position(|a| a.get_info().backend == wgpu::Backend::Dx12).map(|i| adapters.remove(i))
52+
})
53+
} else {
54+
None
2455
};
2556

26-
let adapter = instance.request_adapter(&adapter_options).await.ok()?;
57+
let adapter = if let Some(adapter) = selected_adapter { adapter } else { self.request_adapter(&instance).await? };
2758

28-
Self::new_with_instance_and_adapter(instance, adapter).await
59+
let (device, queue) = self.request_device(&adapter).await?;
60+
Some(Context {
61+
device: Arc::new(device),
62+
queue: Arc::new(queue),
63+
adapter: Arc::new(adapter),
64+
instance: Arc::new(instance),
65+
})
66+
}
67+
pub async fn available_adapters_fmt(&self) -> impl std::fmt::Display {
68+
let instance = self.build_instance();
69+
let adapters = instance.enumerate_adapters(self.backends);
70+
AvailableAdaptersFormatter(adapters)
2971
}
3072

31-
pub async fn new_with_instance_and_adapter(instance: wgpu::Instance, adapter: wgpu::Adapter) -> Option<Self> {
32-
let required_limits = adapter.limits();
33-
let (device, queue) = adapter
73+
fn build_instance(&self) -> Instance {
74+
Instance::new(&wgpu::InstanceDescriptor {
75+
backends: self.backends,
76+
..Default::default()
77+
})
78+
}
79+
async fn request_adapter(&self, instance: &Instance) -> Option<Adapter> {
80+
instance
81+
.request_adapter(&wgpu::RequestAdapterOptions {
82+
power_preference: wgpu::PowerPreference::HighPerformance,
83+
compatible_surface: None,
84+
force_fallback_adapter: false,
85+
})
86+
.await
87+
.ok()
88+
}
89+
fn select_adapter<S: WgpuAdapterSelector>(&self, instance: &Instance, select: S) -> Option<Adapter> {
90+
select(&mut instance.enumerate_adapters(self.backends))
91+
}
92+
async fn request_device(&self, adapter: &Adapter) -> Option<(Device, Queue)> {
93+
adapter
3494
.request_device(&wgpu::DeviceDescriptor {
3595
label: None,
36-
#[cfg(target_family = "wasm")]
37-
required_features: wgpu::Features::empty(),
38-
#[cfg(not(target_family = "wasm"))]
39-
required_features: wgpu::Features::PUSH_CONSTANTS,
40-
required_limits,
96+
required_features: self.features,
97+
required_limits: adapter.limits(),
4198
memory_hints: Default::default(),
4299
trace: wgpu::Trace::Off,
43100
})
44101
.await
45-
.ok()?;
102+
.ok()
103+
}
104+
}
46105

47-
Some(Self {
48-
device: Arc::new(device),
49-
queue: Arc::new(queue),
50-
adapter: Arc::new(adapter),
51-
instance: Arc::new(instance),
52-
})
106+
pub trait WgpuAdapterSelector: FnOnce(&mut Vec<Adapter>) -> Option<Adapter> {}
107+
impl<F> WgpuAdapterSelector for F where F: FnOnce(&mut Vec<Adapter>) -> Option<Adapter> {}
108+
type WgpuAdapterSelectorFn = fn(&mut Vec<Adapter>) -> Option<Adapter>;
109+
110+
struct AvailableAdaptersFormatter(Vec<Adapter>);
111+
impl std::fmt::Display for AvailableAdaptersFormatter {
112+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113+
for (i, adapter) in self.0.iter().enumerate() {
114+
let info = adapter.get_info();
115+
writeln!(
116+
f,
117+
"[{}] {:?} {:?} (Name: {}, Driver: {}, Device: {})",
118+
i, info.backend, info.device_type, info.name, info.driver, info.device,
119+
)?;
120+
}
121+
Ok(())
53122
}
54123
}

node-graph/wgpu-executor/src/lib.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ pub mod texture_upload;
44

55
use crate::shader_runtime::ShaderRuntime;
66
use anyhow::Result;
7-
pub use context::Context;
87
use dyn_any::StaticType;
98
use futures::lock::Mutex;
109
use glam::UVec2;
@@ -16,9 +15,14 @@ use vello::{AaConfig, AaSupport, RenderParams, Renderer, RendererOptions, Scene}
1615
use wgpu::util::TextureBlitter;
1716
use wgpu::{Origin3d, SurfaceConfiguration, TextureAspect};
1817

18+
pub use context::Context as WgpuContext;
19+
pub use context::ContextBuilder as WgpuContextBuilder;
20+
pub use wgpu::Backends as WgpuBackends;
21+
pub use wgpu::Features as WgpuFeatures;
22+
1923
#[derive(dyn_any::DynAny)]
2024
pub struct WgpuExecutor {
21-
pub context: Context,
25+
pub context: WgpuContext,
2226
vello_renderer: Mutex<Renderer>,
2327
pub shader_runtime: ShaderRuntime,
2428
}
@@ -182,10 +186,10 @@ impl WgpuExecutor {
182186

183187
impl WgpuExecutor {
184188
pub async fn new() -> Option<Self> {
185-
Self::with_context(Context::new().await?)
189+
Self::with_context(WgpuContext::new().await?)
186190
}
187191

188-
pub fn with_context(context: Context) -> Option<Self> {
192+
pub fn with_context(context: WgpuContext) -> Option<Self> {
189193
let vello_renderer = Renderer::new(
190194
&context.device,
191195
RendererOptions {

node-graph/wgpu-executor/src/shader_runtime/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
use crate::Context;
1+
use crate::WgpuContext;
22
use crate::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustShaderRuntime;
33

44
pub mod per_pixel_adjust_runtime;
55

66
pub const FULLSCREEN_VERTEX_SHADER_NAME: &str = "fullscreen_vertexfullscreen_vertex";
77

88
pub struct ShaderRuntime {
9-
context: Context,
9+
context: WgpuContext,
1010
per_pixel_adjust: PerPixelAdjustShaderRuntime,
1111
}
1212

1313
impl ShaderRuntime {
14-
pub fn new(context: &Context) -> Self {
14+
pub fn new(context: &WgpuContext) -> Self {
1515
Self {
1616
context: context.clone(),
1717
per_pixel_adjust: PerPixelAdjustShaderRuntime::new(),

node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::Context;
1+
use crate::WgpuContext;
22
use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime};
33
use futures::lock::Mutex;
44
use graphene_core::raster_types::{GPU, Raster};
@@ -31,7 +31,7 @@ impl ShaderRuntime {
3131
let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await;
3232
let pipeline = cache
3333
.entry(shaders.fragment_shader_name.to_owned())
34-
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders));
34+
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, shaders));
3535

3636
let arg_buffer = args.map(|args| {
3737
let device = &self.context.device;
@@ -58,7 +58,7 @@ pub struct PerPixelAdjustGraphicsPipeline {
5858
}
5959

6060
impl PerPixelAdjustGraphicsPipeline {
61-
pub fn new(context: &Context, info: &Shaders) -> Self {
61+
pub fn new(context: &WgpuContext, info: &Shaders) -> Self {
6262
let device = &context.device;
6363
let name = info.fragment_shader_name.to_owned();
6464

@@ -67,7 +67,7 @@ impl PerPixelAdjustGraphicsPipeline {
6767
// TODO workaround to naga removing `:`
6868
let fragment_name = fragment_name.replace(":", "");
6969
let shader_module = device.create_shader_module(ShaderModuleDescriptor {
70-
label: Some(&format!("PerPixelAdjust {} wgsl shader", name)),
70+
label: Some(&format!("PerPixelAdjust {name} wgsl shader")),
7171
source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)),
7272
});
7373

@@ -107,16 +107,16 @@ impl PerPixelAdjustGraphicsPipeline {
107107
}]
108108
};
109109
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
110-
label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)),
110+
label: Some(&format!("PerPixelAdjust {name} PipelineLayout")),
111111
bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor {
112-
label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)),
112+
label: Some(&format!("PerPixelAdjust {name} BindGroupLayout 0")),
113113
entries,
114114
})],
115115
push_constant_ranges: &[],
116116
});
117117

118118
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
119-
label: Some(&format!("PerPixelAdjust {} Pipeline", name)),
119+
label: Some(&format!("PerPixelAdjust {name} Pipeline")),
120120
layout: Some(&pipeline_layout),
121121
vertex: VertexState {
122122
module: &shader_module,
@@ -155,7 +155,7 @@ impl PerPixelAdjustGraphicsPipeline {
155155
}
156156
}
157157

158-
pub fn dispatch(&self, context: &Context, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
158+
pub fn dispatch(&self, context: &WgpuContext, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
159159
assert_eq!(self.has_uniform, arg_buffer.is_some());
160160
let device = &context.device;
161161
let name = self.name.as_str();

0 commit comments

Comments
 (0)