diff --git a/desktop/src/gpu_context.rs b/desktop/src/gpu_context.rs new file mode 100644 index 0000000000..787d6cda2f --- /dev/null +++ b/desktop/src/gpu_context.rs @@ -0,0 +1,23 @@ +use graphite_desktop_wrapper::{WgpuContext, WgpuContextBuilder, WgpuFeatures}; + +pub(super) async fn create_wgpu_context() -> WgpuContext { + let wgpu_context_builder = WgpuContextBuilder::new().with_features(WgpuFeatures::PUSH_CONSTANTS); + + // TODO: add a cli flag to list adapters and exit instead of always printing + println!("\nAvailable WGPU adapters:\n{}", wgpu_context_builder.available_adapters_fmt().await); + + // TODO: make this configurable via cli flags instead + let wgpu_context = match std::env::var("GRAPHITE_WGPU_ADAPTER").ok().and_then(|s| s.parse().ok()) { + None => wgpu_context_builder.build().await, + Some(adapter_index) => { + tracing::info!("Overriding WGPU adapter selection with adapter index {adapter_index}"); + wgpu_context_builder.build_with_adapter_selection(|_| Some(adapter_index)).await + } + } + .expect("Failed to create WGPU context"); + + // TODO: add a cli flag to list adapters and exit instead of always printing + println!("Using WGPU adapter: {:?}", wgpu_context.adapter.get_info()); + + wgpu_context +} diff --git a/desktop/src/main.rs b/desktop/src/main.rs index fbe8c362b0..4c15a05d78 100644 --- a/desktop/src/main.rs +++ b/desktop/src/main.rs @@ -2,8 +2,6 @@ use std::process::exit; use tracing_subscriber::EnvFilter; use winit::event_loop::EventLoop; -use graphite_desktop_wrapper::WgpuContext; - pub(crate) mod consts; mod app; @@ -14,6 +12,8 @@ mod native_window; mod persist; mod render; +mod gpu_context; + use app::App; use cef::CefHandler; use event::CreateAppEventSchedulerEventLoopExt; @@ -31,7 +31,7 @@ fn main() { return; } - let wgpu_context = futures::executor::block_on(WgpuContext::new()).unwrap(); + let wgpu_context = futures::executor::block_on(gpu_context::create_wgpu_context()); let event_loop = EventLoop::new().unwrap(); let (app_event_sender, app_event_receiver) = std::sync::mpsc::channel(); diff --git a/desktop/wrapper/src/lib.rs b/desktop/wrapper/src/lib.rs index 0410f153bb..3050c96d2a 100644 --- a/desktop/wrapper/src/lib.rs +++ b/desktop/wrapper/src/lib.rs @@ -5,8 +5,10 @@ use graphite_editor::messages::prelude::{FrontendMessage, Message}; // TODO: Remove usage of this reexport in desktop create and remove this line pub use graphene_std::Color; -pub use wgpu_executor::Context as WgpuContext; +pub use wgpu_executor::WgpuContext; +pub use wgpu_executor::WgpuContextBuilder; pub use wgpu_executor::WgpuExecutor; +pub use wgpu_executor::WgpuFeatures; pub mod messages; use messages::{DesktopFrontendMessage, DesktopWrapperMessage}; diff --git a/node-graph/graph-craft/src/wasm_application_io.rs b/node-graph/graph-craft/src/wasm_application_io.rs index abca58d88b..9a956d52b6 100644 --- a/node-graph/graph-craft/src/wasm_application_io.rs +++ b/node-graph/graph-craft/src/wasm_application_io.rs @@ -143,7 +143,7 @@ impl WasmApplicationIo { io } #[cfg(all(not(target_family = "wasm"), feature = "wgpu"))] - pub fn new_with_context(context: wgpu_executor::Context) -> Self { + pub fn new_with_context(context: wgpu_executor::WgpuContext) -> Self { #[cfg(feature = "wgpu")] let executor = WgpuExecutor::with_context(context); diff --git a/node-graph/wgpu-executor/src/context.rs b/node-graph/wgpu-executor/src/context.rs index 06da16e0ac..ff95686d05 100644 --- a/node-graph/wgpu-executor/src/context.rs +++ b/node-graph/wgpu-executor/src/context.rs @@ -1,53 +1,151 @@ use std::sync::Arc; -use wgpu::{Device, Instance, Queue}; +use wgpu::{Adapter, Backends, Device, Features, Instance, Queue}; #[derive(Debug, Clone)] pub struct Context { pub device: Arc, pub queue: Arc, pub instance: Arc, - pub adapter: Arc, + pub adapter: Arc, } impl Context { pub async fn new() -> Option { - // Instantiates instance of WebGPU - let instance_descriptor = wgpu::InstanceDescriptor { - backends: wgpu::Backends::all(), - ..Default::default() - }; - let instance = Instance::new(&instance_descriptor); + ContextBuilder::new().build().await + } +} - let adapter_options = wgpu::RequestAdapterOptions { +#[derive(Default)] +pub struct ContextBuilder { + backends: Backends, + features: Features, +} +impl ContextBuilder { + pub fn new() -> Self { + Self { + backends: Backends::all(), + features: Features::empty(), + } + } + pub fn with_backends(mut self, backends: Backends) -> Self { + self.backends = backends; + self + } + pub fn with_features(mut self, features: Features) -> Self { + self.features = features; + self + } +} +#[cfg(not(target_family = "wasm"))] +impl ContextBuilder { + pub async fn build(self) -> Option { + self.build_with_adapter_selection_inner(None:: Option>).await + } + pub async fn build_with_adapter_selection(self, select: S) -> Option + where + S: Fn(&[Adapter]) -> Option, + { + self.build_with_adapter_selection_inner(Some(select)).await + } + pub async fn available_adapters_fmt(&self) -> impl std::fmt::Display { + let instance = self.build_instance(); + fmt::AvailableAdaptersFormatter(instance.enumerate_adapters(self.backends)) + } +} +#[cfg(target_family = "wasm")] +impl ContextBuilder { + pub async fn build(self) -> Option { + let instance = self.build_instance(); + let adapter = self.request_adapter(&instance).await?; + let (device, queue) = self.request_device(&adapter).await?; + Some(Context { + device: Arc::new(device), + queue: Arc::new(queue), + adapter: Arc::new(adapter), + instance: Arc::new(instance), + }) + } +} +impl ContextBuilder { + fn build_instance(&self) -> Instance { + Instance::new(&wgpu::InstanceDescriptor { + backends: self.backends, + ..Default::default() + }) + } + async fn request_adapter(&self, instance: &Instance) -> Option { + let request_adapter_options = wgpu::RequestAdapterOptions { power_preference: wgpu::PowerPreference::HighPerformance, compatible_surface: None, force_fallback_adapter: false, }; - // `request_adapter` instantiates the general connection to the GPU - let adapter = instance.request_adapter(&adapter_options).await.ok()?; + instance.request_adapter(&request_adapter_options).await.ok() + } + async fn request_device(&self, adapter: &Adapter) -> Option<(Device, Queue)> { + let device_descriptor = wgpu::DeviceDescriptor { + label: None, + required_features: self.features, + required_limits: adapter.limits(), + memory_hints: Default::default(), + trace: wgpu::Trace::Off, + }; + adapter.request_device(&device_descriptor).await.ok() + } +} +#[cfg(not(target_family = "wasm"))] +impl ContextBuilder { + async fn build_with_adapter_selection_inner(self, select: Option) -> Option + where + S: Fn(&[Adapter]) -> Option, + { + let instance = self.build_instance(); + + let selected_adapter = if let Some(select) = select { + self.select_adapter(&instance, select) + } else if cfg!(target_os = "windows") { + self.select_adapter(&instance, |adapters: &[Adapter]| adapters.iter().position(|a| a.get_info().backend == wgpu::Backend::Dx12)) + } else { + None + }; - let required_limits = adapter.limits(); - // `request_device` instantiates the feature specific connection to the GPU, defining some parameters, - // `features` being the available features. - let (device, queue) = adapter - .request_device(&wgpu::DeviceDescriptor { - label: None, - #[cfg(target_family = "wasm")] - required_features: wgpu::Features::empty(), - #[cfg(not(target_family = "wasm"))] - required_features: wgpu::Features::PUSH_CONSTANTS, - required_limits, - memory_hints: Default::default(), - trace: wgpu::Trace::Off, - }) - .await - .ok()?; + let adapter = if let Some(adapter) = selected_adapter { adapter } else { self.request_adapter(&instance).await? }; - Some(Self { + let (device, queue) = self.request_device(&adapter).await?; + Some(Context { device: Arc::new(device), queue: Arc::new(queue), adapter: Arc::new(adapter), instance: Arc::new(instance), }) } + fn select_adapter(&self, instance: &Instance, select: S) -> Option + where + S: Fn(&[Adapter]) -> Option, + { + let mut adapters = instance.enumerate_adapters(self.backends); + let selected_index = select(&adapters)?; + if selected_index >= adapters.len() { + return None; + } + Some(adapters.remove(selected_index)) + } +} +#[cfg(not(target_family = "wasm"))] +mod fmt { + use super::*; + + pub(super) struct AvailableAdaptersFormatter(pub(super) Vec); + impl std::fmt::Display for AvailableAdaptersFormatter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, adapter) in self.0.iter().enumerate() { + let info = adapter.get_info(); + writeln!( + f, + "[{}] {:?} {:?} (Name: {}, Driver: {}, Device: {})", + i, info.backend, info.device_type, info.name, info.driver, info.device, + )?; + } + Ok(()) + } + } } diff --git a/node-graph/wgpu-executor/src/lib.rs b/node-graph/wgpu-executor/src/lib.rs index 4a1e7f3e36..cd8e3d4263 100644 --- a/node-graph/wgpu-executor/src/lib.rs +++ b/node-graph/wgpu-executor/src/lib.rs @@ -4,7 +4,6 @@ pub mod texture_upload; use crate::shader_runtime::ShaderRuntime; use anyhow::Result; -pub use context::Context; use dyn_any::StaticType; use futures::lock::Mutex; use glam::UVec2; @@ -16,9 +15,14 @@ use vello::{AaConfig, AaSupport, RenderParams, Renderer, RendererOptions, Scene} use wgpu::util::TextureBlitter; use wgpu::{Origin3d, SurfaceConfiguration, TextureAspect}; +pub use context::Context as WgpuContext; +pub use context::ContextBuilder as WgpuContextBuilder; +pub use wgpu::Backends as WgpuBackends; +pub use wgpu::Features as WgpuFeatures; + #[derive(dyn_any::DynAny)] pub struct WgpuExecutor { - pub context: Context, + pub context: WgpuContext, vello_renderer: Mutex, pub shader_runtime: ShaderRuntime, } @@ -182,10 +186,10 @@ impl WgpuExecutor { impl WgpuExecutor { pub async fn new() -> Option { - Self::with_context(Context::new().await?) + Self::with_context(WgpuContext::new().await?) } - pub fn with_context(context: Context) -> Option { + pub fn with_context(context: WgpuContext) -> Option { let vello_renderer = Renderer::new( &context.device, RendererOptions { diff --git a/node-graph/wgpu-executor/src/shader_runtime/mod.rs b/node-graph/wgpu-executor/src/shader_runtime/mod.rs index e7e0df8d94..7260fa6e56 100644 --- a/node-graph/wgpu-executor/src/shader_runtime/mod.rs +++ b/node-graph/wgpu-executor/src/shader_runtime/mod.rs @@ -1,4 +1,4 @@ -use crate::Context; +use crate::WgpuContext; use crate::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustShaderRuntime; pub mod per_pixel_adjust_runtime; @@ -6,12 +6,12 @@ pub mod per_pixel_adjust_runtime; pub const FULLSCREEN_VERTEX_SHADER_NAME: &str = "fullscreen_vertexfullscreen_vertex"; pub struct ShaderRuntime { - context: Context, + context: WgpuContext, per_pixel_adjust: PerPixelAdjustShaderRuntime, } impl ShaderRuntime { - pub fn new(context: &Context) -> Self { + pub fn new(context: &WgpuContext) -> Self { Self { context: context.clone(), per_pixel_adjust: PerPixelAdjustShaderRuntime::new(), diff --git a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs index aa567a5b93..763b1e1e30 100644 --- a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs +++ b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs @@ -1,4 +1,4 @@ -use crate::Context; +use crate::WgpuContext; use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime}; use futures::lock::Mutex; use graphene_core::raster_types::{GPU, Raster}; @@ -31,7 +31,7 @@ impl ShaderRuntime { let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await; let pipeline = cache .entry(shaders.fragment_shader_name.to_owned()) - .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders)); + .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, shaders)); let arg_buffer = args.map(|args| { let device = &self.context.device; @@ -58,7 +58,7 @@ pub struct PerPixelAdjustGraphicsPipeline { } impl PerPixelAdjustGraphicsPipeline { - pub fn new(context: &Context, info: &Shaders) -> Self { + pub fn new(context: &WgpuContext, info: &Shaders) -> Self { let device = &context.device; let name = info.fragment_shader_name.to_owned(); @@ -67,7 +67,7 @@ impl PerPixelAdjustGraphicsPipeline { // TODO workaround to naga removing `:` let fragment_name = fragment_name.replace(":", ""); let shader_module = device.create_shader_module(ShaderModuleDescriptor { - label: Some(&format!("PerPixelAdjust {} wgsl shader", name)), + label: Some(&format!("PerPixelAdjust {name} wgsl shader")), source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)), }); @@ -107,16 +107,16 @@ impl PerPixelAdjustGraphicsPipeline { }] }; let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor { - label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)), + label: Some(&format!("PerPixelAdjust {name} PipelineLayout")), bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor { - label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)), + label: Some(&format!("PerPixelAdjust {name} BindGroupLayout 0")), entries, })], push_constant_ranges: &[], }); let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor { - label: Some(&format!("PerPixelAdjust {} Pipeline", name)), + label: Some(&format!("PerPixelAdjust {name} Pipeline")), layout: Some(&pipeline_layout), vertex: VertexState { module: &shader_module, @@ -155,7 +155,7 @@ impl PerPixelAdjustGraphicsPipeline { } } - pub fn dispatch(&self, context: &Context, textures: Table>, arg_buffer: Option) -> Table> { + pub fn dispatch(&self, context: &WgpuContext, textures: Table>, arg_buffer: Option) -> Table> { assert_eq!(self.has_uniform, arg_buffer.is_some()); let device = &context.device; let name = self.name.as_str();