Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Cargo.lock

.idea
vkd3d-proton.cache*
*.swp
15 changes: 13 additions & 2 deletions src/hooks/dx12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::ffi::c_void;
use std::mem;
use std::sync::atomic::Ordering;
use std::sync::OnceLock;

use imgui::Context;
Expand All @@ -28,7 +29,7 @@ use windows::Win32::Graphics::Dxgi::{
use super::DummyHwnd;
use crate::mh::MhHook;
use crate::renderer::{D3D12RenderEngine, Pipeline};
use crate::{util, Hooks, ImguiRenderLoop};
use crate::{perform_eject, util, Hooks, ImguiRenderLoop, EJECT_REQUESTED, HOOK_EJECTION_BARRIER};

type DXGISwapChainPresentType =
unsafe extern "system" fn(this: IDXGISwapChain3, sync_interval: u32, flags: u32) -> HRESULT;
Expand Down Expand Up @@ -194,6 +195,7 @@ unsafe extern "system" fn dxgi_swap_chain_present_impl(
sync_interval: u32,
flags: u32,
) -> HRESULT {
let _hook_ejection_guard = HOOK_EJECTION_BARRIER.acquire_ejection_guard();
{
INITIALIZATION_CONTEXT.lock().insert_swap_chain(&swap_chain);
}
Expand All @@ -207,7 +209,13 @@ unsafe extern "system" fn dxgi_swap_chain_present_impl(
}

trace!("Call IDXGISwapChain::Present trampoline");
dxgi_swap_chain_present(swap_chain, sync_interval, flags)
let result = dxgi_swap_chain_present(swap_chain, sync_interval, flags);

if EJECT_REQUESTED.load(Ordering::SeqCst) {
perform_eject();
}

result
}

unsafe extern "system" fn dxgi_swap_chain_resize_buffers_impl(
Expand All @@ -218,6 +226,7 @@ unsafe extern "system" fn dxgi_swap_chain_resize_buffers_impl(
new_format: DXGI_FORMAT,
flags: u32,
) -> HRESULT {
let _hook_ejection_guard = HOOK_EJECTION_BARRIER.acquire_ejection_guard();
let Trampolines { dxgi_swap_chain_resize_buffers, .. } =
TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized");

Expand All @@ -230,6 +239,7 @@ unsafe extern "system" fn d3d12_command_queue_execute_command_lists_impl(
num_command_lists: u32,
command_lists: *mut ID3D12CommandList,
) {
let _hook_ejection_guard = HOOK_EJECTION_BARRIER.acquire_ejection_guard();
trace!(
"ID3D12CommandQueue::ExecuteCommandLists({command_queue:?}, {num_command_lists}, \
{command_lists:p}) invoked",
Expand Down Expand Up @@ -390,6 +400,7 @@ impl Hooks for ImguiDx12Hooks {
TRAMPOLINES.take();
PIPELINE.take().map(|p| p.into_inner().take());
RENDER_LOOP.take(); // should already be null

*INITIALIZATION_CONTEXT.lock() = InitializationContext::Empty;
}
}
43 changes: 33 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ use std::thread;

use imgui::{Context, Io, TextureId, Ui};
use once_cell::sync::OnceCell;
use tracing::error;
use tracing::{error, trace, warn};
use windows::core::Error;
use windows::Win32::Foundation::{HINSTANCE, HWND, LPARAM, WPARAM};
use windows::Win32::System::Console::{
Expand All @@ -130,6 +130,7 @@ use windows::Win32::System::LibraryLoader::FreeLibraryAndExitThread;
pub use {imgui, tracing, windows};

use crate::mh::{MH_ApplyQueued, MH_Initialize, MH_Uninitialize, MhHook, MH_STATUS};
use crate::util::HookEjectionBarrier;

pub mod hooks;
#[cfg(feature = "inject")]
Expand All @@ -145,6 +146,8 @@ pub mod util;
static mut MODULE: OnceCell<HINSTANCE> = OnceCell::new();
static mut HUDHOOK: OnceCell<Hudhook> = OnceCell::new();
static CONSOLE_ALLOCATED: AtomicBool = AtomicBool::new(false);
static EJECT_REQUESTED: AtomicBool = AtomicBool::new(false);
static HOOK_EJECTION_BARRIER: HookEjectionBarrier = HookEjectionBarrier::new();

/// Texture Loader for ImguiRenderLoop callbacks to load and replace textures
pub trait RenderContext {
Expand Down Expand Up @@ -215,20 +218,35 @@ pub fn free_console() -> Result<(), Error> {
/// Befor calling [`eject`], make sure to perform any manual cleanup (e.g.
/// dropping/resetting the contents of static mutable variables).
pub fn eject() {
thread::spawn(|| unsafe {
if let Err(e) = free_console() {
error!("{e:?}");
}
trace!("Requesting eject");
EJECT_REQUESTED.store(true, Ordering::SeqCst);
}

if let Some(mut hudhook) = HUDHOOK.take() {
if let Err(e) = hudhook.unapply() {
error!("Couldn't unapply hooks: {e:?}");
}
/// Perform the ejection that was previously requested
unsafe fn perform_eject() {
trace!("Performing eject");
if let Err(e) = free_console() {
error!("{e:?}");
}

if let Some(mut hudhook) = HUDHOOK.take() {
if let Err(e) = hudhook.unapply() {
error!("Couldn't unapply hooks: {e:?}");
}
}

thread::spawn(|| unsafe {
// Wait for all hook ejection guards to complete. As we have
// already called `hudhook.unapply()` above any future invocations
// of the hooked functions will call the original code, so we just
// have to wait for the previous hook invocations to complete before
// we continue to free the library.
HOOK_EJECTION_BARRIER.wait_for_all_guards();

if let Some(module) = MODULE.take() {
FreeLibraryAndExitThread(module, 0);
}
trace!("Finished ejecting!");
});
}

Expand Down Expand Up @@ -313,7 +331,10 @@ impl Hudhook {
fn new() -> Self {
// Initialize minhook.
match unsafe { MH_Initialize() } {
MH_STATUS::MH_ERROR_ALREADY_INITIALIZED | MH_STATUS::MH_OK => {},
MH_STATUS::MH_OK => {},
MH_STATUS::MH_ERROR_ALREADY_INITIALIZED => {
warn!("Minhook already initialized");
},
status @ MH_STATUS::MH_ERROR_MEMORY_ALLOC => panic!("MH_Initialize: {status:?}"),
_ => unreachable!(),
}
Expand Down Expand Up @@ -343,6 +364,7 @@ impl Hudhook {

/// Disable and cleanup the hooks.
pub fn unapply(&mut self) -> Result<(), MH_STATUS> {
trace!("Unapply hook");
// Queue disabling all the hooks.
for hook in self.hooks() {
unsafe { hook.queue_disable()? };
Expand All @@ -358,6 +380,7 @@ impl Hudhook {
for hook in &mut self.0 {
unsafe { hook.unhook() };
}
trace!("Finished removing hook");

Ok(())
}
Expand Down
45 changes: 45 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::os::windows::ffi::OsStringExt;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};

use parking_lot::{RwLock, RwLockReadGuard};
use tracing::{debug, error};
use windows::core::s;
use windows::Win32::Foundation::{HANDLE, HMODULE, HWND, MAX_PATH, RECT};
Expand Down Expand Up @@ -350,6 +351,50 @@ pub unsafe fn readable_region<T>(ptr: *const T, limit: usize) -> &'static [T] {
std::slice::from_raw_parts(ptr, limit)
}

/// Implements a barrier to coordinate ejection of hooks
///
/// # Usave
/// - Hooked functions should call and maintain a guard from
/// `acquire_ejection_guard()` while they are in progress.
/// - Ejecting code should call `hudhook.unapply()` to ensure that no more
/// ejection guards will be acquired and then call `wait_for_all_blocks()` to
/// allow all hooks to exit before calling `FreeLibraryAndExitThread()`
///
/// This is implemented with a RwLock which allows us to have multiple
/// ejection guards in place without blocking each other, and then wait
/// for all the guards to complete before ejecting.
pub struct HookEjectionBarrier(RwLock<()>);
impl HookEjectionBarrier {
/// Construct a new ejection barrier
pub const fn new() -> Self {
Self(RwLock::new(()))
}

/// Acquire a guard to prevent ejection while the guard exists
///
/// Multiple guards can be acquired simultaneously and do not block
/// each other.
pub fn acquire_ejection_guard(&self) -> RwLockReadGuard<'_, ()> {
self.0.read()
}

/// Wait for ejection to be safe.
///
/// All ejection guards will be awaited before continuing. After this
/// is called `acquire_ejection_guard()` should not be called again.
pub fn wait_for_all_guards(&self) {
// Note: We immediately drop the write lock once acquired, we just
// need to ensure that all read locks have also been dropped.
let _wait_guard = self.0.write();
}
}

impl Default for HookEjectionBarrier {
fn default() -> Self {
Self::new()
}
}

#[cfg(test)]
mod tests {
use windows::Win32::System::Memory::{VirtualAlloc, VirtualProtect, MEM_COMMIT, PAGE_NOACCESS};
Expand Down