diff --git a/Cargo.lock b/Cargo.lock index 0c38e924e6..3bd3c5cb93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8991,7 +8991,6 @@ dependencies = [ "futures", "guestmem", "guid", - "headervec", "mesh", "pal", "pal_async", diff --git a/vm/devices/vmbus/vmbus_proxy/Cargo.toml b/vm/devices/vmbus/vmbus_proxy/Cargo.toml index 5495a98bbb..d40ed03fba 100644 --- a/vm/devices/vmbus/vmbus_proxy/Cargo.toml +++ b/vm/devices/vmbus/vmbus_proxy/Cargo.toml @@ -9,7 +9,6 @@ rust-version.workspace = true [target.'cfg(windows)'.dependencies] guestmem.workspace = true guid.workspace = true -headervec.workspace = true vmbus_core.workspace = true mesh.workspace = true pal.workspace = true diff --git a/vm/devices/vmbus/vmbus_proxy/src/lib.rs b/vm/devices/vmbus/vmbus_proxy/src/lib.rs index ef23a72c28..ca8eabe91e 100644 --- a/vm/devices/vmbus/vmbus_proxy/src/lib.rs +++ b/vm/devices/vmbus/vmbus_proxy/src/lib.rs @@ -10,7 +10,6 @@ use futures::poll; use guestmem::GuestMemory; use guid::Guid; -use headervec::HeaderVec; use mesh::CancelContext; use mesh::MeshPayload; use pal::windows::ObjectAttributes; @@ -30,7 +29,6 @@ use vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS; use widestring::Utf16Str; use widestring::utf16str; use windows::Wdk::Storage::FileSystem::NtOpenFile; -use windows::Win32::Foundation::ERROR_MORE_DATA; use windows::Win32::Foundation::ERROR_OPERATION_ABORTED; use windows::Win32::Foundation::HANDLE; use windows::Win32::Foundation::NTSTATUS; @@ -457,60 +455,6 @@ impl VmbusProxy { Ok(()) } - pub fn get_numa_node_map(&self) -> Result> { - unsafe { - // This is a synchronous operation, so don't use the async IO infrastructure. - let mut output = - HeaderVec::::with_capacity( - zeroed::(), - 8, - ); - let mut bytes = 0; - if let Err(e) = DeviceIoControl( - HANDLE(self.file.get().as_raw_handle()), - proxyioctl::IOCTL_VMBUS_PROXY_GET_NUMA_MAP, - None, - 0, - Some(output.as_mut_ptr().cast()), - output.total_byte_capacity() as u32, - Some(&mut bytes), - None, - ) { - if e.code() == ERROR_MORE_DATA.into() { - // The buffer was too small, resize and try again. The proxy returns the required buffer size - // in VpCount on overflow, so use that. - assert!( - bytes as usize >= size_of::() - ); - - let required_len = output.head.VpCount as usize; - output.reserve_tail(required_len - output.tail_capacity()); - DeviceIoControl( - HANDLE(self.file.get().as_raw_handle()), - proxyioctl::IOCTL_VMBUS_PROXY_GET_NUMA_MAP, - None, - 0, - Some(output.as_mut_ptr().cast()), - output.total_byte_capacity() as u32, - Some(&mut bytes), - None, - )?; - } else { - return Err(e); - } - } - - assert!( - bytes as usize - >= size_of::() - + output.head.VpCount as usize - ); - - output.set_tail_len(output.head.VpCount as usize); - Ok(output.tail.to_vec()) - } - } - /// Adds GPADL ioctl data to a buffer. fn add_gpadl( buffer: &mut Vec, diff --git a/vm/devices/vmbus/vmbus_proxy/src/proxyioctl.rs b/vm/devices/vmbus/vmbus_proxy/src/proxyioctl.rs index 1edf583b7d..a1a7ece08f 100644 --- a/vm/devices/vmbus/vmbus_proxy/src/proxyioctl.rs +++ b/vm/devices/vmbus/vmbus_proxy/src/proxyioctl.rs @@ -49,7 +49,6 @@ pub const IOCTL_VMBUS_PROXY_TL_CONNECT_REQUEST: u32 = VMBUS_PROXY_IOCTL(0xc); pub const IOCTL_VMBUS_PROXY_RESTORE_CHANNEL: u32 = VMBUS_PROXY_IOCTL(0xd); pub const IOCTL_VMBUS_PROXY_REVOKE_UNCLAIMED_CHANNELS: u32 = VMBUS_PROXY_IOCTL(0xe); pub const IOCTL_VMBUS_PROXY_RESTORE_SET_INTERRUPT: u32 = VMBUS_PROXY_IOCTL(0xf); -pub const IOCTL_VMBUS_PROXY_GET_NUMA_MAP: u32 = VMBUS_PROXY_IOCTL(0x14); #[repr(C)] #[derive(Copy, Clone, zerocopy::IntoBytes)] @@ -207,9 +206,3 @@ pub struct VMBUS_PROXY_TL_CONNECT_REQUEST_INPUT { pub Vtl: u8, pub Padding: [u8; 3], } - -#[repr(C)] -#[derive(Copy, Clone)] -pub struct VMBUS_PROXY_GET_NUMA_MAP_OUTPUT { - pub VpCount: u32, -} diff --git a/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs b/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs index c973e88e9e..e7e8eb8fc9 100644 --- a/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs +++ b/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs @@ -64,7 +64,6 @@ use vmbus_proxy::ProxyAction; use vmbus_proxy::VmbusProxy; use vmbus_proxy::vmbusioctl::VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS; use vmcore::interrupt::Interrupt; -use windows::Win32::Foundation::ERROR_INVALID_FUNCTION; use windows::Win32::Foundation::ERROR_NOT_FOUND; use windows::Win32::Foundation::ERROR_OPERATION_ABORTED; use zerocopy::IntoBytes; @@ -111,6 +110,7 @@ pub struct ProxyIntegrationBuilder<'a, T: SpawnDriver + Clone> { vtl2_server: Option, mem: Option<&'a GuestMemory>, require_flush_before_start: bool, + vp_to_physical_node_map: Vec, } impl<'a, T: SpawnDriver + Clone> ProxyIntegrationBuilder<'a, T> { @@ -132,6 +132,13 @@ impl<'a, T: SpawnDriver + Clone> ProxyIntegrationBuilder<'a, T> { self } + /// Adds a NUMA node map to be passed to the proxy driver. This map is of the format + /// VP -> Physical NUMA Node. For example, `map[0]` is the physical NUMA node for VP 0. + pub fn vp_to_physical_node_map(mut self, map: Vec) -> Self { + self.vp_to_physical_node_map = map; + self + } + /// Builds and starts the `ProxyIntegration`. pub async fn build(self) -> io::Result { let (cancel_ctx, cancel) = CancelContext::new().with_cancel(); @@ -151,6 +158,7 @@ impl<'a, T: SpawnDriver + Clone> ProxyIntegrationBuilder<'a, T> { self.vtl2_server, flush_recv, self.require_flush_before_start, + self.vp_to_physical_node_map, ), ); @@ -184,6 +192,7 @@ impl ProxyIntegration { vtl2_server: None, mem: None, require_flush_before_start: false, + vp_to_physical_node_map: vec![], } } @@ -242,15 +251,11 @@ impl SavedStatePair { } } -struct VpToPhysicalNodeMap(Vec); +struct VpToPhysicalNodeMap(Vec); impl VpToPhysicalNodeMap { - fn new(nodes: Vec) -> Self { - Self(nodes) - } - fn get_numa_node(&self, vp_index: u32) -> u16 { - self.0.get(vp_index as usize).copied().unwrap_or(0).into() + self.0.get(vp_index as usize).copied().unwrap_or(0) } } @@ -263,7 +268,7 @@ struct ProxyTask { hvsock_response_send: Option>, vtl2_hvsock_response_send: Option>, saved_states: Arc>, - numa_node_map: VpToPhysicalNodeMap, + vp_to_physical_node_map: VpToPhysicalNodeMap, } impl ProxyTask { @@ -273,7 +278,7 @@ impl ProxyTask { hvsock_response_send: Option>, vtl2_hvsock_response_send: Option>, proxy: Arc, - numa_node_map: VpToPhysicalNodeMap, + vp_to_physical_node_map: VpToPhysicalNodeMap, ) -> Self { Self { channels: Arc::new(Mutex::new(HashMap::new())), @@ -287,7 +292,7 @@ impl ProxyTask { saved_state: None, vtl2_saved_state: None, })), - numa_node_map, + vp_to_physical_node_map, } } @@ -343,7 +348,7 @@ impl ProxyTask { RingBufferGpadlHandle: open_request.open_data.ring_gpadl_id.0, DownstreamRingBufferPageOffset: open_request.open_data.ring_offset, NodeNumber: self - .numa_node_map + .vp_to_physical_node_map .get_numa_node(open_request.open_data.target_vp), Padding: 0, }, @@ -973,7 +978,9 @@ impl ProxyTask { .map(|request| VMBUS_SERVER_OPEN_CHANNEL_OUTPUT_PARAMETERS { RingBufferGpadlHandle: request.ring_buffer_gpadl_id.0, DownstreamRingBufferPageOffset: request.downstream_ring_buffer_page_offset, - NodeNumber: self.numa_node_map.get_numa_node(request.target_vp), + NodeNumber: self + .vp_to_physical_node_map + .get_numa_node(request.target_vp), Padding: 0, }); @@ -1116,6 +1123,7 @@ async fn proxy_thread( vtl2_server: Option, flush_recv: mesh::Receiver>, await_flush: bool, + vp_to_physical_node_map: Vec, ) { // Separate the hvsocket relay channels. let (hvsock_request_recv, hvsock_response_send) = server @@ -1143,24 +1151,13 @@ async fn proxy_thread( let (send, recv) = mesh::channel(); let proxy = Arc::new(proxy); - let numa_node_map = VpToPhysicalNodeMap::new(proxy.get_numa_node_map().unwrap_or_else(|err| { - if err.code() == ERROR_INVALID_FUNCTION.into() { - tracing::info!("proxy does not support NUMA node map ioctl"); - } else { - tracing::warn!( - error = &err as &dyn std::error::Error, - "failed to get NUMA node map from proxy" - ); - } - Vec::new() - })); let task = Arc::new(ProxyTask::new( server.control, vtl2_control, hvsock_response_send, vtl2_hvsock_response_send, Arc::clone(&proxy), - numa_node_map, + VpToPhysicalNodeMap(vp_to_physical_node_map), )); let offers = task.run_proxy_actions(send, flush_recv, await_flush); let requests = task.run_server_requests(