diff --git a/README.md b/README.md index 657a47c..2975fbc 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ VoiceStreamAI is a Python 3 -based server and JavaScript client solution that enables near-realtime audio streaming and transcription using WebSocket. The -system employs Huggingface's Voice Activity Detection (VAD) and OpenAI's Whisper +system employs Silero Voice Activity Detection (VAD) and OpenAI's Whisper model ([faster-whisper](https://github.com/SYSTRAN/faster-whisper) being the default) for accurate speech recognition and processing. @@ -78,6 +78,8 @@ following packages: 5. `asyncio` 6. `sentence-transformers` 7. `faster-whisper` +8. `silero-vad` +9. `soundfile` Install these packages using pip: @@ -96,7 +98,7 @@ allowing you to specify components, host, and port settings according to your needs. - `--vad-type`: Specifies the type of Voice Activity Detection (VAD) pipeline to - use (default: `pyannote`) . + use (default: `silero`). The default Silero VAD doesn't require an authentication token. - `--vad-args`: A JSON string containing additional arguments for the VAD pipeline. (required for `pyannote`: `'{"auth_token": "VAD_AUTH_HERE"}'`) - `--asr-type`: Specifies the type of Automatic Speech Recognition (ASR) @@ -113,12 +115,19 @@ needs. For running the server with the standard configuration: +```bash +python3 -m src.main --help +``` + +Since the default VAD is Silero, which doesn't require an authentication token, +the above command is sufficient. If you want to use pyannote VAD: + 1. Obtain the key to the Voice-Activity-Detection model at [https://huggingface.co/pyannote/segmentation](https://huggingface.co/pyannote/segmentation) 2. Run the server using Python 3.x, please add the VAD key in the command line: ```bash -python3 -m src.main --vad-args '{"auth_token": "vad token here"}' +python3 -m src.main --vad-type 'pyannote' --vad-args '{"auth_token": "vad token here"}' ``` You can see all the command line options with the command: diff --git a/requirements.txt b/requirements.txt index 34aea8e..f4b0b25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ transformers==4.40.2 faster-whisper==1.0.2 torchvision~=0.18.0 torch~=2.3.0 +soundfile==0.12.1 +silero-vad==5.1.2 diff --git a/src/asr/faster_whisper_asr.py b/src/asr/faster_whisper_asr.py index 96c475c..c16b944 100644 --- a/src/asr/faster_whisper_asr.py +++ b/src/asr/faster_whisper_asr.py @@ -1,8 +1,8 @@ import os - +import torch from faster_whisper import WhisperModel -from src.audio_utils import save_audio_to_file +from src.utils.audio_utils import convert_audio_bytes_to_numpy from .asr_interface import ASRInterface @@ -109,19 +109,17 @@ "cantonese": "yue", } - class FasterWhisperASR(ASRInterface): def __init__(self, **kwargs): - model_size = kwargs.get("model_size", "large-v3") - # Run on GPU with FP16 + model_size = kwargs.get("model_size", "tiny") + device = "cuda" if torch.cuda.is_available() else "cpu" + compute_type = "float16" if torch.cuda.is_available() else "float32" self.asr_pipeline = WhisperModel( - model_size, device="cuda", compute_type="float16" + model_size, device=device, compute_type=compute_type ) async def transcribe(self, client): - file_path = await save_audio_to_file( - client.scratch_buffer, client.get_file_name() - ) + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) language = ( None @@ -129,11 +127,10 @@ async def transcribe(self, client): else language_codes.get(client.config["language"].lower()) ) segments, info = self.asr_pipeline.transcribe( - file_path, word_timestamps=True, language=language + audio_np, word_timestamps=True, language=language ) segments = list(segments) # The transcription will actually run here. - os.remove(file_path) flattened_words = [ word for segment in segments for word in segment.words diff --git a/src/asr/whisper_asr.py b/src/asr/whisper_asr.py index b472e43..f4a01d4 100644 --- a/src/asr/whisper_asr.py +++ b/src/asr/whisper_asr.py @@ -3,7 +3,7 @@ import torch from transformers import pipeline -from src.audio_utils import save_audio_to_file +from src.utils.audio_utils import convert_audio_bytes_to_numpy from .asr_interface import ASRInterface @@ -19,19 +19,15 @@ def __init__(self, **kwargs): ) async def transcribe(self, client): - file_path = await save_audio_to_file( - client.scratch_buffer, client.get_file_name() - ) + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) if client.config["language"] is not None: to_return = self.asr_pipeline( - file_path, + audio_np, generate_kwargs={"language": client.config["language"]}, )["text"] else: - to_return = self.asr_pipeline(file_path)["text"] - - os.remove(file_path) + to_return = self.asr_pipeline(audio_np)["text"] to_return = { "language": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER", diff --git a/src/audio_utils.py b/src/audio_utils.py deleted file mode 100644 index f9aa203..0000000 --- a/src/audio_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import wave - - -async def save_audio_to_file( - audio_data, file_name, audio_dir="audio_files", audio_format="wav" -): - """ - Saves the audio data to a file. - - :param audio_data: The audio data to save. - :param file_name: The name of the file. - :param audio_dir: Directory where audio files will be saved. - :param audio_format: Format of the audio file. - :return: Path to the saved audio file. - """ - - os.makedirs(audio_dir, exist_ok=True) - - file_path = os.path.join(audio_dir, file_name) - - with wave.open(file_path, "wb") as wav_file: - wav_file.setnchannels(1) # Assuming mono audio - wav_file.setsampwidth(2) - wav_file.setframerate(16000) - wav_file.writeframes(audio_data) - - return file_path diff --git a/src/buffering_strategy/buffering_strategies.py b/src/buffering_strategy/buffering_strategies.py index bc1a0d4..b4d214d 100644 --- a/src/buffering_strategy/buffering_strategies.py +++ b/src/buffering_strategy/buffering_strategies.py @@ -3,6 +3,8 @@ import os import time +from src.callbacks import AudioProcessingCallbacks + from .buffering_strategy_interface import BufferingStrategyInterface @@ -49,15 +51,9 @@ def __init__(self, client, **kwargs): self.chunk_offset_seconds = kwargs.get("chunk_offset_seconds") self.chunk_offset_seconds = float(self.chunk_offset_seconds) - self.error_if_not_realtime = os.environ.get("ERROR_IF_NOT_REALTIME") - if not self.error_if_not_realtime: - self.error_if_not_realtime = kwargs.get( - "error_if_not_realtime", False - ) - self.processing_flag = False - def process_audio(self, websocket, vad_pipeline, asr_pipeline): + def process_audio(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline): """ Process audio chunks by checking their length and scheduling asynchronous processing. @@ -66,7 +62,7 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline): length and, if so, it schedules asynchronous processing of the audio. Args: - websocket: The WebSocket connection for sending transcriptions. + callbacks: Callbacks for audio processing events. vad_pipeline: The voice activity detection pipeline. asr_pipeline: The automatic speech recognition pipeline. """ @@ -77,30 +73,26 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline): ) if len(self.client.buffer) > chunk_length_in_bytes: if self.processing_flag: - exit( - "Error in realtime processing: tried processing a new " - "chunk while the previous one was still being processed" - ) + self.processing_flag.cancel() self.client.scratch_buffer += self.client.buffer self.client.buffer.clear() - self.processing_flag = True # Schedule the processing in a separate task - asyncio.create_task( - self.process_audio_async(websocket, vad_pipeline, asr_pipeline) + self.processing_flag = asyncio.create_task( + self.process_audio_async(callbacks, vad_pipeline, asr_pipeline) ) - async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline): + async def process_audio_async(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline): """ Asynchronously process audio for activity detection and transcription. This method performs heavy processing, including voice activity - detection and transcription of the audio data. It sends the - transcription results through the WebSocket connection. + detection and transcription of the audio data. If conditions are met, + triggers transcribes the audio, and triggers + processing and transcription callbacks Args: - websocket (Websocket): The WebSocket connection for sending - transcriptions. + callbacks: Callbacks for audio processing events. vad_pipeline: The voice activity detection pipeline. asr_pipeline: The automatic speech recognition pipeline. """ @@ -123,7 +115,7 @@ async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline): end = time.time() transcription["processing_time"] = end - start json_transcription = json.dumps(transcription) - await websocket.send(json_transcription) + await callbacks.trigger_transcription_complete(json_transcription) self.client.scratch_buffer.clear() self.client.increment_file_counter() diff --git a/src/callbacks.py b/src/callbacks.py new file mode 100644 index 0000000..abcf9fd --- /dev/null +++ b/src/callbacks.py @@ -0,0 +1,35 @@ +from typing import Callable, Any, Optional, Awaitable + +class AudioProcessingCallbacks: + """ + Callback interface for audio processing events. + + This class defines callback functions that can be registered to handle + different events that occur during audio processing, such as: + - Transcription completion + + Callbacks are defined as async functions to support asynchronous operations. + """ + + def __init__( + self, + on_transcription_complete: Optional[Callable[[str], Awaitable[None]]] = None, + ): + """ + Initialize the callback interface. + + Args: + on_transcription_complete: Called when transcription is complete with the + transcribed text. + """ + self.on_transcription_complete = on_transcription_complete + + async def trigger_transcription_complete(self, text: str): + """ + Trigger the transcription complete callback. + + Args: + text: The transcribed text. + """ + if self.on_transcription_complete: + await self.on_transcription_complete(text) \ No newline at end of file diff --git a/src/client.py b/src/client.py index 414a0fd..d674a20 100644 --- a/src/client.py +++ b/src/client.py @@ -3,6 +3,7 @@ from src.buffering_strategy.buffering_strategy_factory import ( BufferingStrategyFactory, ) +from src.callbacks import AudioProcessingCallbacks class Client: @@ -33,7 +34,7 @@ def __init__(self, client_id, sampling_rate, samples_width): "language": None, "processing_strategy": "silence_at_end_of_chunk", "processing_args": { - "chunk_length_seconds": 5, + "chunk_length_seconds": 3, "chunk_offset_seconds": 0.1, }, } @@ -72,7 +73,7 @@ def increment_file_counter(self): def get_file_name(self): return f"{self.client_id}_{self.file_counter}.wav" - def process_audio(self, websocket, vad_pipeline, asr_pipeline): + def process_audio(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline): self.buffering_strategy.process_audio( - websocket, vad_pipeline, asr_pipeline + callbacks, vad_pipeline, asr_pipeline ) diff --git a/src/main.py b/src/main.py index fad9c2d..9161df5 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,7 @@ import argparse import asyncio import json -import logging +from src.utils.base_logger import logger, setLogger from src.asr.asr_factory import ASRFactory from src.vad.vad_factory import VADFactory @@ -17,13 +17,13 @@ def parse_args(): parser.add_argument( "--vad-type", type=str, - default="pyannote", - help="Type of VAD pipeline to use (e.g., 'pyannote')", + default="silero", + help="Type of VAD pipeline to use (e.g., 'silero')", ) parser.add_argument( "--vad-args", type=str, - default='{"auth_token": "huggingface_token"}', + default=None, help="JSON string of additional arguments for VAD pipeline", ) parser.add_argument( @@ -35,7 +35,7 @@ def parse_args(): parser.add_argument( "--asr-args", type=str, - default='{"model_size": "large-v3"}', + default='{"model_size": "tiny"}', help="JSON string of additional arguments for ASR pipeline", ) parser.add_argument( @@ -73,14 +73,13 @@ def parse_args(): def main(): args = parse_args() - logging.basicConfig() - logging.getLogger().setLevel(args.log_level.upper()) + setLogger("debug") try: vad_args = json.loads(args.vad_args) asr_args = json.loads(args.asr_args) except json.JSONDecodeError as e: - print(f"Error parsing JSON arguments: {e}") + logger.error(f"Error parsing JSON arguments: {e}") return vad_pipeline = VADFactory.create_vad_pipeline(args.vad_type, **vad_args) diff --git a/src/server.py b/src/server.py index 809c7ad..60006d8 100644 --- a/src/server.py +++ b/src/server.py @@ -1,11 +1,12 @@ import json -import logging import ssl import uuid import websockets from src.client import Client +from src.utils.base_logger import logger +from .callbacks import AudioProcessingCallbacks class Server: @@ -49,6 +50,20 @@ def __init__( self.connected_clients = {} async def handle_audio(self, client, websocket): + + async def on_transcription_complete(message): + # Process the transcribed message + try: + await websocket.send(message) + except Exception as e: + logger.error(f"Error processing message: {e}") + # This could be enhanced with proper error handling + + # Initialize callbacks + callbacks = AudioProcessingCallbacks( + on_transcription_complete=on_transcription_complete, + ) + while True: message = await websocket.recv() @@ -58,14 +73,14 @@ async def handle_audio(self, client, websocket): config = json.loads(message) if config.get("type") == "config": client.update_config(config["data"]) - logging.debug(f"Updated config: {client.config}") + logger.debug(f"Updated config: {client.config}") continue else: - print(f"Unexpected message type from {client.client_id}") + logger.error(f"Unexpected message type from {client.client_id}") # this is synchronous, any async operation is in BufferingStrategy client.process_audio( - websocket, self.vad_pipeline, self.asr_pipeline + callbacks, self.vad_pipeline, self.asr_pipeline ) async def handle_websocket(self, websocket): @@ -73,12 +88,12 @@ async def handle_websocket(self, websocket): client = Client(client_id, self.sampling_rate, self.samples_width) self.connected_clients[client_id] = client - print(f"Client {client_id} connected") + logger.info(f"Client {client_id} connected") try: await self.handle_audio(client, websocket) except websockets.ConnectionClosed as e: - print(f"Connection with {client_id} closed: {e}") + logger.error(f"Connection with {client_id} closed: {e}") finally: del self.connected_clients[client_id] @@ -94,7 +109,7 @@ def start(self): certfile=self.certfile, keyfile=self.keyfile ) - print( + logger.info( f"WebSocket server ready to accept secure connections on " f"{self.host}:{self.port}" ) @@ -106,7 +121,7 @@ def start(self): self.handle_websocket, self.host, self.port, ssl=ssl_context ) else: - print( + logger.error( f"WebSocket server ready to accept secure connections on " f"{self.host}:{self.port}" ) diff --git a/src/utils/audio_utils.py b/src/utils/audio_utils.py new file mode 100644 index 0000000..8e8e98f --- /dev/null +++ b/src/utils/audio_utils.py @@ -0,0 +1,18 @@ +from numpy import frombuffer, int16, float32 + +def convert_audio_bytes_to_numpy(audio_bytes): + """ + Convert raw audio bytes from scratch_buffer (bytearray) + directly to the numpy format required by Whisper and VAD + + :param audio_bytes: Raw audio bytes as bytearray + :return: Numpy array with the audio data in the format expected by Whisper + """ + # Convert bytearray directly to numpy array + # Assuming 16-bit PCM audio format + audio_as_np_int16 = frombuffer(audio_bytes, dtype=int16) + + # Convert to float32 and normalize to [-1, 1] range as expected by Whisper + audio_as_np_float32 = audio_as_np_int16.astype(float32) / 32768.0 + + return audio_as_np_float32 \ No newline at end of file diff --git a/src/utils/base_logger.py b/src/utils/base_logger.py new file mode 100644 index 0000000..c8bf6b0 --- /dev/null +++ b/src/utils/base_logger.py @@ -0,0 +1,23 @@ +import logging + +logger = logging + + +class BinaryLogFilter(logging.Filter): + def filter(self, record): + return not (record.getMessage().startswith('< BINARY') or '< BINARY' in record.getMessage()) + + +def setLogger(level: str): + global logger + + levels = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + } + + logger.basicConfig(format='%(asctime)s - %(message)s', level=levels[level]) + for handler in logger.getLogger().handlers: + handler.addFilter(BinaryLogFilter()) diff --git a/src/vad/pyannote_vad.py b/src/vad/pyannote_vad.py index 4f65c67..52d0942 100644 --- a/src/vad/pyannote_vad.py +++ b/src/vad/pyannote_vad.py @@ -1,10 +1,10 @@ import os -from os import remove - +import io +import soundfile as sf from pyannote.audio import Model from pyannote.audio.pipelines import VoiceActivityDetection -from src.audio_utils import save_audio_to_file +from src.utils.audio_utils import convert_audio_bytes_to_numpy from .vad_interface import VADInterface @@ -51,11 +51,19 @@ def __init__(self, **kwargs): self.vad_pipeline.instantiate(pyannote_args) async def detect_activity(self, client): - audio_file_path = await save_audio_to_file( - client.scratch_buffer, client.get_file_name() - ) - vad_results = self.vad_pipeline(audio_file_path) - remove(audio_file_path) + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) + + # Create an in-memory audio file + audio_buffer = io.BytesIO() + # Save as WAV at 16kHz sample rate + sf.write(audio_buffer, audio_np, 16000, format='WAV') + + # Reset buffer position for reading + audio_buffer.seek(0) + + # Process with Pyannote directly from the in-memory buffer + vad_results = self.vad_pipeline(audio_buffer) + vad_segments = [] if len(vad_results) > 0: vad_segments = [ diff --git a/src/vad/silero_vad.py b/src/vad/silero_vad.py new file mode 100644 index 0000000..edd4840 --- /dev/null +++ b/src/vad/silero_vad.py @@ -0,0 +1,33 @@ +from math import floor + +from silero_vad import load_silero_vad, get_speech_timestamps + +from src.utils.audio_utils import convert_audio_bytes_to_numpy +from src.utils.base_logger import logger + +from .vad_interface import VADInterface + +class SileroVAD(VADInterface): + """ + Pyannote-based implementation of the VADInterface that works with in-memory audio. + """ + + def __init__(self, **kwargs): + """ + Initializes RVADFast's VAD pipeline. + """ + self.model = load_silero_vad() + + async def detect_activity(self, client): + # Convert bytearray to numpy array + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) + + speech_timestamps = get_speech_timestamps(audio_np, self.model) + + # It returns ms + new_timestamps = [ + {'start': floor(timestamp['start'] / client.sampling_rate), 'end': floor(timestamp['end'] / client.sampling_rate)} + for timestamp in speech_timestamps + ] + + return new_timestamps \ No newline at end of file diff --git a/src/vad/vad_factory.py b/src/vad/vad_factory.py index e20d3f9..b7a8b97 100644 --- a/src/vad/vad_factory.py +++ b/src/vad/vad_factory.py @@ -1,4 +1,5 @@ from .pyannote_vad import PyannoteVAD +from .silero_vad import SileroVAD class VADFactory: @@ -20,5 +21,7 @@ def create_vad_pipeline(type, **kwargs): """ if type == "pyannote": return PyannoteVAD(**kwargs) + elif type == "silero": + return SileroVAD(**kwargs) else: raise ValueError(f"Unknown VAD pipeline type: {type}") diff --git a/test/vad/test_pyannote_vad.py b/test/vad/test_pyannote_vad.py index a3b95d6..af711f3 100644 --- a/test/vad/test_pyannote_vad.py +++ b/test/vad/test_pyannote_vad.py @@ -41,7 +41,7 @@ def test_detect_activity(self): self.client.scratch_buffer = bytearray(audio_segment.raw_data) vad_results = asyncio.run( - self.vad.detect_activity(self.client) + self.vad.detect_activity(self.client.scratch_buffer) ) # Adjust VAD-detected times by adding the start time of the diff --git a/test/vad/test_silero_vad.py b/test/vad/test_silero_vad.py new file mode 100644 index 0000000..c87205f --- /dev/null +++ b/test/vad/test_silero_vad.py @@ -0,0 +1,88 @@ +# tests/vad/test_silero_vad.py + +import asyncio +import json +import os +import unittest + +from pydub import AudioSegment + +from src.client import Client +from src.vad.silero_vad import SileroVAD + + +class TestSileroVAD(unittest.TestCase): + def setUp(self): + self.vad = SileroVAD() + self.annotations_path = os.path.join( + os.path.dirname(__file__), "../audio_files/annotations.json" + ) + self.client = Client("test_client", 16000, 2) # Example client + + def load_annotations(self): + with open(self.annotations_path, "r") as file: + return json.load(file) + + def test_detect_activity(self): + annotations = self.load_annotations() + + for audio_file, data in annotations.items(): + audio_file_path = os.path.join( + os.path.dirname(__file__), f"../audio_files/{audio_file}" + ) + + for annotated_segment in data["segments"]: + print(annotated_segment['transcription']) + # Load the specific audio segment for VAD + audio_segment = self.get_audio_segment( + audio_file_path, + annotated_segment["start"], + annotated_segment["end"], + ) + self.client.scratch_buffer = bytearray(audio_segment.raw_data) + + vad_results = asyncio.run( + self.vad.detect_activity(self.client.scratch_buffer) + ) + + # Adjust VAD-detected times by adding the start time of the + # annotated segment + adjusted_vad_results = [ + { + "start": segment["start"] + annotated_segment["start"], + "end": segment["end"] + annotated_segment["start"], + } + for segment in vad_results + ] + + detected_segments = [ + segment + for segment in adjusted_vad_results + if segment["start"] <= annotated_segment["start"] + 1 + and segment["end"] <= annotated_segment["end"] + 4.2 + ] + + # Print formatted information about the test + print( + f"\nTesting segment from '{audio_file}': Annotated Start: " + f"{annotated_segment['start']}, Annotated End: " + f"{annotated_segment['end']}" + ) + print(f"VAD segments: {adjusted_vad_results}") + print(f"Overlapping, Detected segments: {detected_segments}") + + # Assert that at least one detected segment meets the condition + self.assertTrue( + len(detected_segments) > 0, + "No detected segment matches the annotated segment", + ) + + def get_audio_segment(self, file_path, start, end): + with open(file_path, "rb") as file: + audio = AudioSegment.from_file(file, format="wav") + # pydub works in milliseconds + return audio[start * 1000 : end * 1000] # noqa: E203 + + +if __name__ == "__main__": + unittest.main()