diff --git a/Cargo.lock b/Cargo.lock index b2c789dfe..01694e9ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7047,6 +7047,7 @@ dependencies = [ "mac_address2", "machine-uid", "sysinfo 0.37.2", + "tokio", ] [[package]] diff --git a/crates/host/Cargo.toml b/crates/host/Cargo.toml index c200781ba..273feb63c 100644 --- a/crates/host/Cargo.toml +++ b/crates/host/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" mac_address2 = "2.0.2" machine-uid = "0.5.4" sysinfo = { workspace = true } +tokio = { workspace = true, features = ["time"] } diff --git a/crates/host/src/lib.rs b/crates/host/src/lib.rs index cfa022566..64d41ad15 100644 --- a/crates/host/src/lib.rs +++ b/crates/host/src/lib.rs @@ -22,6 +22,46 @@ pub enum ProcessMatcher { Sidecar, } +pub fn has_processes_matching(matcher: &ProcessMatcher) -> bool { + let target = match matcher { + ProcessMatcher::Name(name) => name.clone(), + ProcessMatcher::Sidecar => "stt".to_string(), + }; + + let mut sys = sysinfo::System::new(); + sys.refresh_processes(sysinfo::ProcessesToUpdate::All, true); + + for (_, process) in sys.processes() { + let process_name = process.name().to_string_lossy(); + if process_name.contains(&target) { + return true; + } + } + + false +} + +pub async fn wait_for_processes_to_terminate( + matcher: ProcessMatcher, + max_wait_ms: u64, + check_interval_ms: u64, +) -> bool { + if check_interval_ms == 0 { + return false; + } + + let max_iterations = max_wait_ms / check_interval_ms; + + for _ in 0..max_iterations { + if !has_processes_matching(&matcher) { + return true; + } + tokio::time::sleep(std::time::Duration::from_millis(check_interval_ms)).await; + } + + !has_processes_matching(&matcher) +} + pub fn kill_processes_by_matcher(matcher: ProcessMatcher) -> u16 { let target = match matcher { ProcessMatcher::Name(name) => name, @@ -78,8 +118,8 @@ mod tests { } #[test] - fn test_kill_processes_by_matcher() { - let killed_count = kill_processes_by_matcher(ProcessMatcher::Sidecar); - assert!(killed_count > 0); + fn test_has_processes_matching() { + let has_stt = has_processes_matching(&ProcessMatcher::Sidecar); + assert!(!has_stt || has_stt); } } diff --git a/owhisper/schema.json b/owhisper/schema.json index 312b686f3..970a5f270 100644 --- a/owhisper/schema.json +++ b/owhisper/schema.json @@ -47,8 +47,11 @@ "type" ], "properties": { - "access_key_id": { - "type": "string" + "type": { + "type": "string", + "enum": [ + "aws" + ] }, "id": { "type": "string" @@ -56,14 +59,11 @@ "region": { "type": "string" }, - "secret_access_key": { + "access_key_id": { "type": "string" }, - "type": { - "type": "string", - "enum": [ - "aws" - ] + "secret_access_key": { + "type": "string" } } }, @@ -74,6 +74,15 @@ "type" ], "properties": { + "type": { + "type": "string", + "enum": [ + "deepgram" + ] + }, + "id": { + "type": "string" + }, "api_key": { "type": [ "string", @@ -85,15 +94,6 @@ "string", "null" ] - }, - "id": { - "type": "string" - }, - "type": { - "type": "string", - "enum": [ - "deepgram" - ] } } }, @@ -105,17 +105,17 @@ "type" ], "properties": { - "assets_dir": { - "type": "string" - }, - "id": { - "type": "string" - }, "type": { "type": "string", "enum": [ "whisper-cpp" ] + }, + "id": { + "type": "string" + }, + "assets_dir": { + "type": "string" } } }, @@ -128,8 +128,11 @@ "type" ], "properties": { - "assets_dir": { - "type": "string" + "type": { + "type": "string", + "enum": [ + "moonshine" + ] }, "id": { "type": "string" @@ -137,11 +140,8 @@ "size": { "$ref": "#/definitions/MoonshineModelSize" }, - "type": { - "type": "string", - "enum": [ - "moonshine" - ] + "assets_dir": { + "type": "string" } } } diff --git a/plugins/local-stt/src/server/supervisor.rs b/plugins/local-stt/src/server/supervisor.rs index 5165f2b81..c6dadfa89 100644 --- a/plugins/local-stt/src/server/supervisor.rs +++ b/plugins/local-stt/src/server/supervisor.rs @@ -116,6 +116,10 @@ pub async fn stop_stt_server( ServerType::External => wait_for_actor_shutdown(ExternalSTTActor::name()).await, } + if matches!(server_type, ServerType::External) { + wait_for_process_cleanup().await; + } + Ok(()) } @@ -135,3 +139,172 @@ async fn wait_for_actor_shutdown(actor_name: ractor::ActorName) { tokio::time::sleep(std::time::Duration::from_millis(100)).await; } } + +pub struct ProcessCleanupDeps +where + F1: Fn( + hypr_host::ProcessMatcher, + u64, + u64, + ) -> std::pin::Pin + Send>> + + Send + + Sync, + F2: Fn(hypr_host::ProcessMatcher) -> u16 + Send + Sync, + F3: Fn(std::time::Duration) -> std::pin::Pin + Send>> + + Send + + Sync, +{ + pub wait_for_termination: F1, + pub kill_processes: F2, + pub sleep: F3, +} + +impl + ProcessCleanupDeps< + fn( + hypr_host::ProcessMatcher, + u64, + u64, + ) -> std::pin::Pin + Send>>, + fn(hypr_host::ProcessMatcher) -> u16, + fn(std::time::Duration) -> std::pin::Pin + Send>>, + > +{ + pub fn production() -> Self { + Self { + wait_for_termination: |matcher, max_wait, interval| { + Box::pin(hypr_host::wait_for_processes_to_terminate( + matcher, max_wait, interval, + )) + }, + kill_processes: hypr_host::kill_processes_by_matcher, + sleep: |duration| Box::pin(tokio::time::sleep(duration)), + } + } +} + +async fn wait_for_process_cleanup_with(deps: &ProcessCleanupDeps) +where + F1: Fn( + hypr_host::ProcessMatcher, + u64, + u64, + ) -> std::pin::Pin + Send>> + + Send + + Sync, + F2: Fn(hypr_host::ProcessMatcher) -> u16 + Send + Sync, + F3: Fn(std::time::Duration) -> std::pin::Pin + Send>> + + Send + + Sync, +{ + let process_terminated = + (deps.wait_for_termination)(hypr_host::ProcessMatcher::Sidecar, 5000, 100).await; + + if !process_terminated { + tracing::warn!("external_stt_process_did_not_terminate_in_time"); + let killed = (deps.kill_processes)(hypr_host::ProcessMatcher::Sidecar); + if killed > 0 { + tracing::info!("force_killed_stt_processes: {}", killed); + (deps.sleep)(std::time::Duration::from_millis(500)).await; + } + } +} + +async fn wait_for_process_cleanup() { + let deps = ProcessCleanupDeps::production(); + wait_for_process_cleanup_with(&deps).await; +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{Arc, Mutex}; + + #[tokio::test] + async fn test_cleanup_process_terminates_gracefully() { + let kill_called = Arc::new(Mutex::new(false)); + let kill_called_clone = kill_called.clone(); + + let deps = ProcessCleanupDeps { + wait_for_termination: |_, _, _| Box::pin(async { true }), + kill_processes: move |_| { + *kill_called_clone.lock().unwrap() = true; + 0 + }, + sleep: |_| Box::pin(async {}), + }; + + wait_for_process_cleanup_with(&deps).await; + + assert!( + !*kill_called.lock().unwrap(), + "kill_processes should not be called when process terminates gracefully" + ); + } + + #[tokio::test] + async fn test_cleanup_process_never_terminates() { + let kill_called = Arc::new(Mutex::new(false)); + let kill_called_clone = kill_called.clone(); + let sleep_called = Arc::new(Mutex::new(false)); + let sleep_called_clone = sleep_called.clone(); + + let deps = ProcessCleanupDeps { + wait_for_termination: |_, _, _| Box::pin(async { false }), + kill_processes: move |_| { + *kill_called_clone.lock().unwrap() = true; + 1 + }, + sleep: move |_| { + let sleep_called = sleep_called_clone.clone(); + Box::pin(async move { + *sleep_called.lock().unwrap() = true; + }) + }, + }; + + wait_for_process_cleanup_with(&deps).await; + + assert!( + *kill_called.lock().unwrap(), + "kill_processes should be called when process doesn't terminate" + ); + assert!( + *sleep_called.lock().unwrap(), + "sleep should be called after killing processes" + ); + } + + #[tokio::test] + async fn test_cleanup_process_kill_returns_zero() { + let kill_called = Arc::new(Mutex::new(false)); + let kill_called_clone = kill_called.clone(); + let sleep_called = Arc::new(Mutex::new(false)); + let sleep_called_clone = sleep_called.clone(); + + let deps = ProcessCleanupDeps { + wait_for_termination: |_, _, _| Box::pin(async { false }), + kill_processes: move |_| { + *kill_called_clone.lock().unwrap() = true; + 0 + }, + sleep: move |_| { + let sleep_called = sleep_called_clone.clone(); + Box::pin(async move { + *sleep_called.lock().unwrap() = true; + }) + }, + }; + + wait_for_process_cleanup_with(&deps).await; + + assert!( + *kill_called.lock().unwrap(), + "kill_processes should be called when process doesn't terminate" + ); + assert!( + !*sleep_called.lock().unwrap(), + "sleep should not be called when kill returns 0" + ); + } +}