Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

VoiceStreamAI is a Python 3 -based server and JavaScript client solution that
enables near-realtime audio streaming and transcription using WebSocket. The
system employs Huggingface's Voice Activity Detection (VAD) and OpenAI's Whisper
system employs Silero Voice Activity Detection (VAD) and OpenAI's Whisper
model ([faster-whisper](https://github.com/SYSTRAN/faster-whisper) being the
default) for accurate speech recognition and processing.

Expand Down Expand Up @@ -78,6 +78,8 @@ following packages:
5. `asyncio`
6. `sentence-transformers`
7. `faster-whisper`
8. `silero-vad`
9. `soundfile`

Install these packages using pip:

Expand All @@ -96,7 +98,7 @@ allowing you to specify components, host, and port settings according to your
needs.

- `--vad-type`: Specifies the type of Voice Activity Detection (VAD) pipeline to
use (default: `pyannote`) .
use (default: `silero`). The default Silero VAD doesn't require an authentication token.
- `--vad-args`: A JSON string containing additional arguments for the VAD
pipeline. (required for `pyannote`: `'{"auth_token": "VAD_AUTH_HERE"}'`)
- `--asr-type`: Specifies the type of Automatic Speech Recognition (ASR)
Expand All @@ -113,12 +115,19 @@ needs.

For running the server with the standard configuration:

```bash
python3 -m src.main --help
```

Since the default VAD is Silero, which doesn't require an authentication token,
the above command is sufficient. If you want to use pyannote VAD:

1. Obtain the key to the Voice-Activity-Detection model
at [https://huggingface.co/pyannote/segmentation](https://huggingface.co/pyannote/segmentation)
2. Run the server using Python 3.x, please add the VAD key in the command line:

```bash
python3 -m src.main --vad-args '{"auth_token": "vad token here"}'
python3 -m src.main --vad-type 'pyannote' --vad-args '{"auth_token": "vad token here"}'
```

You can see all the command line options with the command:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ transformers==4.40.2
faster-whisper==1.0.2
torchvision~=0.18.0
torch~=2.3.0
soundfile==0.12.1
silero-vad==5.1.2
19 changes: 8 additions & 11 deletions src/asr/faster_whisper_asr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os

import torch
from faster_whisper import WhisperModel

from src.audio_utils import save_audio_to_file
from src.utils.audio_utils import convert_audio_bytes_to_numpy

from .asr_interface import ASRInterface

Expand Down Expand Up @@ -109,31 +109,28 @@
"cantonese": "yue",
}


class FasterWhisperASR(ASRInterface):
def __init__(self, **kwargs):
model_size = kwargs.get("model_size", "large-v3")
# Run on GPU with FP16
model_size = kwargs.get("model_size", "tiny")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if torch.cuda.is_available() else "float32"
self.asr_pipeline = WhisperModel(
model_size, device="cuda", compute_type="float16"
model_size, device=device, compute_type=compute_type
)

async def transcribe(self, client):
file_path = await save_audio_to_file(
client.scratch_buffer, client.get_file_name()
)
audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer)

language = (
None
if client.config["language"] is None
else language_codes.get(client.config["language"].lower())
)
segments, info = self.asr_pipeline.transcribe(
file_path, word_timestamps=True, language=language
audio_np, word_timestamps=True, language=language
)

segments = list(segments) # The transcription will actually run here.
os.remove(file_path)

flattened_words = [
word for segment in segments for word in segment.words
Expand Down
12 changes: 4 additions & 8 deletions src/asr/whisper_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers import pipeline

from src.audio_utils import save_audio_to_file
from src.utils.audio_utils import convert_audio_bytes_to_numpy

from .asr_interface import ASRInterface

Expand All @@ -19,19 +19,15 @@ def __init__(self, **kwargs):
)

async def transcribe(self, client):
file_path = await save_audio_to_file(
client.scratch_buffer, client.get_file_name()
)
audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer)

if client.config["language"] is not None:
to_return = self.asr_pipeline(
file_path,
audio_np,
generate_kwargs={"language": client.config["language"]},
)["text"]
else:
to_return = self.asr_pipeline(file_path)["text"]

os.remove(file_path)
to_return = self.asr_pipeline(audio_np)["text"]

to_return = {
"language": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER",
Expand Down
28 changes: 0 additions & 28 deletions src/audio_utils.py

This file was deleted.

34 changes: 13 additions & 21 deletions src/buffering_strategy/buffering_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import time

from src.callbacks import AudioProcessingCallbacks

from .buffering_strategy_interface import BufferingStrategyInterface


Expand Down Expand Up @@ -49,15 +51,9 @@ def __init__(self, client, **kwargs):
self.chunk_offset_seconds = kwargs.get("chunk_offset_seconds")
self.chunk_offset_seconds = float(self.chunk_offset_seconds)

self.error_if_not_realtime = os.environ.get("ERROR_IF_NOT_REALTIME")
if not self.error_if_not_realtime:
self.error_if_not_realtime = kwargs.get(
"error_if_not_realtime", False
)

self.processing_flag = False

def process_audio(self, websocket, vad_pipeline, asr_pipeline):
def process_audio(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline):
"""
Process audio chunks by checking their length and scheduling
asynchronous processing.
Expand All @@ -66,7 +62,7 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline):
length and, if so, it schedules asynchronous processing of the audio.

Args:
websocket: The WebSocket connection for sending transcriptions.
callbacks: Callbacks for audio processing events.
vad_pipeline: The voice activity detection pipeline.
asr_pipeline: The automatic speech recognition pipeline.
"""
Expand All @@ -77,30 +73,26 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline):
)
if len(self.client.buffer) > chunk_length_in_bytes:
if self.processing_flag:
exit(
"Error in realtime processing: tried processing a new "
"chunk while the previous one was still being processed"
)
self.processing_flag.cancel()

