-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstt.py
More file actions
102 lines (80 loc) · 3.29 KB
/
stt.py
File metadata and controls
102 lines (80 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Speech-to-text module using lightning-whisper-mlx.
Uses the lightning-fast Whisper implementation optimized for Apple Silicon.
"""
import tempfile
import numpy as np
import scipy.io.wavfile as wav
from lightning_whisper_mlx import LightningWhisperMLX
class WhisperTranscriber:
"""Transcribes audio to text using Lightning Whisper MLX."""
# Available models: tiny, base, small, medium, large, large-v2, large-v3
# distil variants: distil-small.en, distil-medium.en, distil-large-v2, distil-large-v3
DEFAULT_MODEL = "distil-medium.en" # Fast & accurate for English
SAMPLE_RATE = 16000 # Whisper expects 16kHz
def __init__(self, model: str | None = None, batch_size: int = 12):
"""Initialize transcriber.
Args:
model: Whisper model name (tiny, base, small, medium, large, etc.)
batch_size: Batch size for processing.
"""
self.model_name = model or self.DEFAULT_MODEL
self.batch_size = batch_size
self._model = None
def _load_model(self):
"""Lazy load the model on first use."""
if self._model is None:
import sys
import io
import contextlib
print(f"Loading Whisper model: {self.model_name}...", file=sys.stderr)
# Suppress stdout from LightningWhisperMLX (it prints loading messages)
with contextlib.redirect_stdout(io.StringIO()):
self._model = LightningWhisperMLX(
model=self.model_name,
batch_size=self.batch_size
)
print("Whisper model ready!", file=sys.stderr)
def transcribe(self, audio: np.ndarray, sample_rate: int = 24000) -> str:
"""Transcribe audio to text.
Args:
audio: Audio data as numpy array.
sample_rate: Sample rate of the input audio (will resample to 16kHz).
Returns:
Transcribed text.
"""
if len(audio) == 0:
return ""
self._load_model()
# Ensure audio is float32
audio = audio.astype(np.float32)
# Flatten if needed
if audio.ndim > 1:
audio = audio.flatten()
# Resample to 16kHz if needed (Whisper expects 16kHz)
if sample_rate != self.SAMPLE_RATE:
from scipy import signal
num_samples = int(len(audio) * self.SAMPLE_RATE / sample_rate)
audio = signal.resample(audio, num_samples)
# Normalize audio to [-1, 1] range
max_val = np.abs(audio).max()
if max_val > 0:
audio = audio / max_val
# Convert to int16 for WAV file
audio_int16 = (audio * 32767).astype(np.int16)
# Write to temp WAV file (lightning-whisper-mlx needs a file path)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
temp_path = f.name
wav.write(temp_path, self.SAMPLE_RATE, audio_int16)
try:
# Transcribe
result = self._model.transcribe(temp_path)
return result.get("text", "").strip()
finally:
# Clean up temp file
import os
try:
os.unlink(temp_path)
except OSError:
pass
# Alias for backwards compatibility
MoshiTranscriber = WhisperTranscriber