Skip to content

Commit 04b5dfa

Browse files
authored
Merge pull request #233 from rampadc/master
2 parents 9091be8 + acb7d05 commit 04b5dfa

File tree

7 files changed

+88
-9
lines changed

7 files changed

+88
-9
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,6 @@ examples/speech.mp3
7070
examples/phoneme_examples/output/*.wav
7171
examples/assorted_checks/benchmarks/output_audio/*
7272
uv.lock
73+
74+
# Mac MPS virtualenv for dual testing
75+
.venv-mps

api/src/core/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pydantic_settings import BaseSettings
2+
import torch
23

34

45
class Settings(BaseSettings):
@@ -15,6 +16,7 @@ class Settings(BaseSettings):
1516
default_voice: str = "af_heart"
1617
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
1718
use_gpu: bool = True # Whether to use GPU acceleration if available
19+
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
1820
allow_local_voice_saving: bool = (
1921
False # Whether to allow saving combined voices locally
2022
)
@@ -51,5 +53,21 @@ class Settings(BaseSettings):
5153
class Config:
5254
env_file = ".env"
5355

56+
def get_device(self) -> str:
57+
"""Get the appropriate device based on settings and availability"""
58+
if not self.use_gpu:
59+
return "cpu"
60+
61+
if self.device_type:
62+
return self.device_type
63+
64+
# Auto-detect device
65+
if torch.backends.mps.is_available():
66+
return "mps"
67+
elif torch.cuda.is_available():
68+
return "cuda"
69+
return "cpu"
70+
71+
5472

5573
settings = Settings()

api/src/inference/kokoro_v1.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self):
2121
"""Initialize backend with environment-based configuration."""
2222
super().__init__()
2323
# Strictly respect settings.use_gpu
24-
self._device = "cuda" if settings.use_gpu else "cpu"
24+
self._device = settings.get_device()
2525
self._model: Optional[KModel] = None
2626
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
2727

@@ -48,9 +48,14 @@ async def load_model(self, path: str) -> None:
4848

4949
# Load model and let KModel handle device mapping
5050
self._model = KModel(config=config_path, model=model_path).eval()
51-
# Move to CUDA if needed
52-
if self._device == "cuda":
51+
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
52+
if self._device == "mps":
53+
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
54+
self._model = self._model.to(torch.device("mps"))
55+
elif self._device == "cuda":
5356
self._model = self._model.cuda()
57+
else:
58+
self._model = self._model.cpu()
5459

5560
except FileNotFoundError as e:
5661
raise e
@@ -273,7 +278,7 @@ async def generate(
273278
continue
274279
if not token.text or not token.text.strip():
275280
continue
276-
281+
277282
start_time = float(token.start_ts) + current_offset
278283
end_time = float(token.end_ts) + current_offset
279284
word_timestamps.append(
@@ -291,8 +296,8 @@ async def generate(
291296
logger.error(
292297
f"Failed to process timestamps for chunk: {e}"
293298
)
294-
295-
299+
300+
296301
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
297302
else:
298303
logger.warning("No audio in chunk")
@@ -314,13 +319,18 @@ def _check_memory(self) -> bool:
314319
if self._device == "cuda":
315320
memory_gb = torch.cuda.memory_allocated() / 1e9
316321
return memory_gb > model_config.pytorch_gpu.memory_threshold
322+
# MPS doesn't provide memory management APIs
317323
return False
318324

319325
def _clear_memory(self) -> None:
320326
"""Clear device memory."""
321327
if self._device == "cuda":
322328
torch.cuda.empty_cache()
323329
torch.cuda.synchronize()
330+
elif self._device == "mps":
331+
# Empty cache if available (future-proofing)
332+
if hasattr(torch.mps, 'empty_cache'):
333+
torch.mps.empty_cache()
324334

325335
def unload(self) -> None:
326336
"""Unload model and free resources."""

api/src/inference/voice_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class VoiceManager:
1919
def __init__(self):
2020
"""Initialize voice manager."""
2121
# Strictly respect settings.use_gpu
22-
self._device = "cuda" if settings.use_gpu else "cpu"
22+
self._device = settings.get_device()
2323
self._voices: Dict[str, torch.Tensor] = {}
2424

2525
async def get_voice_path(self, voice_name: str) -> str:

api/src/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ async def lifespan(app: FastAPI):
8585
{boundary}
8686
"""
8787
startup_msg += f"\nModel warmed up on {device}: {model}"
88-
startup_msg += f"CUDA: {torch.cuda.is_available()}"
88+
if device == "mps":
89+
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
90+
elif device == "cuda":
91+
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
92+
else:
93+
startup_msg += "\nRunning on CPU"
8994
startup_msg += f"\n{voicepack_count} voice packs loaded"
9095

9196
# Add web player info if enabled

api/src/routers/debug.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import psutil
66
from fastapi import APIRouter
7+
import torch
78

89
try:
910
import GPUtil
@@ -113,7 +114,14 @@ async def get_system_info():
113114

114115
# GPU Info if available
115116
gpu_info = None
116-
if GPU_AVAILABLE:
117+
if torch.backends.mps.is_available():
118+
gpu_info = {
119+
"type": "MPS",
120+
"available": True,
121+
"device": "Apple Silicon",
122+
"backend": "Metal"
123+
}
124+
elif GPU_AVAILABLE:
117125
try:
118126
gpus = GPUtil.getGPUs()
119127
gpu_info = [

start-gpu_mac.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/bin/bash
2+
3+
# Get project root directory
4+
PROJECT_ROOT=$(pwd)
5+
6+
# Create mps-specific venv directory
7+
VENV_DIR="$PROJECT_ROOT/.venv-mps"
8+
if [ ! -d "$VENV_DIR" ]; then
9+
echo "Creating MPS-specific virtual environment..."
10+
python3 -m venv "$VENV_DIR"
11+
fi
12+
13+
# Set other environment variables
14+
export USE_GPU=true
15+
export USE_ONNX=false
16+
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
17+
export MODEL_DIR=src/models
18+
export VOICES_DIR=src/voices/v1_0
19+
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
20+
21+
# Set environment variables
22+
export USE_GPU=true
23+
export USE_ONNX=false
24+
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
25+
export MODEL_DIR=src/models
26+
export VOICES_DIR=src/voices/v1_0
27+
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
28+
29+
export DEVICE_TYPE=mps
30+
# Enable MPS fallback for unsupported operations
31+
export PYTORCH_ENABLE_MPS_FALLBACK=1
32+
33+
# Run FastAPI with GPU extras using uv run
34+
uv pip install -e .
35+
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880

0 commit comments

Comments
 (0)