self.client.scratch_buffer += self.client.buffer
self.client.buffer.clear()
self.processing_flag = True
# Schedule the processing in a separate task
asyncio.create_task(
self.process_audio_async(websocket, vad_pipeline, asr_pipeline)
self.processing_flag = asyncio.create_task(
self.process_audio_async(callbacks, vad_pipeline, asr_pipeline)
)

async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline):
async def process_audio_async(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline):
"""
Asynchronously process audio for activity detection and transcription.

This method performs heavy processing, including voice activity
detection and transcription of the audio data. It sends the
transcription results through the WebSocket connection.
detection and transcription of the audio data. If conditions are met,
triggers transcribes the audio, and triggers
processing and transcription callbacks

Args:
websocket (Websocket): The WebSocket connection for sending
transcriptions.
callbacks: Callbacks for audio processing events.
vad_pipeline: The voice activity detection pipeline.
asr_pipeline: The automatic speech recognition pipeline.
"""
Expand All @@ -123,7 +115,7 @@ async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline):
end = time.time()
transcription["processing_time"] = end - start
json_transcription = json.dumps(transcription)
await websocket.send(json_transcription)
await callbacks.trigger_transcription_complete(json_transcription)
self.client.scratch_buffer.clear()
self.client.increment_file_counter()

Expand Down
35 changes: 35 additions & 0 deletions src/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable, Any, Optional, Awaitable

class AudioProcessingCallbacks:
"""
Callback interface for audio processing events.

This class defines callback functions that can be registered to handle
different events that occur during audio processing, such as:
- Transcription completion

Callbacks are defined as async functions to support asynchronous operations.
"""

def __init__(
self,
on_transcription_complete: Optional[Callable[[str], Awaitable[None]]] = None,
):
"""
Initialize the callback interface.

Args:
on_transcription_complete: Called when transcription is complete with the
transcribed text.
"""
self.on_transcription_complete = on_transcription_complete

async def trigger_transcription_complete(self, text: str):
"""
Trigger the transcription complete callback.

Args:
text: The transcribed text.
"""
if self.on_transcription_complete:
await self.on_transcription_complete(text)
7 changes: 4 additions & 3 deletions src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from src.buffering_strategy.buffering_strategy_factory import (
BufferingStrategyFactory,
)
from src.callbacks import AudioProcessingCallbacks


class Client:
Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(self, client_id, sampling_rate, samples_width):
"language": None,
"processing_strategy": "silence_at_end_of_chunk",
"processing_args": {
"chunk_length_seconds": 5,
"chunk_length_seconds": 3,
"chunk_offset_seconds": 0.1,
},
}
Expand Down Expand Up @@ -72,7 +73,7 @@ def increment_file_counter(self):
def get_file_name(self):
return f"{self.client_id}_{self.file_counter}.wav"

def process_audio(self, websocket, vad_pipeline, asr_pipeline):
def process_audio(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline):
self.buffering_strategy.process_audio(
websocket, vad_pipeline, asr_pipeline
callbacks, vad_pipeline, asr_pipeline
)
15 changes: 7 additions & 8 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import asyncio
import json
import logging
from src.utils.base_logger import logger, setLogger

from src.asr.asr_factory import ASRFactory
from src.vad.vad_factory import VADFactory
Expand All @@ -17,13 +17,13 @@ def parse_args():
parser.add_argument(
"--vad-type",
type=str,
default="pyannote",
help="Type of VAD pipeline to use (e.g., 'pyannote')",
default="silero",
help="Type of VAD pipeline to use (e.g., 'silero')",
)
parser.add_argument(
"--vad-args",
type=str,
default='{"auth_token": "huggingface_token"}',
default=None,
help="JSON string of additional arguments for VAD pipeline",
)
parser.add_argument(
Expand All @@ -35,7 +35,7 @@ def parse_args():
parser.add_argument(
"--asr-args",
type=str,
default='{"model_size": "large-v3"}',
default='{"model_size": "tiny"}',
help="JSON string of additional arguments for ASR pipeline",
)
parser.add_argument(
Expand Down Expand Up @@ -73,14 +73,13 @@ def parse_args():
def main():
args = parse_args()

logging.basicConfig()
logging.getLogger().setLevel(args.log_level.upper())
setLogger("debug")

try:
vad_args = json.loads(args.vad_args)
asr_args = json.loads(args.asr_args)
except json.JSONDecodeError as e:
print(f"Error parsing JSON arguments: {e}")
logger.error(f"Error parsing JSON arguments: {e}")
return

vad_pipeline = VADFactory.create_vad_pipeline(args.vad_type, **vad_args)
Expand Down
Loading