diff --git a/.cargo/config.toml b/.cargo/config.toml index b547570c3..4b4d98777 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,5 +1,6 @@ [env] MACOSX_DEPLOYMENT_TARGET = { value = "14.2", force = true } +TOOLCHAINS = { value = "com.apple.dt.toolchain.XcodeDefault", force = true } [target.'cfg(all(windows, target_env = "msvc"))'] rustflags = [ diff --git a/.swift-version b/.swift-version deleted file mode 100644 index 92f2ea299..000000000 --- a/.swift-version +++ /dev/null @@ -1 +0,0 @@ -6.1.2 \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index e83909b2c..bcc1dcea0 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -22,7 +22,10 @@ "rust-analyzer.checkOnSave": true, "rust-analyzer.check.allTargets": true, "rust-analyzer.cargo.targetDir": "target/analyzer", - "rust-analyzer.cargo.extraEnv": { "MACOSX_DEPLOYMENT_TARGET": "14.2" }, + "rust-analyzer.cargo.extraEnv": { + "MACOSX_DEPLOYMENT_TARGET": "14.2", + "TOOLCHAINS": "com.apple.dt.toolchain.XcodeDefault" + }, "deno.enable": true, "deno.lint": true, "deno.enablePaths": [ diff --git a/crates/am2/.gitignore b/crates/am2/.gitignore index 9a1d9108e..fe5693ade 100644 --- a/crates/am2/.gitignore +++ b/crates/am2/.gitignore @@ -1 +1 @@ -swift-lib/.build +temp-resolver/ diff --git a/crates/am2/build.rs b/crates/am2/build.rs index dd5adff1d..f40a30517 100644 --- a/crates/am2/build.rs +++ b/crates/am2/build.rs @@ -1,15 +1,101 @@ fn main() { - // Build skipped - uncomment below to re-enable Swift linking - - // #[cfg(target_os = "macos")] - // { - // swift_rs::SwiftLinker::new("13.0") - // .with_package("swift-lib", "./swift-lib/") - // .link(); - // } - // - // #[cfg(not(target_os = "macos"))] - // { - // println!("cargo:warning=Swift linking is only available on macOS"); - // } + #[cfg(target_os = "macos")] + { + use std::path::Path; + use std::process::Command; + + let manifest_dir = Path::new(env!("CARGO_MANIFEST_DIR")); + let swift_lib_path = manifest_dir.join("swift-lib"); + let frameworks_dir = swift_lib_path.join("frameworks"); + let xcframework_dest = frameworks_dir.join("ArgmaxSDK.xcframework"); + + if !xcframework_dest.exists() { + println!("cargo:warning=ArgmaxSDK.xcframework not found, extracting from registry..."); + + let temp_dir = manifest_dir.join("temp-resolver"); + std::fs::create_dir_all(&temp_dir).expect("Failed to create temp directory"); + + let resolver_package = r#"// swift-tools-version:5.10 +import PackageDescription +let package = Package( + name: "resolver", + platforms: [.macOS(.v13)], + dependencies: [ + .package(id: "argmaxinc.argmax-sdk-swift", exact: "1.9.3") + ], + targets: [] +) +"#; + std::fs::write(temp_dir.join("Package.swift"), resolver_package) + .expect("Failed to write resolver Package.swift"); + + let status = Command::new("swift") + .args(["package", "resolve"]) + .current_dir(&temp_dir) + .status() + .expect("Failed to run swift package resolve"); + + if !status.success() { + panic!("swift package resolve failed"); + } + + let artifacts_path = temp_dir.join( + ".build/artifacts/argmaxinc.argmax-sdk-swift/ArgmaxSDK/ArgmaxSDK.xcframework", + ); + + if !artifacts_path.exists() { + panic!( + "ArgmaxSDK.xcframework not found in artifacts at: {:?}", + artifacts_path + ); + } + + std::fs::create_dir_all(&frameworks_dir) + .expect("Failed to create frameworks directory"); + + let status = Command::new("cp") + .args([ + "-R", + artifacts_path.to_str().unwrap(), + xcframework_dest.to_str().unwrap(), + ]) + .status() + .expect("Failed to copy xcframework"); + + if !status.success() { + panic!("Failed to copy ArgmaxSDK.xcframework"); + } + + std::fs::remove_dir_all(&temp_dir).ok(); + + println!("cargo:warning=ArgmaxSDK.xcframework extracted successfully"); + } + + let out_dir = std::env::var("OUT_DIR").unwrap(); + let swift_build_dir = Path::new(&out_dir).join("swift-rs/swift-lib"); + let workspace_state = swift_build_dir.join("workspace-state.json"); + if workspace_state.exists() { + std::fs::remove_file(&workspace_state).ok(); + } + + swift_rs::SwiftLinker::new("13.0") + .with_package("swift-lib", "./swift-lib/") + .link(); + + let framework_path = xcframework_dest.join("macos-arm64"); + println!( + "cargo:rustc-link-search=framework={}", + framework_path.display() + ); + println!("cargo:rustc-link-lib=framework=ArgmaxSDK"); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + framework_path.display() + ); + } + + #[cfg(not(target_os = "macos"))] + { + println!("cargo:warning=Swift linking is only available on macOS"); + } } diff --git a/crates/am2/src/lib.rs b/crates/am2/src/lib.rs index a7f5b1e58..41912ec20 100644 --- a/crates/am2/src/lib.rs +++ b/crates/am2/src/lib.rs @@ -1,12 +1,15 @@ -use swift_rs::swift; +use swift_rs::{swift, SRString}; -swift!(fn initialize_am2_sdk()); +swift!(fn initialize_am2_sdk(api_key: &SRString)); swift!(fn check_am2_ready() -> bool); -pub fn init() { +swift!(fn transcribe_audio_file(path: &SRString) -> SRString); + +pub fn init(api_key: &str) { + let key = SRString::from(api_key); unsafe { - initialize_am2_sdk(); + initialize_am2_sdk(&key); } } @@ -14,13 +17,37 @@ pub fn is_ready() -> bool { unsafe { check_am2_ready() } } +pub fn transcribe(audio_path: &str) -> String { + let path = SRString::from(audio_path); + unsafe { transcribe_audio_file(&path).to_string() } +} + #[cfg(test)] mod tests { use super::*; #[test] fn test_am2_swift_compilation() { - init(); + let api_key = std::env::var("AM_API_KEY").expect("AM_API_KEY env var required"); + init(&api_key); assert!(is_ready()); } + + #[test] + fn test_transcribe_audio() { + let api_key = std::env::var("AM_API_KEY").expect("AM_API_KEY env var required"); + init(&api_key); + + let audio_path = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../data/src/english_1/audio.wav" + ); + println!("Audio path: {}", audio_path); + + let result = transcribe(audio_path); + println!("Transcription result: {}", result); + + assert!(!result.is_empty()); + assert!(!result.starts_with("Error:")); + } } diff --git a/crates/am2/swift-lib/.gitignore b/crates/am2/swift-lib/.gitignore new file mode 100644 index 000000000..83d6c3c84 --- /dev/null +++ b/crates/am2/swift-lib/.gitignore @@ -0,0 +1,3 @@ +.build/ +frameworks/ + diff --git a/crates/am2/swift-lib/Package.resolved b/crates/am2/swift-lib/Package.resolved index 4f7f75831..17a43d3b3 100644 --- a/crates/am2/swift-lib/Package.resolved +++ b/crates/am2/swift-lib/Package.resolved @@ -1,14 +1,6 @@ { - "originHash" : "63df145a8e597cfd8fde517d9454c3a25d71fa829aafde57c8cf6c7e366c01d9", + "originHash" : "db058fcc5de37fe3ac3128ad03ec785365fe9373022279ff77d83a0d63d5dad0", "pins" : [ - { - "identity" : "argmaxinc.argmax-sdk-swift", - "kind" : "registry", - "location" : "", - "state" : { - "version" : "1.9.3" - } - }, { "identity" : "jinja", "kind" : "remoteSourceControl", @@ -23,8 +15,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-argument-parser.git", "state" : { - "revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b", - "version" : "1.4.0" + "revision" : "cdd0ef3755280949551dc26dee5de9ddeda89f54", + "version" : "1.6.2" } }, { @@ -49,8 +41,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "f000aa7aec0e78acd0211685e4094e1fca84cd8b", - "version" : "0.1.24" + "revision" : "8a83416cc00ab07a5de9991e6ad817a9b8588d20", + "version" : "0.1.15" } }, { diff --git a/crates/am2/swift-lib/Package.swift b/crates/am2/swift-lib/Package.swift index a874c1574..aa0a6fb9b 100644 --- a/crates/am2/swift-lib/Package.swift +++ b/crates/am2/swift-lib/Package.swift @@ -1,10 +1,10 @@ -// swift-tools-version:6.0 +// swift-tools-version:5.10 import PackageDescription let package = Package( name: "swift-lib", - platforms: [.macOS("13.0")], + platforms: [.macOS(.v13)], products: [ .library( name: "swift-lib", @@ -15,24 +15,21 @@ let package = Package( .package( url: "https://github.com/Brendonovich/swift-rs", revision: "01980f981bc642a6da382cc0788f18fdd4cde6df"), - .package( - id: "argmaxinc.argmax-sdk-swift", - exact: "1.9.3"), - .package( - url: "https://github.com/huggingface/swift-transformers.git", - exact: "0.1.24") + .package(url: "https://github.com/argmaxinc/WhisperKit.git", exact: "0.14.1") ], targets: [ + .binaryTarget( + name: "ArgmaxSDK", + path: "frameworks/ArgmaxSDK.xcframework" + ), .target( name: "swift-lib", dependencies: [ .product(name: "SwiftRs", package: "swift-rs"), - .product(name: "Argmax", package: "argmaxinc.argmax-sdk-swift") + .product(name: "WhisperKit", package: "WhisperKit"), + "ArgmaxSDK" ], - path: "src", - swiftSettings: [ - .swiftLanguageMode(.v5) - ] + path: "src" ) ] ) diff --git a/crates/am2/swift-lib/src/lib.swift b/crates/am2/swift-lib/src/lib.swift index 1d1c95d43..f1cbdbea3 100644 --- a/crates/am2/swift-lib/src/lib.swift +++ b/crates/am2/swift-lib/src/lib.swift @@ -1,13 +1,23 @@ -import Argmax +import ArgmaxSDK import Foundation import SwiftRs private var isAM2Ready = false +private var whisperKitPro: WhisperKitPro? @_cdecl("initialize_am2_sdk") -public func initialize_am2_sdk() { - isAM2Ready = true - print("AM2 SDK initialized successfully") +public func initialize_am2_sdk(apiKey: SRString) { + let key = apiKey.toString() + let semaphore = DispatchSemaphore(value: 0) + + Task { + await ArgmaxSDK.with(ArgmaxConfig(apiKey: key)) + isAM2Ready = true + print("AM2 SDK initialized successfully with API key: \(key.prefix(10))...") + semaphore.signal() + } + + semaphore.wait() } @_cdecl("check_am2_ready") @@ -15,3 +25,34 @@ public func check_am2_ready() -> Bool { return isAM2Ready } +@_cdecl("transcribe_audio_file") +public func transcribe_audio_file(path: SRString) -> SRString { + let audioPath = path.toString() + print("Transcribing: \(audioPath)") + + var result = "" + + let semaphore = DispatchSemaphore(value: 0) + + Task { + do { + let config = WhisperKitProConfig(model: "large-v3-v20240930_626MB") + let kit = try await WhisperKitPro(config) + whisperKitPro = kit + + print("WhisperKitPro initialized, starting transcription...") + + let results = try await kit.transcribe(audioPath: audioPath) + let transcript = WhisperKitProUtils.mergeTranscriptionResults(results).text + result = transcript + print("Transcription complete: \(result)") + } catch { + result = "Error: \(error.localizedDescription)" + print("Transcription error: \(error)") + } + semaphore.signal() + } + + semaphore.wait() + return SRString(result) +}