Skip to content

Commit 63a018f

Browse files
rudimentary custom wgpu adapter selection
1 parent f4a0f27 commit 63a018f

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

desktop/src/main.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ fn main() {
3131
return;
3232
}
3333

34-
let wgpu_context = futures::executor::block_on(WgpuContext::new()).unwrap();
34+
let wgpu_context = futures::executor::block_on(init_wgpu_context());
3535

3636
let event_loop = EventLoop::new().unwrap();
3737
let (app_event_sender, app_event_receiver) = std::sync::mpsc::channel();
@@ -67,3 +67,52 @@ fn main() {
6767

6868
event_loop.run_app(&mut app).unwrap();
6969
}
70+
71+
async fn init_wgpu_context() -> WgpuContext {
72+
// TODO: make this configurable via cli flags instead
73+
let adapter_override = std::env::var("GRAPHITE_WGPU_ADAPTER").ok().map(|s| usize::from_str_radix(&s, 10).ok()).flatten();
74+
75+
let instance_descriptor = wgpu::InstanceDescriptor {
76+
backends: wgpu::Backends::all(),
77+
..Default::default()
78+
};
79+
let instance = wgpu::Instance::new(&instance_descriptor);
80+
81+
let mut adapters = instance.enumerate_adapters(wgpu::Backends::all());
82+
83+
// TODO: add a cli flag to list adapters and exit instead of always printing
84+
let adapters_fmt = adapters
85+
.iter()
86+
.enumerate()
87+
.map(|(i, a)| {
88+
let info = a.get_info();
89+
format!(
90+
"\nAdapter {}:\n Name: {}\n Backend: {:?}\n Driver: {}\n Device ID: {}\n Vendor ID: {}",
91+
i, info.name, info.backend, info.driver, info.device, info.vendor
92+
)
93+
})
94+
.collect::<Vec<_>>()
95+
.join("\n");
96+
println!("\nAvailable wgpu adapters:\n {}\n", adapters_fmt);
97+
98+
let adapter_index = if let Some(index) = adapter_override
99+
&& index < adapters.len()
100+
{
101+
index
102+
} else if cfg!(target_os = "windows") {
103+
match adapters.iter().enumerate().find(|(_, a)| a.get_info().backend == wgpu::Backend::Dx12) {
104+
Some((index, _)) => index,
105+
None => 0,
106+
}
107+
} else {
108+
0 // Same behavior as requests adapter
109+
};
110+
111+
tracing::info!("Using WGPU adapter {adapter_index}");
112+
113+
let adapter = adapters.remove(adapter_index);
114+
115+
WgpuContext::new_with_instance_and_adapter(instance, adapter)
116+
.await
117+
.expect("Failed to create WGPU context with specified adapter")
118+
}

node-graph/wgpu-executor/src/context.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ pub struct Context {
1111

1212
impl Context {
1313
pub async fn new() -> Option<Self> {
14-
// Instantiates instance of WebGPU
1514
let instance_descriptor = wgpu::InstanceDescriptor {
1615
backends: wgpu::Backends::all(),
1716
..Default::default()
@@ -23,12 +22,14 @@ impl Context {
2322
compatible_surface: None,
2423
force_fallback_adapter: false,
2524
};
26-
// `request_adapter` instantiates the general connection to the GPU
25+
2726
let adapter = instance.request_adapter(&adapter_options).await.ok()?;
2827

28+
Self::new_with_instance_and_adapter(instance, adapter).await
29+
}
30+
31+
pub async fn new_with_instance_and_adapter(instance: wgpu::Instance, adapter: wgpu::Adapter) -> Option<Self> {
2932
let required_limits = adapter.limits();
30-
// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
31-
// `features` being the available features.
3233
let (device, queue) = adapter
3334
.request_device(&wgpu::DeviceDescriptor {
3435
label: None,

0 commit comments

Comments
 (0)