diff --git a/.gitignore b/.gitignore index 72a19a8..2bc5d85 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ Cargo.lock .idea vkd3d-proton.cache* +*.swp diff --git a/src/hooks/dx12.rs b/src/hooks/dx12.rs index 7332e3c..c1cfe1c 100644 --- a/src/hooks/dx12.rs +++ b/src/hooks/dx12.rs @@ -2,6 +2,7 @@ use std::ffi::c_void; use std::mem; +use std::sync::atomic::Ordering; use std::sync::OnceLock; use imgui::Context; @@ -28,7 +29,7 @@ use windows::Win32::Graphics::Dxgi::{ use super::DummyHwnd; use crate::mh::MhHook; use crate::renderer::{D3D12RenderEngine, Pipeline}; -use crate::{util, Hooks, ImguiRenderLoop}; +use crate::{perform_eject, util, Hooks, ImguiRenderLoop, EJECT_REQUESTED, HOOK_EJECTION_BARRIER}; type DXGISwapChainPresentType = unsafe extern "system" fn(this: IDXGISwapChain3, sync_interval: u32, flags: u32) -> HRESULT; @@ -194,6 +195,7 @@ unsafe extern "system" fn dxgi_swap_chain_present_impl( sync_interval: u32, flags: u32, ) -> HRESULT { + let _hook_ejection_guard = HOOK_EJECTION_BARRIER.acquire_ejection_guard(); { INITIALIZATION_CONTEXT.lock().insert_swap_chain(&swap_chain); } @@ -207,7 +209,13 @@ unsafe extern "system" fn dxgi_swap_chain_present_impl( } trace!("Call IDXGISwapChain::Present trampoline"); - dxgi_swap_chain_present(swap_chain, sync_interval, flags) + let result = dxgi_swap_chain_present(swap_chain, sync_interval, flags); + + if EJECT_REQUESTED.load(Ordering::SeqCst) { + perform_eject(); + } + + result } unsafe extern "system" fn dxgi_swap_chain_resize_buffers_impl( @@ -218,6 +226,7 @@ unsafe extern "system" fn dxgi_swap_chain_resize_buffers_impl( new_format: DXGI_FORMAT, flags: u32, ) -> HRESULT { + let _hook_ejection_guard = HOOK_EJECTION_BARRIER.acquire_ejection_guard(); let Trampolines { dxgi_swap_chain_resize_buffers, .. } = TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized"); @@ -230,6 +239,7 @@ unsafe extern "system" fn d3d12_command_queue_execute_command_lists_impl( num_command_lists: u32, command_lists: *mut ID3D12CommandList, ) { + let _hook_ejection_guard = HOOK_EJECTION_BARRIER.acquire_ejection_guard(); trace!( "ID3D12CommandQueue::ExecuteCommandLists({command_queue:?}, {num_command_lists}, \ {command_lists:p}) invoked", @@ -390,6 +400,7 @@ impl Hooks for ImguiDx12Hooks { TRAMPOLINES.take(); PIPELINE.take().map(|p| p.into_inner().take()); RENDER_LOOP.take(); // should already be null + *INITIALIZATION_CONTEXT.lock() = InitializationContext::Empty; } } diff --git a/src/lib.rs b/src/lib.rs index 155089e..d5ae7b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,7 +119,7 @@ use std::thread; use imgui::{Context, Io, TextureId, Ui}; use once_cell::sync::OnceCell; -use tracing::error; +use tracing::{error, trace, warn}; use windows::core::Error; use windows::Win32::Foundation::{HINSTANCE, HWND, LPARAM, WPARAM}; use windows::Win32::System::Console::{ @@ -130,6 +130,7 @@ use windows::Win32::System::LibraryLoader::FreeLibraryAndExitThread; pub use {imgui, tracing, windows}; use crate::mh::{MH_ApplyQueued, MH_Initialize, MH_Uninitialize, MhHook, MH_STATUS}; +use crate::util::HookEjectionBarrier; pub mod hooks; #[cfg(feature = "inject")] @@ -145,6 +146,8 @@ pub mod util; static mut MODULE: OnceCell = OnceCell::new(); static mut HUDHOOK: OnceCell = OnceCell::new(); static CONSOLE_ALLOCATED: AtomicBool = AtomicBool::new(false); +static EJECT_REQUESTED: AtomicBool = AtomicBool::new(false); +static HOOK_EJECTION_BARRIER: HookEjectionBarrier = HookEjectionBarrier::new(); /// Texture Loader for ImguiRenderLoop callbacks to load and replace textures pub trait RenderContext { @@ -215,20 +218,35 @@ pub fn free_console() -> Result<(), Error> { /// Befor calling [`eject`], make sure to perform any manual cleanup (e.g. /// dropping/resetting the contents of static mutable variables). pub fn eject() { - thread::spawn(|| unsafe { - if let Err(e) = free_console() { - error!("{e:?}"); - } + trace!("Requesting eject"); + EJECT_REQUESTED.store(true, Ordering::SeqCst); +} - if let Some(mut hudhook) = HUDHOOK.take() { - if let Err(e) = hudhook.unapply() { - error!("Couldn't unapply hooks: {e:?}"); - } +/// Perform the ejection that was previously requested +unsafe fn perform_eject() { + trace!("Performing eject"); + if let Err(e) = free_console() { + error!("{e:?}"); + } + + if let Some(mut hudhook) = HUDHOOK.take() { + if let Err(e) = hudhook.unapply() { + error!("Couldn't unapply hooks: {e:?}"); } + } + + thread::spawn(|| unsafe { + // Wait for all hook ejection guards to complete. As we have + // already called `hudhook.unapply()` above any future invocations + // of the hooked functions will call the original code, so we just + // have to wait for the previous hook invocations to complete before + // we continue to free the library. + HOOK_EJECTION_BARRIER.wait_for_all_guards(); if let Some(module) = MODULE.take() { FreeLibraryAndExitThread(module, 0); } + trace!("Finished ejecting!"); }); } @@ -313,7 +331,10 @@ impl Hudhook { fn new() -> Self { // Initialize minhook. match unsafe { MH_Initialize() } { - MH_STATUS::MH_ERROR_ALREADY_INITIALIZED | MH_STATUS::MH_OK => {}, + MH_STATUS::MH_OK => {}, + MH_STATUS::MH_ERROR_ALREADY_INITIALIZED => { + warn!("Minhook already initialized"); + }, status @ MH_STATUS::MH_ERROR_MEMORY_ALLOC => panic!("MH_Initialize: {status:?}"), _ => unreachable!(), } @@ -343,6 +364,7 @@ impl Hudhook { /// Disable and cleanup the hooks. pub fn unapply(&mut self) -> Result<(), MH_STATUS> { + trace!("Unapply hook"); // Queue disabling all the hooks. for hook in self.hooks() { unsafe { hook.queue_disable()? }; @@ -358,6 +380,7 @@ impl Hudhook { for hook in &mut self.0 { unsafe { hook.unhook() }; } + trace!("Finished removing hook"); Ok(()) } diff --git a/src/util.rs b/src/util.rs index da55306..26f63f4 100644 --- a/src/util.rs +++ b/src/util.rs @@ -8,6 +8,7 @@ use std::os::windows::ffi::OsStringExt; use std::path::PathBuf; use std::sync::atomic::{AtomicU64, Ordering}; +use parking_lot::{RwLock, RwLockReadGuard}; use tracing::{debug, error}; use windows::core::s; use windows::Win32::Foundation::{HANDLE, HMODULE, HWND, MAX_PATH, RECT}; @@ -350,6 +351,50 @@ pub unsafe fn readable_region(ptr: *const T, limit: usize) -> &'static [T] { std::slice::from_raw_parts(ptr, limit) } +/// Implements a barrier to coordinate ejection of hooks +/// +/// # Usave +/// - Hooked functions should call and maintain a guard from +/// `acquire_ejection_guard()` while they are in progress. +/// - Ejecting code should call `hudhook.unapply()` to ensure that no more +/// ejection guards will be acquired and then call `wait_for_all_blocks()` to +/// allow all hooks to exit before calling `FreeLibraryAndExitThread()` +/// +/// This is implemented with a RwLock which allows us to have multiple +/// ejection guards in place without blocking each other, and then wait +/// for all the guards to complete before ejecting. +pub struct HookEjectionBarrier(RwLock<()>); +impl HookEjectionBarrier { + /// Construct a new ejection barrier + pub const fn new() -> Self { + Self(RwLock::new(())) + } + + /// Acquire a guard to prevent ejection while the guard exists + /// + /// Multiple guards can be acquired simultaneously and do not block + /// each other. + pub fn acquire_ejection_guard(&self) -> RwLockReadGuard<'_, ()> { + self.0.read() + } + + /// Wait for ejection to be safe. + /// + /// All ejection guards will be awaited before continuing. After this + /// is called `acquire_ejection_guard()` should not be called again. + pub fn wait_for_all_guards(&self) { + // Note: We immediately drop the write lock once acquired, we just + // need to ensure that all read locks have also been dropped. + let _wait_guard = self.0.write(); + } +} + +impl Default for HookEjectionBarrier { + fn default() -> Self { + Self::new() + } +} + #[cfg(test)] mod tests { use windows::Win32::System::Memory::{VirtualAlloc, VirtualProtect, MEM_COMMIT, PAGE_NOACCESS};