diff --git a/Dockerfile b/Dockerfile
index 1a0a11d..a5a4c57 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,27 +1,29 @@
# Use an NVIDIA CUDA base image with Python 3
-FROM nvidia/cuda:11.6.2-base-ubuntu20.04
+FROM nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
# Set the working directory in the container
WORKDIR /usr/src/app
-# Copy the requirements.txt file first to leverage Docker cache
-COPY requirements.txt ./
-
# Avoid interactive prompts from apt-get
ENV DEBIAN_FRONTEND=noninteractive
-# Install any needed packages specified in requirements.txt
-RUN apt-get update && apt-get install -y python3-pip libsndfile1 ffmpeg && \
- pip3 install --no-cache-dir -r requirements.txt
+# Install any needed packages
+RUN apt-get update && \
+ apt-get install -y python3-pip libsndfile1 ffmpeg && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
-# Reset the frontend (not necessary in newer Docker versions)
-ENV DEBIAN_FRONTEND=newt
+# Copy the requirements.txt file
+COPY requirements.txt ./
+
+# Install any needed packages specified in requirements.txt
+RUN pip3 install --no-cache-dir -r requirements.txt
# Copy the rest of your application's code
COPY . .
-# Make port 8765 available to the world outside this container
-EXPOSE 8765
+# Make port 80 available to the world outside this container
+EXPOSE 80
# Define environment variable
ENV NAME VoiceStreamAI
@@ -30,5 +32,4 @@ ENV NAME VoiceStreamAI
ENTRYPOINT ["python3", "-m", "src.main"]
# Provide a default command (can be overridden at runtime)
-CMD ["--host", "0.0.0.0", "--port", "8765"]
-
+CMD ["--host", "0.0.0.0", "--port", "80", "--static-path", "./src/static"]
diff --git a/LICENSE.txt b/LICENSE.txt
index 84609a6..c13f991 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2024 Alessandro Saccoia
+Copyright (c) 2024
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/README.md b/README.md
index b44062c..90a044f 100644
--- a/README.md
+++ b/README.md
@@ -13,18 +13,15 @@ VoiceStreamAI is a Python 3 -based server and JavaScript client solution that en
- Customizable audio chunk processing strategies.
- Support for multilingual transcription.
-## Demo Video
+## Demo
+[View Demo Video](https://raw.githubusercontent.com/TyreseDev/VoiceStreamAI/main/img/voicestreamai_test.mp4)
-https://github.com/alesaccoia/VoiceStreamAI/assets/1385023/9b5f2602-fe0b-4c9d-af9e-4662e42e23df
-
-## Demo Client
-
-
+
## Running with Docker
-This will not guide you in detail on how to use CUDA in docker, see for example [here](https://medium.com/@kevinsjy997/configure-docker-to-use-local-gpu-for-training-ml-models-70980168ec9b).
+This will not guide you in detail on how to use CUDA in docker, see for example [here](https://medium.com/@kevinsjy997/configure-docker-to-use-local-gpu-for-training-ml-models-70980168ec9b).
Still, these are the commands for Linux:
@@ -52,13 +49,13 @@ After getting your VAD token (see next sections) run:
sudo docker volume create huggingface_models
-sudo docker run --gpus all -p 8765:8765 -v huggingface_models:/root/.cache/huggingface -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai
+sudo docker run --gpus all -p 80:80 -v huggingface_models:/root/.cache/huggingface -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai
```
The "volume" stuff will allow you not to re-download the huggingface models each time you re-run the container. If you don't need this, just use:
```bash
-sudo docker run --gpus all -p 8765:8765 -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai
+sudo docker run --gpus all -p 80:80 -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai
```
## Normal, Manual Installation
@@ -92,7 +89,7 @@ The VoiceStreamAI server can be customized through command line arguments, allow
- `--asr-type`: Specifies the type of Automatic Speech Recognition (ASR) pipeline to use (default: `faster_whisper`).
- `--asr-args`: A JSON string containing additional arguments for the ASR pipeline (one can for example change `model_name` for whisper)
- `--host`: Sets the host address for the WebSocket server (default: `127.0.0.1`).
-- `--port`: Sets the port on which the server listens (default: `8765`).
+- `--port`: Sets the port on which the server listens (default: `80`).
For running the server with the standard configuration:
@@ -103,7 +100,7 @@ For running the server with the standard configuration:
python3 -m src.main --vad-args '{"auth_token": "vad token here"}'
```
-You can see all the command line options with the command:
+You can see all the command line options with the command:
```bash
python3 -m src.main --help
@@ -112,13 +109,12 @@ python3 -m src.main --help
## Client Usage
1. Open the `client/VoiceStreamAI_Client.html` file in a web browser.
-2. Enter the WebSocket address (default is `ws://localhost:8765`).
+2. Enter the WebSocket address (default is `ws://localhost/ws`).
3. Configure the audio chunk length and offset. See below.
4. Select the language for transcription.
5. Click 'Connect' to establish a WebSocket connection.
6. Use 'Start Streaming' and 'Stop Streaming' to control audio capture.
-
## Technology Overview
- **Python Server**: Manages WebSocket connections, processes audio streams, and handles voice activity detection and transcription.
@@ -207,10 +203,8 @@ Please make sure that the end variables are in place for example for the VAD aut
### Dependence on Audio Files
-Currently, VoiceStreamAI processes audio by saving chunks to files and then running these files through the models.
+Currently, VoiceStreamAI processes audio by saving chunks to files and then running these files through the models.
## Contributors
-- Alessandro Saccoia - [alessandro.saccoia@gmail.com](mailto:alessandro.saccoia@gmail.com)
-
This project is open for contributions. Feel free to fork the repository and submit pull requests.
diff --git a/client/VoiceStreamAI_Client.html b/client/VoiceStreamAI_Client.html
deleted file mode 100644
index a3c5cb3..0000000
--- a/client/VoiceStreamAI_Client.html
+++ /dev/null
@@ -1,133 +0,0 @@
-
-
-
-
-
- Audio Stream to WebSocket Server
-
-
-
-
-
Transcribe a Web Audio Stream with Huggingface VAD + Whisper
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
WebSocket: Not Connected
-
Detected Language: Undefined
-
Last Processing Time: Undefined
-
-
diff --git a/client/index.html b/client/index.html
new file mode 100644
index 0000000..c23be91
--- /dev/null
+++ b/client/index.html
@@ -0,0 +1,147 @@
+
+
+
+
+
+ Audio Stream to WebSocket Server
+
+
+
+
+
Transcribe a Web Audio Stream with PyAnnote + Whisper
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
WebSocket: Not Connected
+
Detected Language: Undefined
+
Last Processing Time: Undefined
+
+
diff --git a/client/utils.js b/client/utils.js
index abdaf13..48756f1 100644
--- a/client/utils.js
+++ b/client/utils.js
@@ -1,8 +1,6 @@
/**
* VoiceStreamAI Client - WebSocket-based real-time transcription
*
- * Contributor:
- * - Alessandro Saccoia - alessandro.saccoia@gmail.com
*/
let websocket;
@@ -10,207 +8,231 @@ let context;
let processor;
let globalStream;
-const websocket_uri = 'ws://localhost:8765';
const bufferSize = 4096;
let isRecording = false;
function initWebSocket() {
- const websocketAddress = document.getElementById('websocketAddress').value;
- chunk_length_seconds = document.getElementById('chunk_length_seconds').value;
- chunk_offset_seconds = document.getElementById('chunk_offset_seconds').value;
- const selectedLanguage = document.getElementById('languageSelect').value;
- language = selectedLanguage !== 'multilingual' ? selectedLanguage : null;
-
- if (!websocketAddress) {
- console.log("WebSocket address is required.");
- return;
- }
-
- websocket = new WebSocket(websocketAddress);
- websocket.onopen = () => {
- console.log("WebSocket connection established");
- document.getElementById("webSocketStatus").textContent = 'Connected';
- document.getElementById('startButton').disabled = false;
- };
- websocket.onclose = event => {
- console.log("WebSocket connection closed", event);
- document.getElementById("webSocketStatus").textContent = 'Not Connected';
- document.getElementById('startButton').disabled = true;
- document.getElementById('stopButton').disabled = true;
- };
- websocket.onmessage = event => {
- console.log("Message from server:", event.data);
- const transcript_data = JSON.parse(event.data);
- updateTranscription(transcript_data);
- };
+ const websocketAddress = document.getElementById("websocketAddress").value;
+ chunk_length_seconds = document.getElementById("chunk_length_seconds").value;
+ chunk_offset_seconds = document.getElementById("chunk_offset_seconds").value;
+ const selectedLanguage = document.getElementById("languageSelect").value;
+ language = selectedLanguage !== "multilingual" ? selectedLanguage : null;
+
+ if (!websocketAddress) {
+ console.log("WebSocket address is required.");
+ return;
+ }
+
+ websocket = new WebSocket(websocketAddress);
+ websocket.onopen = () => {
+ console.log("WebSocket connection established");
+ document.getElementById("webSocketStatus").textContent = "Connected";
+ document.getElementById("startButton").disabled = false;
+ };
+ websocket.onclose = (event) => {
+ console.log("WebSocket connection closed", event);
+ document.getElementById("webSocketStatus").textContent = "Not Connected";
+ document.getElementById("startButton").disabled = true;
+ document.getElementById("stopButton").disabled = true;
+ };
+ websocket.onmessage = (event) => {
+ console.log("Message from server:", event.data);
+ const transcript_data = JSON.parse(event.data);
+ updateTranscription(transcript_data);
+ };
}
function updateTranscription(transcript_data) {
- const transcriptionDiv = document.getElementById('transcription');
- const languageDiv = document.getElementById('detected_language');
-
- if (transcript_data['words'] && transcript_data['words'].length > 0) {
- // Append words with color based on their probability
- transcript_data['words'].forEach(wordData => {
- const span = document.createElement('span');
- const probability = wordData['probability'];
- span.textContent = wordData['word'] + ' ';
-
- // Set the color based on the probability
- if (probability > 0.9) {
- span.style.color = 'green';
- } else if (probability > 0.6) {
- span.style.color = 'orange';
- } else {
- span.style.color = 'red';
- }
-
- transcriptionDiv.appendChild(span);
- });
-
- // Add a new line at the end
- transcriptionDiv.appendChild(document.createElement('br'));
- } else {
- // Fallback to plain text
- transcriptionDiv.textContent += transcript_data['text'] + '\n';
- }
-
- // Update the language information
- if (transcript_data['language'] && transcript_data['language_probability']) {
- languageDiv.textContent = transcript_data['language'] + ' (' + transcript_data['language_probability'].toFixed(2) + ')';
- }
-
- // Update the processing time, if available
- const processingTimeDiv = document.getElementById('processing_time');
- if (transcript_data['processing_time']) {
- processingTimeDiv.textContent = 'Processing time: ' + transcript_data['processing_time'].toFixed(2) + ' seconds';
- }
+ const transcriptionDiv = document.getElementById("transcription");
+ const languageDiv = document.getElementById("detected_language");
+
+ if (transcript_data.words && transcript_data.words.length > 0) {
+ // Append words with color based on their probability
+ // biome-ignore lint/complexity/noForEach:
+ transcript_data.words.forEach((wordData) => {
+ const span = document.createElement("span");
+ const probability = wordData.probability;
+ span.textContent = `${wordData.word} `;
+
+ // Set the color based on the probability
+ if (probability > 0.9) {
+ span.style.color = "green";
+ } else if (probability > 0.6) {
+ span.style.color = "orange";
+ } else {
+ span.style.color = "red";
+ }
+
+ transcriptionDiv.appendChild(span);
+ });
+
+ // Add a new line at the end
+ transcriptionDiv.appendChild(document.createElement("br"));
+ } else {
+ // Fallback to plain text
+ transcriptionDiv.textContent += `${transcript_data.text}\n`;
+ }
+
+ // Update the language information
+ if (transcript_data.language && transcript_data.language_probability) {
+ languageDiv.textContent = `${
+ transcript_data.language
+ } (${transcript_data.language_probability.toFixed(2)})`;
+ }
+
+ // Update the processing time, if available
+ const processingTimeDiv = document.getElementById("processing_time");
+ if (transcript_data.processing_time) {
+ processingTimeDiv.textContent = `Processing time: ${transcript_data.processing_time.toFixed(
+ 2,
+ )} seconds`;
+ }
}
-
function startRecording() {
- if (isRecording) return;
- isRecording = true;
-
- const AudioContext = window.AudioContext || window.webkitAudioContext;
- context = new AudioContext();
-
- navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => {
- globalStream = stream;
- const input = context.createMediaStreamSource(stream);
- processor = context.createScriptProcessor(bufferSize, 1, 1);
- processor.onaudioprocess = e => processAudio(e);
- input.connect(processor);
- processor.connect(context.destination);
-
- sendAudioConfig();
- }).catch(error => console.error('Error accessing microphone', error));
-
- // Disable start button and enable stop button
- document.getElementById('startButton').disabled = true;
- document.getElementById('stopButton').disabled = false;
+ if (isRecording) return;
+ isRecording = true;
+
+ const AudioContext = window.AudioContext || window.webkitAudioContext;
+ context = new AudioContext();
+
+ navigator.mediaDevices
+ .getUserMedia({ audio: true })
+ .then((stream) => {
+ globalStream = stream;
+ const input = context.createMediaStreamSource(stream);
+ processor = context.createScriptProcessor(bufferSize, 1, 1);
+ processor.onaudioprocess = (e) => processAudio(e);
+ input.connect(processor);
+ processor.connect(context.destination);
+
+ sendAudioConfig();
+ })
+ .catch((error) => console.error("Error accessing microphone", error));
+
+ // Disable start button and enable stop button
+ document.getElementById("startButton").disabled = true;
+ document.getElementById("stopButton").disabled = false;
}
function stopRecording() {
- if (!isRecording) return;
- isRecording = false;
-
- if (globalStream) {
- globalStream.getTracks().forEach(track => track.stop());
- }
- if (processor) {
- processor.disconnect();
- processor = null;
- }
- if (context) {
- context.close().then(() => context = null);
- }
- document.getElementById('startButton').disabled = false;
- document.getElementById('stopButton').disabled = true;
+ if (!isRecording) return;
+ isRecording = false;
+
+ if (globalStream) {
+ // biome-ignore lint/complexity/noForEach:
+ globalStream.getTracks().forEach((track) => track.stop());
+ }
+ if (processor) {
+ processor.disconnect();
+ processor = null;
+ }
+ if (context) {
+ // biome-ignore lint/suspicious/noAssignInExpressions:
+ context.close().then(() => (context = null));
+ }
+ document.getElementById("startButton").disabled = false;
+ document.getElementById("stopButton").disabled = true;
}
function sendAudioConfig() {
- let selectedStrategy = document.getElementById('bufferingStrategySelect').value;
- let processingArgs = {};
-
- if (selectedStrategy === 'silence_at_end_of_chunk') {
- processingArgs = {
- chunk_length_seconds: parseFloat(document.getElementById('chunk_length_seconds').value),
- chunk_offset_seconds: parseFloat(document.getElementById('chunk_offset_seconds').value)
- };
- }
-
- const audioConfig = {
- type: 'config',
- data: {
- sampleRate: context.sampleRate,
- bufferSize: bufferSize,
- channels: 1, // Assuming mono channel
- language: language,
- processing_strategy: selectedStrategy,
- processing_args: processingArgs
- }
+ const selectedStrategy = document.getElementById(
+ "bufferingStrategySelect",
+ ).value;
+ let processingArgs = {};
+
+ if (selectedStrategy === "silence_at_end_of_chunk") {
+ processingArgs = {
+ chunk_length_seconds: Number.parseFloat(
+ document.getElementById("chunk_length_seconds").value,
+ ),
+ chunk_offset_seconds: Number.parseFloat(
+ document.getElementById("chunk_offset_seconds").value,
+ ),
};
-
- websocket.send(JSON.stringify(audioConfig));
+ }
+
+ const audioConfig = {
+ type: "config",
+ data: {
+ sampleRate: context.sampleRate,
+ bufferSize: bufferSize,
+ channels: 1, // Assuming mono channel
+ language: language,
+ processing_strategy: selectedStrategy,
+ processing_args: processingArgs,
+ },
+ };
+
+ websocket.send(JSON.stringify(audioConfig));
}
function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) {
- if (inputSampleRate === outputSampleRate) {
- return buffer;
+ if (inputSampleRate === outputSampleRate) {
+ return buffer;
+ }
+ const sampleRateRatio = inputSampleRate / outputSampleRate;
+ const newLength = Math.round(buffer.length / sampleRateRatio);
+ const result = new Float32Array(newLength);
+ let offsetResult = 0;
+ let offsetBuffer = 0;
+ while (offsetResult < result.length) {
+ const nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio);
+ let accum = 0;
+ let count = 0;
+ for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
+ accum += buffer[i];
+ count++;
}
- var sampleRateRatio = inputSampleRate / outputSampleRate;
- var newLength = Math.round(buffer.length / sampleRateRatio);
- var result = new Float32Array(newLength);
- var offsetResult = 0;
- var offsetBuffer = 0;
- while (offsetResult < result.length) {
- var nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio);
- var accum = 0, count = 0;
- for (var i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
- accum += buffer[i];
- count++;
- }
- result[offsetResult] = accum / count;
- offsetResult++;
- offsetBuffer = nextOffsetBuffer;
- }
- return result;
+ result[offsetResult] = accum / count;
+ offsetResult++;
+ offsetBuffer = nextOffsetBuffer;
+ }
+ return result;
}
function processAudio(e) {
- const inputSampleRate = context.sampleRate;
- const outputSampleRate = 16000; // Target sample rate
-
- const left = e.inputBuffer.getChannelData(0);
- const downsampledBuffer = downsampleBuffer(left, inputSampleRate, outputSampleRate);
- const audioData = convertFloat32ToInt16(downsampledBuffer);
-
- if (websocket && websocket.readyState === WebSocket.OPEN) {
- websocket.send(audioData);
- }
+ const inputSampleRate = context.sampleRate;
+ const outputSampleRate = 16000; // Target sample rate
+
+ const left = e.inputBuffer.getChannelData(0);
+ const downsampledBuffer = downsampleBuffer(
+ left,
+ inputSampleRate,
+ outputSampleRate,
+ );
+ const audioData = convertFloat32ToInt16(downsampledBuffer);
+
+ if (websocket && websocket.readyState === WebSocket.OPEN) {
+ websocket.send(audioData);
+ }
}
function convertFloat32ToInt16(buffer) {
- let l = buffer.length;
- const buf = new Int16Array(l);
- while (l--) {
- buf[l] = Math.min(1, buffer[l]) * 0x7FFF;
- }
- return buf.buffer;
+ let l = buffer.length;
+ const buf = new Int16Array(l);
+ while (l--) {
+ buf[l] = Math.min(1, buffer[l]) * 0x7fff;
+ }
+ return buf.buffer;
}
-// Initialize WebSocket on page load
-// window.onload = initWebSocket;
-
function toggleBufferingStrategyPanel() {
- var selectedStrategy = document.getElementById('bufferingStrategySelect').value;
- if (selectedStrategy === 'silence_at_end_of_chunk') {
- var panel = document.getElementById('silence_at_end_of_chunk_options_panel');
- panel.classList.remove('hidden');
- } else {
- var panel = document.getElementById('silence_at_end_of_chunk_options_panel');
- panel.classList.add('hidden');
- }
+ const selectedStrategy = document.getElementById(
+ "bufferingStrategySelect",
+ ).value;
+ if (selectedStrategy === "silence_at_end_of_chunk") {
+ const panel = document.getElementById(
+ "silence_at_end_of_chunk_options_panel",
+ );
+ panel.classList.remove("hidden");
+ } else {
+ const panel = document.getElementById(
+ "silence_at_end_of_chunk_options_panel",
+ );
+ panel.classList.add("hidden");
+ }
}
+// // Initialize WebSocket on page load
+// window.onload = initWebSocket;
diff --git a/dockerhub_push.sh b/dockerhub_push.sh
new file mode 100644
index 0000000..40cf311
--- /dev/null
+++ b/dockerhub_push.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+# Define variables
+# DOCKER_USERNAME="tyrese3915"
+read -p "Enter your Docker Hub username: " DOCKER_USERNAME
+IMAGE_NAME="voice-stream-ai"
+TAG="latest"
+
+# Step 1: Log in to Docker Hub
+echo "Logging in to Docker Hub..."
+docker login --username "$DOCKER_USERNAME"
+if [ $? -ne 0 ]; then
+ echo "Docker login failed. Exiting..."
+ exit 1
+fi
+
+# Step 2: Build your Docker image
+echo "Building the Docker image..."
+docker build -t "$IMAGE_NAME:$TAG" .
+if [ $? -ne 0 ]; then
+ echo "Docker build failed. Exiting..."
+ exit 1
+fi
+
+# Step 3: Tag the image for your Docker Hub repository
+echo "Tagging the image..."
+docker tag "$IMAGE_NAME:$TAG" "$DOCKER_USERNAME/$IMAGE_NAME:$TAG"
+if [ $? -ne 0 ]; then
+ echo "Docker tag failed. Exiting..."
+ exit 1
+fi
+
+# Step 4: Push the image to Docker Hub
+echo "Pushing the image to Docker Hub..."
+docker push "$DOCKER_USERNAME/$IMAGE_NAME:$TAG"
+if [ $? -ne 0 ]; then
+ echo "Docker push failed. Exiting..."
+ exit 1
+fi
+
+echo "Docker image has been successfully pushed to Docker Hub."
diff --git a/img/voicestreamai_test.mp4 b/img/voicestreamai_test.mp4
new file mode 100644
index 0000000..9d9c848
Binary files /dev/null and b/img/voicestreamai_test.mp4 differ
diff --git a/requirements.txt b/requirements.txt
index 966d009..b2eded3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,6 @@
-websockets
-speechbrain
pyannote-audio
-asyncio
sentence-transformers
transformers
faster-whisper
+torchvision
+aiohttp
\ No newline at end of file
diff --git a/src/asr/asr_factory.py b/src/asr/asr_factory.py
index 433291a..bd8ba77 100644
--- a/src/asr/asr_factory.py
+++ b/src/asr/asr_factory.py
@@ -1,6 +1,7 @@
from .whisper_asr import WhisperASR
from .faster_whisper_asr import FasterWhisperASR
+
class ASRFactory:
@staticmethod
def create_asr_pipeline(type, **kwargs):
diff --git a/src/asr/faster_whisper_asr.py b/src/asr/faster_whisper_asr.py
index 8cc1734..a27a606 100644
--- a/src/asr/faster_whisper_asr.py
+++ b/src/asr/faster_whisper_asr.py
@@ -110,15 +110,25 @@
class FasterWhisperASR(ASRInterface):
def __init__(self, **kwargs):
- model_size = kwargs.get('model_size', "large-v3")
+ model_size = kwargs.get("model_size", "large-v3")
# Run on GPU with FP16
- self.asr_pipeline = WhisperModel(model_size, device="cuda", compute_type="float16")
+ self.asr_pipeline = WhisperModel(
+ model_size, device="cuda", compute_type="float16"
+ )
async def transcribe(self, client):
- file_path = await save_audio_to_file(client.scratch_buffer, client.get_file_name())
+ file_path = await save_audio_to_file(
+ client.scratch_buffer, client.get_file_name()
+ )
- language = None if client.config['language'] is None else language_codes.get(client.config['language'].lower())
- segments, info = self.asr_pipeline.transcribe(file_path, word_timestamps=True, language=language)
+ language = (
+ None
+ if client.config["language"] is None
+ else language_codes.get(client.config["language"].lower())
+ )
+ segments, info = self.asr_pipeline.transcribe(
+ file_path, word_timestamps=True, language=language
+ )
segments = list(segments) # The transcription will actually run here.
os.remove(file_path)
@@ -126,13 +136,17 @@ async def transcribe(self, client):
flattened_words = [word for segment in segments for word in segment.words]
to_return = {
- "language": info.language,
- "language_probability": info.language_probability,
- "text": ' '.join([s.text.strip() for s in segments]),
- "words":
- [
- {"word": w.word, "start": w.start, "end": w.end, "probability":w.probability} for w in flattened_words
- ]
+ "language": info.language,
+ "language_probability": info.language_probability,
+ "text": " ".join([s.text.strip() for s in segments]),
+ "words": [
+ {
+ "word": w.word,
+ "start": w.start,
+ "end": w.end,
+ "probability": w.probability,
+ }
+ for w in flattened_words
+ ],
}
return to_return
-
diff --git a/src/asr/whisper_asr.py b/src/asr/whisper_asr.py
index 38d14f6..1c31da3 100644
--- a/src/asr/whisper_asr.py
+++ b/src/asr/whisper_asr.py
@@ -3,18 +3,23 @@
from src.audio_utils import save_audio_to_file
import os
+
class WhisperASR(ASRInterface):
def __init__(self, **kwargs):
- model_name = kwargs.get('model_name', "openai/whisper-large-v3")
+ model_name = kwargs.get("model_name", "openai/whisper-large-v3")
self.asr_pipeline = pipeline("automatic-speech-recognition", model=model_name)
async def transcribe(self, client):
- file_path = await save_audio_to_file(client.scratch_buffer, client.get_file_name())
-
- if client.config['language'] is not None:
- to_return = self.asr_pipeline(file_path, generate_kwargs={"language": client.config['language']})['text']
+ file_path = await save_audio_to_file(
+ client.scratch_buffer, client.get_file_name()
+ )
+
+ if client.config["language"] is not None:
+ to_return = self.asr_pipeline(
+ file_path, generate_kwargs={"language": client.config["language"]}
+ )["text"]
else:
- to_return = self.asr_pipeline(file_path)['text']
+ to_return = self.asr_pipeline(file_path)["text"]
os.remove(file_path)
@@ -22,6 +27,6 @@ async def transcribe(self, client):
"language": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER",
"language_probability": None,
"text": to_return.strip(),
- "words": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER"
+ "words": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER",
}
return to_return
diff --git a/src/audio_utils.py b/src/audio_utils.py
index 9b8d2b0..370e920 100644
--- a/src/audio_utils.py
+++ b/src/audio_utils.py
@@ -1,7 +1,10 @@
import wave
import os
-async def save_audio_to_file(audio_data, file_name, audio_dir="audio_files", audio_format="wav"):
+
+async def save_audio_to_file(
+ audio_data, file_name, audio_dir="audio_files", audio_format="wav"
+):
"""
Saves the audio data to a file.
@@ -14,10 +17,10 @@ async def save_audio_to_file(audio_data, file_name, audio_dir="audio_files", aud
"""
os.makedirs(audio_dir, exist_ok=True)
-
+
file_path = os.path.join(audio_dir, file_name)
- with wave.open(file_path, 'wb') as wav_file:
+ with wave.open(file_path, "wb") as wav_file:
wav_file.setnchannels(1) # Assuming mono audio
wav_file.setsampwidth(2)
wav_file.setframerate(16000)
diff --git a/src/buffering_strategy/buffering_strategies.py b/src/buffering_strategy/buffering_strategies.py
index aca1da5..bb5cc5d 100644
--- a/src/buffering_strategy/buffering_strategies.py
+++ b/src/buffering_strategy/buffering_strategies.py
@@ -5,6 +5,7 @@
from .buffering_strategy_interface import BufferingStrategyInterface
+
class SilenceAtEndOfChunk(BufferingStrategyInterface):
"""
A buffering strategy that processes audio at the end of each chunk with silence detection.
@@ -28,20 +29,20 @@ def __init__(self, client, **kwargs):
"""
self.client = client
- self.chunk_length_seconds = os.environ.get('BUFFERING_CHUNK_LENGTH_SECONDS')
+ self.chunk_length_seconds = os.environ.get("BUFFERING_CHUNK_LENGTH_SECONDS")
if not self.chunk_length_seconds:
- self.chunk_length_seconds = kwargs.get('chunk_length_seconds')
+ self.chunk_length_seconds = kwargs.get("chunk_length_seconds")
self.chunk_length_seconds = float(self.chunk_length_seconds)
- self.chunk_offset_seconds = os.environ.get('BUFFERING_CHUNK_OFFSET_SECONDS')
+ self.chunk_offset_seconds = os.environ.get("BUFFERING_CHUNK_OFFSET_SECONDS")
if not self.chunk_offset_seconds:
- self.chunk_offset_seconds = kwargs.get('chunk_offset_seconds')
+ self.chunk_offset_seconds = kwargs.get("chunk_offset_seconds")
self.chunk_offset_seconds = float(self.chunk_offset_seconds)
- self.error_if_not_realtime = os.environ.get('ERROR_IF_NOT_REALTIME')
+ self.error_if_not_realtime = os.environ.get("ERROR_IF_NOT_REALTIME")
if not self.error_if_not_realtime:
- self.error_if_not_realtime = kwargs.get('error_if_not_realtime', False)
-
+ self.error_if_not_realtime = kwargs.get("error_if_not_realtime", False)
+
self.processing_flag = False
def process_audio(self, websocket, vad_pipeline, asr_pipeline):
@@ -56,17 +57,26 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline):
vad_pipeline: The voice activity detection pipeline.
asr_pipeline: The automatic speech recognition pipeline.
"""
- chunk_length_in_bytes = self.chunk_length_seconds * self.client.sampling_rate * self.client.samples_width
+ chunk_length_in_bytes = (
+ self.chunk_length_seconds
+ * self.client.sampling_rate
+ * self.client.samples_width
+ )
if len(self.client.buffer) > chunk_length_in_bytes:
if self.processing_flag:
- exit("Error in realtime processing: tried processing a new chunk while the previous one was still being processed")
+ print(
+ "Error in realtime processing: tried processing a new chunk while the previous one was still being processed"
+ )
+ return
self.client.scratch_buffer += self.client.buffer
self.client.buffer.clear()
self.processing_flag = True
# Schedule the processing in a separate task
- asyncio.create_task(self.process_audio_async(websocket, vad_pipeline, asr_pipeline))
-
+ asyncio.create_task(
+ self.process_audio_async(websocket, vad_pipeline, asr_pipeline)
+ )
+
async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline):
"""
Asynchronously process audio for activity detection and transcription.
@@ -78,7 +88,7 @@ async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline):
websocket (Websocket): The WebSocket connection for sending transcriptions.
vad_pipeline: The voice activity detection pipeline.
asr_pipeline: The automatic speech recognition pipeline.
- """
+ """
start = time.time()
vad_results = await vad_pipeline.detect_activity(self.client)
@@ -88,15 +98,18 @@ async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline):
self.processing_flag = False
return
- last_segment_should_end_before = ((len(self.client.scratch_buffer) / (self.client.sampling_rate * self.client.samples_width)) - self.chunk_offset_seconds)
- if vad_results[-1]['end'] < last_segment_should_end_before:
+ last_segment_should_end_before = (
+ len(self.client.scratch_buffer)
+ / (self.client.sampling_rate * self.client.samples_width)
+ ) - self.chunk_offset_seconds
+ if vad_results[-1]["end"] < last_segment_should_end_before:
transcription = await asr_pipeline.transcribe(self.client)
- if transcription['text'] != '':
+ if transcription["text"] != "":
end = time.time()
- transcription['processing_time'] = end - start
- json_transcription = json.dumps(transcription)
- await websocket.send(json_transcription)
+ transcription["processing_time"] = end - start
+ json_transcription = json.dumps(transcription)
+ await websocket.send_str(json_transcription)
self.client.scratch_buffer.clear()
self.client.increment_file_counter()
-
- self.processing_flag = False
\ No newline at end of file
+
+ self.processing_flag = False
diff --git a/src/buffering_strategy/buffering_strategy_factory.py b/src/buffering_strategy/buffering_strategy_factory.py
index 8d3f452..6f8131c 100644
--- a/src/buffering_strategy/buffering_strategy_factory.py
+++ b/src/buffering_strategy/buffering_strategy_factory.py
@@ -1,5 +1,6 @@
from .buffering_strategies import SilenceAtEndOfChunk
+
class BufferingStrategyFactory:
"""
A factory class for creating instances of different buffering strategies.
diff --git a/src/client.py b/src/client.py
index a735d9e..95ab70c 100644
--- a/src/client.py
+++ b/src/client.py
@@ -1,5 +1,6 @@
from src.buffering_strategy.buffering_strategy_factory import BufferingStrategyFactory
+
class Client:
"""
Represents a client connected to the VoiceStreamAI server.
@@ -16,26 +17,29 @@ class Client:
sampling_rate (int): The sampling rate of the audio data in Hz.
samples_width (int): The width of each audio sample in bits.
"""
+
def __init__(self, client_id, sampling_rate, samples_width):
self.client_id = client_id
self.buffer = bytearray()
self.scratch_buffer = bytearray()
- self.config = {"language": None,
- "processing_strategy": "silence_at_end_of_chunk",
- "processing_args": {
- "chunk_length_seconds": 5,
- "chunk_offset_seconds": 0.1
- }
- }
+ self.config = {
+ "language": None,
+ "processing_strategy": "silence_at_end_of_chunk",
+ "processing_args": {"chunk_length_seconds": 5, "chunk_offset_seconds": 0.1},
+ }
self.file_counter = 0
self.total_samples = 0
self.sampling_rate = sampling_rate
self.samples_width = samples_width
- self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy(self.config['processing_strategy'], self, **self.config['processing_args'])
+ self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy(
+ self.config["processing_strategy"], self, **self.config["processing_args"]
+ )
def update_config(self, config_data):
self.config.update(config_data)
- self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy(self.config['processing_strategy'], self, **self.config['processing_args'])
+ self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy(
+ self.config["processing_strategy"], self, **self.config["processing_args"]
+ )
def append_audio_data(self, audio_data):
self.buffer.extend(audio_data)
@@ -49,6 +53,6 @@ def increment_file_counter(self):
def get_file_name(self):
return f"{self.client_id}_{self.file_counter}.wav"
-
+
def process_audio(self, websocket, vad_pipeline, asr_pipeline):
self.buffering_strategy.process_audio(websocket, vad_pipeline, asr_pipeline)
diff --git a/src/main.py b/src/main.py
index fd9f8fa..89b0841 100644
--- a/src/main.py
+++ b/src/main.py
@@ -2,20 +2,54 @@
import asyncio
import json
-from .server import Server
+from src.server import Server
from src.asr.asr_factory import ASRFactory
from src.vad.vad_factory import VADFactory
+
def parse_args():
- parser = argparse.ArgumentParser(description="VoiceStreamAI Server: Real-time audio transcription using self-hosted Whisper and WebSocket")
- parser.add_argument("--vad-type", type=str, default="pyannote", help="Type of VAD pipeline to use (e.g., 'pyannote')")
- parser.add_argument("--vad-args", type=str, default='{"auth_token": "huggingface_token"}', help="JSON string of additional arguments for VAD pipeline")
- parser.add_argument("--asr-type", type=str, default="faster_whisper", help="Type of ASR pipeline to use (e.g., 'whisper')")
- parser.add_argument("--asr-args", type=str, default='{"model_size": "large-v3"}', help="JSON string of additional arguments for ASR pipeline")
- parser.add_argument("--host", type=str, default="127.0.0.1", help="Host for the WebSocket server")
- parser.add_argument("--port", type=int, default=8765, help="Port for the WebSocket server")
+ parser = argparse.ArgumentParser(
+ description="VoiceStreamAI Server: Real-time audio transcription using self-hosted Whisper and WebSocket"
+ )
+ parser.add_argument(
+ "--vad-type",
+ type=str,
+ default="pyannote",
+ help="Type of VAD pipeline to use (e.g., 'pyannote')",
+ )
+ parser.add_argument(
+ "--vad-args",
+ type=str,
+ default='{"auth_token": "huggingface_token"}',
+ help="JSON string of additional arguments for VAD pipeline",
+ )
+ parser.add_argument(
+ "--asr-type",
+ type=str,
+ default="faster_whisper",
+ help="Type of ASR pipeline to use (e.g., 'whisper')",
+ )
+ parser.add_argument(
+ "--asr-args",
+ type=str,
+ default='{"model_size": "large-v3"}',
+ help="JSON string of additional arguments for ASR pipeline",
+ )
+ parser.add_argument(
+ "--host", type=str, default="127.0.0.1", help="Host for the WebSocket server"
+ )
+ parser.add_argument(
+ "--port", type=int, default=80, help="Port for the WebSocket server"
+ )
+ parser.add_argument(
+ "--static-path",
+ type=str,
+ default="./static",
+ help="Port for the WebSocket server",
+ )
return parser.parse_args()
+
def main():
args = parse_args()
@@ -29,10 +63,19 @@ def main():
vad_pipeline = VADFactory.create_vad_pipeline(args.vad_type, **vad_args)
asr_pipeline = ASRFactory.create_asr_pipeline(args.asr_type, **asr_args)
- server = Server(vad_pipeline, asr_pipeline, host=args.host, port=args.port, sampling_rate=16000, samples_width=2)
+ server = Server(
+ vad_pipeline,
+ asr_pipeline,
+ host=args.host,
+ port=args.port,
+ sampling_rate=16000,
+ samples_width=2,
+ static_files_path=args.static_path,
+ )
asyncio.get_event_loop().run_until_complete(server.start())
asyncio.get_event_loop().run_forever()
+
if __name__ == "__main__":
main()
diff --git a/src/server.py b/src/server.py
index a9028d4..ab2031d 100644
--- a/src/server.py
+++ b/src/server.py
@@ -1,11 +1,10 @@
-import websockets
+from aiohttp import web
import uuid
import json
-import asyncio
-from src.audio_utils import save_audio_to_file
from src.client import Client
+
class Server:
"""
Represents the WebSocket server for handling real-time audio transcription.
@@ -23,7 +22,17 @@ class Server:
samples_width (int): The width of each audio sample in bits.
connected_clients (dict): A dictionary mapping client IDs to Client objects.
"""
- def __init__(self, vad_pipeline, asr_pipeline, host='localhost', port=8765, sampling_rate=16000, samples_width=2):
+
+ def __init__(
+ self,
+ vad_pipeline,
+ asr_pipeline,
+ host="localhost",
+ port=80,
+ sampling_rate=16000,
+ samples_width=2,
+ static_files_path="./static",
+ ):
self.vad_pipeline = vad_pipeline
self.asr_pipeline = asr_pipeline
self.host = host
@@ -31,39 +40,57 @@ def __init__(self, vad_pipeline, asr_pipeline, host='localhost', port=8765, samp
self.sampling_rate = sampling_rate
self.samples_width = samples_width
self.connected_clients = {}
+ self.static_files_path = static_files_path
+ self.app = web.Application()
+ self.setup_routes()
+
+ def setup_routes(self):
+ self.app.router.add_get("/ws", self.websocket_handler)
+ self.app.router.add_get("/", self.index_handler)
+ self.app.router.add_static(
+ "/static/", path=self.static_files_path, name="static"
+ )
- async def handle_audio(self, client, websocket):
- while True:
- message = await websocket.recv()
+ async def index_handler(self, request):
+ return web.FileResponse(path=f"{self.static_files_path}/index.html")
- if isinstance(message, bytes):
- client.append_audio_data(message)
- elif isinstance(message, str):
- config = json.loads(message)
- if config.get('type') == 'config':
- client.update_config(config['data'])
- continue
- else:
- print(f"Unexpected message type from {client.client_id}")
+ def setup_routes(self):
+ self.app.router.add_get("/ws", self.websocket_handler)
+ self.app.router.add_get("/", self.serve_index)
+ self.app.router.add_static("/", path=self.static_files_path, name="static")
- # this is synchronous, any async operation is in BufferingStrategy
- client.process_audio(websocket, self.vad_pipeline, self.asr_pipeline)
+ async def serve_index(self, request):
+ return web.FileResponse(path=f"{self.static_files_path}/index.html")
+ async def websocket_handler(self, request):
+ ws = web.WebSocketResponse()
+ await ws.prepare(request)
- async def handle_websocket(self, websocket, path):
client_id = str(uuid.uuid4())
client = Client(client_id, self.sampling_rate, self.samples_width)
self.connected_clients[client_id] = client
- print(f"Client {client_id} connected")
+ print(f"Client {client_id} connected.")
+ async for msg in ws:
+ if msg.type == web.WSMsgType.TEXT:
+ message_text = msg.data
+ if message_text == "close":
+ await ws.close()
+ else:
+ # Handle textual WebSocket messages
+ config = json.loads(message_text)
+ if config.get("type") == "config":
+ client.update_config(config["data"])
+ elif msg.type == web.WSMsgType.BINARY:
+ # Handle binary WebSocket messages
+ client.append_audio_data(msg.data)
+ client.process_audio(ws, self.vad_pipeline, self.asr_pipeline)
+ elif msg.type == web.WSMsgType.ERROR:
+ print(f"WebSocket connection closed with exception {ws.exception()}")
- try:
- await self.handle_audio(client, websocket)
- except websockets.ConnectionClosed as e:
- print(f"Connection with {client_id} closed: {e}")
- finally:
- del self.connected_clients[client_id]
+ print(f"Client {client_id} disconnected.")
+ del self.connected_clients[client_id]
+ return ws
def start(self):
- print("Websocket server ready to accept connections")
- return websockets.serve(self.handle_websocket, self.host, self.port)
+ web.run_app(self.app, host=self.host, port=self.port)
diff --git a/src/static/index.html b/src/static/index.html
new file mode 100644
index 0000000..c23be91
--- /dev/null
+++ b/src/static/index.html
@@ -0,0 +1,147 @@
+
+
+
+
+
+ Audio Stream to WebSocket Server
+
+
+
+
+
Transcribe a Web Audio Stream with PyAnnote + Whisper
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
WebSocket: Not Connected
+
Detected Language: Undefined
+
Last Processing Time: Undefined
+
+
diff --git a/src/static/utils.js b/src/static/utils.js
new file mode 100644
index 0000000..bf362b3
--- /dev/null
+++ b/src/static/utils.js
@@ -0,0 +1,256 @@
+/**
+ * VoiceStreamAI Client - WebSocket-based real-time transcription
+ *
+ */
+
+let websocket;
+let context;
+let processor;
+let globalStream;
+
+const bufferSize = 4096;
+let isRecording = false;
+
+function initWebSocket() {
+ const websocketAddress = document.getElementById("websocketAddress").value;
+ chunk_length_seconds = document.getElementById("chunk_length_seconds").value;
+ chunk_offset_seconds = document.getElementById("chunk_offset_seconds").value;
+ const selectedLanguage = document.getElementById("languageSelect").value;
+ language = selectedLanguage !== "multilingual" ? selectedLanguage : null;
+
+ if (!websocketAddress) {
+ console.log("WebSocket address is required.");
+ return;
+ }
+
+ websocket = new WebSocket(websocketAddress);
+ websocket.onopen = () => {
+ console.log("WebSocket connection established");
+ document.getElementById("webSocketStatus").textContent = "Connected";
+ document.getElementById("startButton").disabled = false;
+ };
+ websocket.onclose = (event) => {
+ console.log("WebSocket connection closed", event);
+ document.getElementById("webSocketStatus").textContent = "Not Connected";
+ document.getElementById("startButton").disabled = true;
+ document.getElementById("stopButton").disabled = true;
+ };
+ websocket.onmessage = (event) => {
+ console.log("Message from server:", event.data);
+ const transcript_data = JSON.parse(event.data);
+ updateTranscription(transcript_data);
+ };
+}
+
+function updateTranscription(transcript_data) {
+ const transcriptionDiv = document.getElementById("transcription");
+ const languageDiv = document.getElementById("detected_language");
+
+ if (transcript_data.words && transcript_data.words.length > 0) {
+ // Append words with color based on their probability
+ // biome-ignore lint/complexity/noForEach:
+ transcript_data.words.forEach((wordData) => {
+ const span = document.createElement("span");
+ const probability = wordData.probability;
+ span.textContent = `${wordData.word} `;
+
+ // Set the color based on the probability
+ if (probability > 0.9) {
+ span.style.color = "green";
+ } else if (probability > 0.6) {
+ span.style.color = "orange";
+ } else {
+ span.style.color = "red";
+ }
+
+ transcriptionDiv.appendChild(span);
+ });
+
+ // Add a new line at the end
+ transcriptionDiv.appendChild(document.createElement("br"));
+ } else {
+ // Fallback to plain text
+ transcriptionDiv.textContent += `${transcript_data.text}\n`;
+ }
+
+ // Update the language information
+ if (transcript_data.language && transcript_data.language_probability) {
+ languageDiv.textContent = `${
+ transcript_data.language
+ } (${transcript_data.language_probability.toFixed(2)})`;
+ }
+
+ // Update the processing time, if available
+ const processingTimeDiv = document.getElementById("processing_time");
+ if (transcript_data.processing_time) {
+ processingTimeDiv.textContent = `Processing time: ${transcript_data.processing_time.toFixed(
+ 2,
+ )} seconds`;
+ }
+}
+
+function startRecording() {
+ if (isRecording) return;
+ isRecording = true;
+
+ const AudioContext = window.AudioContext || window.webkitAudioContext;
+ context = new AudioContext();
+
+ navigator.mediaDevices
+ .getUserMedia({ audio: true })
+ .then((stream) => {
+ globalStream = stream;
+ const input = context.createMediaStreamSource(stream);
+ processor = context.createScriptProcessor(bufferSize, 1, 1);
+ processor.onaudioprocess = (e) => processAudio(e);
+ input.connect(processor);
+ processor.connect(context.destination);
+
+ sendAudioConfig();
+ })
+ .catch((error) => console.error("Error accessing microphone", error));
+
+ // Disable start button and enable stop button
+ document.getElementById("startButton").disabled = true;
+ document.getElementById("stopButton").disabled = false;
+}
+
+function stopRecording() {
+ if (!isRecording) return;
+ isRecording = false;
+
+ if (globalStream) {
+ // biome-ignore lint/complexity/noForEach:
+ globalStream.getTracks().forEach((track) => track.stop());
+ }
+ if (processor) {
+ processor.disconnect();
+ processor = null;
+ }
+ if (context) {
+ // biome-ignore lint/suspicious/noAssignInExpressions:
+ context.close().then(() => (context = null));
+ }
+ document.getElementById("startButton").disabled = false;
+ document.getElementById("stopButton").disabled = true;
+}
+
+function sendAudioConfig() {
+ const selectedStrategy = document.getElementById(
+ "bufferingStrategySelect",
+ ).value;
+ let processingArgs = {};
+
+ if (selectedStrategy === "silence_at_end_of_chunk") {
+ processingArgs = {
+ chunk_length_seconds: Number.parseFloat(
+ document.getElementById("chunk_length_seconds").value,
+ ),
+ chunk_offset_seconds: Number.parseFloat(
+ document.getElementById("chunk_offset_seconds").value,
+ ),
+ };
+ }
+
+ const audioConfig = {
+ type: "config",
+ data: {
+ sampleRate: context.sampleRate,
+ bufferSize: bufferSize,
+ channels: 1, // Assuming mono channel
+ language: language,
+ processing_strategy: selectedStrategy,
+ processing_args: processingArgs,
+ },
+ };
+
+ websocket.send(JSON.stringify(audioConfig));
+}
+
+function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) {
+ if (inputSampleRate === outputSampleRate) {
+ return buffer;
+ }
+ const sampleRateRatio = inputSampleRate / outputSampleRate;
+ const newLength = Math.round(buffer.length / sampleRateRatio);
+ const result = new Float32Array(newLength);
+ let offsetResult = 0;
+ let offsetBuffer = 0;
+ while (offsetResult < result.length) {
+ const nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio);
+ let accum = 0;
+ let count = 0;
+ for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
+ accum += buffer[i];
+ count++;
+ }
+ result[offsetResult] = accum / count;
+ offsetResult++;
+ offsetBuffer = nextOffsetBuffer;
+ }
+ return result;
+}
+
+function processAudio(e) {
+ const inputSampleRate = context.sampleRate;
+ const outputSampleRate = 16000; // Target sample rate
+
+ const left = e.inputBuffer.getChannelData(0);
+ const downsampledBuffer = downsampleBuffer(
+ left,
+ inputSampleRate,
+ outputSampleRate,
+ );
+ const audioData = convertFloat32ToInt16(downsampledBuffer);
+
+ if (websocket && websocket.readyState === WebSocket.OPEN) {
+ websocket.send(audioData);
+ }
+}
+
+function convertFloat32ToInt16(buffer) {
+ let l = buffer.length;
+ const buf = new Int16Array(l);
+ while (l--) {
+ buf[l] = Math.min(1, buffer[l]) * 0x7fff;
+ }
+ return buf.buffer;
+}
+
+function toggleBufferingStrategyPanel() {
+ const selectedStrategy = document.getElementById(
+ "bufferingStrategySelect",
+ ).value;
+ if (selectedStrategy === "silence_at_end_of_chunk") {
+ const panel = document.getElementById(
+ "silence_at_end_of_chunk_options_panel",
+ );
+ panel.classList.remove("hidden");
+ } else {
+ const panel = document.getElementById(
+ "silence_at_end_of_chunk_options_panel",
+ );
+ panel.classList.add("hidden");
+ }
+}
+
+function getWebSocketUrl() {
+ if (
+ window.location.protocol !== "https:" &&
+ window.location.protocol !== "http:"
+ )
+ return null;
+ const wsProtocol = window.location.protocol === "https:" ? "wss" : "ws";
+ const wsUrl = `${wsProtocol}://${window.location.host}/ws`;
+ return wsUrl;
+}
+
+// // Initialize WebSocket on page load
+// window.onload = initWebSocket;
+
+window.onload = () => {
+ const url = getWebSocketUrl();
+ document.getElementById("websocketAddress").value =
+ url ?? "ws://localhost/ws";
+ initWebSocket();
+};
diff --git a/src/vad/pyannote_vad.py b/src/vad/pyannote_vad.py
index 758e919..7469a6b 100644
--- a/src/vad/pyannote_vad.py
+++ b/src/vad/pyannote_vad.py
@@ -1,7 +1,6 @@
from os import remove
import os
-from pyannote.core import Segment
from pyannote.audio import Model
from pyannote.audio.pipelines import VoiceActivityDetection
@@ -22,23 +21,35 @@ def __init__(self, **kwargs):
model_name (str): The model name for Pyannote.
auth_token (str, optional): Authentication token for Hugging Face.
"""
-
- model_name = kwargs.get('model_name', "pyannote/segmentation")
- auth_token = os.environ.get('PYANNOTE_AUTH_TOKEN')
+ model_name = kwargs.get("model_name", "pyannote/segmentation")
+
+ auth_token = os.environ.get("PYANNOTE_AUTH_TOKEN")
if not auth_token:
- auth_token = kwargs.get('auth_token')
-
+ auth_token = kwargs.get("auth_token")
+
if auth_token is None:
- raise ValueError("Missing required env var in PYANNOTE_AUTH_TOKEN or argument in --vad-args: 'auth_token'")
-
- pyannote_args = kwargs.get('pyannote_args', {"onset": 0.5, "offset": 0.5, "min_duration_on": 0.3, "min_duration_off": 0.3})
+ raise ValueError(
+ "Missing required env var in PYANNOTE_AUTH_TOKEN or argument in --vad-args: 'auth_token'"
+ )
+
+ pyannote_args = kwargs.get(
+ "pyannote_args",
+ {
+ "onset": 0.5,
+ "offset": 0.5,
+ "min_duration_on": 0.1,
+ "min_duration_off": 0.1,
+ },
+ )
self.model = Model.from_pretrained(model_name, use_auth_token=auth_token)
self.vad_pipeline = VoiceActivityDetection(segmentation=self.model)
self.vad_pipeline.instantiate(pyannote_args)
async def detect_activity(self, client):
- audio_file_path = await save_audio_to_file(client.scratch_buffer, client.get_file_name())
+ audio_file_path = await save_audio_to_file(
+ client.scratch_buffer, client.get_file_name()
+ )
vad_results = self.vad_pipeline(audio_file_path)
remove(audio_file_path)
vad_segments = []
diff --git a/src/vad/vad_factory.py b/src/vad/vad_factory.py
index 600864c..e20d3f9 100644
--- a/src/vad/vad_factory.py
+++ b/src/vad/vad_factory.py
@@ -1,5 +1,6 @@
from .pyannote_vad import PyannoteVAD
+
class VADFactory:
"""
Factory for creating instances of VAD systems.
diff --git a/test/asr/test_asr.py b/test/asr/test_asr.py
index d9fafcf..b385c75 100644
--- a/test/asr/test_asr.py
+++ b/test/asr/test_asr.py
@@ -4,51 +4,65 @@
import asyncio
from sentence_transformers import SentenceTransformer, util
from pydub import AudioSegment
-import argparse
from src.asr.asr_factory import ASRFactory
from src.client import Client
+
class TestWhisperASR(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Use an environment variable to get the ASR model type
- cls.asr_type = os.getenv('ASR_TYPE', 'whisper')
+ cls.asr_type = os.getenv("ASR_TYPE", "whisper")
def setUp(self):
self.asr = ASRFactory.create_asr_pipeline(self.asr_type)
- self.annotations_path = os.path.join(os.path.dirname(__file__), "../audio_files/annotations.json")
+ self.annotations_path = os.path.join(
+ os.path.dirname(__file__), "../audio_files/annotations.json"
+ )
self.client = Client("test_client", 16000, 2) # Example client
- self.similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
+ self.similarity_model = SentenceTransformer(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ )
def load_annotations(self):
- with open(self.annotations_path, 'r') as file:
+ with open(self.annotations_path, "r") as file:
return json.load(file)
def get_audio_segment(self, file_path, start, end):
- with open(file_path, 'rb') as file:
+ with open(file_path, "rb") as file:
audio = AudioSegment.from_file(file, format="wav")
- return audio[start * 1000:end * 1000] # pydub works in milliseconds
+ return audio[start * 1000 : end * 1000] # pydub works in milliseconds
def test_transcribe_segments(self):
annotations = self.load_annotations()
for audio_file, data in annotations.items():
- audio_file_path = os.path.join(os.path.dirname(__file__), f"../audio_files/{audio_file}")
+ audio_file_path = os.path.join(
+ os.path.dirname(__file__), f"../audio_files/{audio_file}"
+ )
similarities = []
for segment in data["segments"]:
- audio_segment = self.get_audio_segment(audio_file_path, segment["start"], segment["end"])
+ audio_segment = self.get_audio_segment(
+ audio_file_path, segment["start"], segment["end"]
+ )
self.client.scratch_buffer = bytearray(audio_segment.raw_data)
- self.client.config['language'] = None
+ self.client.config["language"] = None
transcription = asyncio.run(self.asr.transcribe(self.client))["text"]
- embedding_1 = self.similarity_model.encode(transcription.lower().strip(), convert_to_tensor=True)
- embedding_2 = self.similarity_model.encode(segment["transcription"].lower().strip(), convert_to_tensor=True)
+ embedding_1 = self.similarity_model.encode(
+ transcription.lower().strip(), convert_to_tensor=True
+ )
+ embedding_2 = self.similarity_model.encode(
+ segment["transcription"].lower().strip(), convert_to_tensor=True
+ )
similarity = util.pytorch_cos_sim(embedding_1, embedding_2).item()
similarities.append(similarity)
- print(f"\nSegment from '{audio_file}' ({segment['start']}-{segment['end']}s):")
+ print(
+ f"\nSegment from '{audio_file}' ({segment['start']}-{segment['end']}s):"
+ )
print(f"Expected: {segment['transcription']}")
print(f"Actual: {transcription}")
print(f"Similarity: {similarity}")
@@ -60,7 +74,10 @@ def test_transcribe_segments(self):
print(f"\nAverage similarity for '{audio_file}': {avg_similarity}")
# Assert that the average similarity is above the threshold
- self.assertGreaterEqual(avg_similarity, 0.7) # Adjust the threshold as needed
+ self.assertGreaterEqual(
+ avg_similarity, 0.7
+ ) # Adjust the threshold as needed
+
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/test/requirements.txt b/test/requirements.txt
new file mode 100644
index 0000000..d4ad60d
--- /dev/null
+++ b/test/requirements.txt
@@ -0,0 +1,3 @@
+pydub
+sentence_transformers
+websockets
\ No newline at end of file
diff --git a/test/server/test_server.py b/test/server/test_server.py
index 23b1734..0155fff 100644
--- a/test/server/test_server.py
+++ b/test/server/test_server.py
@@ -12,6 +12,7 @@
from src.vad.vad_factory import VADFactory
from src.asr.asr_factory import ASRFactory
+
class TestServer(unittest.TestCase):
"""
Test suite for testing the Server class responsible for real-time audio transcription.
@@ -27,11 +28,12 @@ class TestServer(unittest.TestCase):
test_server_response: Tests the server's response accuracy by comparing received and expected transcriptions.
load_annotations: Loads transcription annotations for comparison with server responses.
"""
+
@classmethod
def setUpClass(cls):
# Use an environment variable to get the ASR model type
- cls.asr_type = os.getenv('ASR_TYPE', 'faster_whisper')
- cls.vad_type = os.getenv('VAD_TYPE', 'pyannote')
+ cls.asr_type = os.getenv("ASR_TYPE", "faster_whisper")
+ cls.vad_type = os.getenv("VAD_TYPE", "pyannote")
def setUp(self):
"""
@@ -42,10 +44,16 @@ def setUp(self):
"""
self.vad_pipeline = VADFactory.create_vad_pipeline(self.vad_type)
self.asr_pipeline = ASRFactory.create_asr_pipeline(self.asr_type)
- self.server = Server(self.vad_pipeline, self.asr_pipeline, host='127.0.0.1', port=8767)
- self.annotations_path = os.path.join(os.path.dirname(__file__), "../audio_files/annotations.json")
+ self.server = Server(
+ self.vad_pipeline, self.asr_pipeline, host="127.0.0.1", port=8767
+ )
+ self.annotations_path = os.path.join(
+ os.path.dirname(__file__), "../audio_files/annotations.json"
+ )
self.received_transcriptions = []
- self.similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
+ self.similarity_model = SentenceTransformer(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ )
async def receive_transcriptions(self, websocket):
"""
@@ -57,9 +65,11 @@ async def receive_transcriptions(self, websocket):
try:
while True:
transcription_str = await websocket.recv()
- transcription = json.loads(transcription_str)
- self.received_transcriptions.append(transcription['text'])
- print(f"Received transcription: {transcription['text']}, processing time: {transcription['processing_time']}")
+ transcription = json.loads(transcription_str)
+ self.received_transcriptions.append(transcription["text"])
+ print(
+ f"Received transcription: {transcription['text']}, processing time: {transcription['processing_time']}"
+ )
except websockets.exceptions.ConnectionClosed:
pass # Expected when server closes the connection
@@ -72,20 +82,20 @@ async def mock_client(self, audio_file):
Args:
audio_file (str): Path to the audio file to be sent to the server.
"""
- uri = "ws://127.0.0.1:8767"
+ uri = "ws://127.0.0.1/ws"
async with websockets.connect(uri) as websocket:
# Start receiving transcriptions in a separate task
receive_task = asyncio.create_task(self.receive_transcriptions(websocket))
# Stream the entire audio file in chunks
- with open(audio_file, 'rb') as file:
+ with open(audio_file, "rb") as file:
audio = AudioSegment.from_file(file, format="wav")
-
+
for i in range(0, len(audio), 250): # 4000 samples = 250 ms at 16000 Hz
- chunk = audio[i:i+250]
+ chunk = audio[i : i + 250]
await websocket.send(chunk.raw_data)
await asyncio.sleep(0.25) # Wait for the chunk duration
-
+
# Stream 10 seconds of silence
silence = AudioSegment.silent(duration=10000)
await websocket.send(silence.raw_data)
@@ -107,17 +117,27 @@ def test_server_response(self):
annotations = self.load_annotations()
for audio_file_name, data in annotations.items():
- audio_file_path = os.path.join(os.path.dirname(__file__), f"../audio_files/{audio_file_name}")
+ audio_file_path = os.path.join(
+ os.path.dirname(__file__), f"../audio_files/{audio_file_name}"
+ )
# Run the mock client for each audio file
- asyncio.get_event_loop().run_until_complete(self.mock_client(audio_file_path))
+ asyncio.get_event_loop().run_until_complete(
+ self.mock_client(audio_file_path)
+ )
# Compare received transcriptions with expected transcriptions
- expected_transcriptions = ' '.join([seg["transcription"] for seg in data['segments']])
- received_transcriptions = ' '.join(self.received_transcriptions)
-
- embedding_1 = self.similarity_model.encode(expected_transcriptions.lower().strip(), convert_to_tensor=True)
- embedding_2 = self.similarity_model.encode(received_transcriptions.lower().strip(), convert_to_tensor=True)
+ expected_transcriptions = " ".join(
+ [seg["transcription"] for seg in data["segments"]]
+ )
+ received_transcriptions = " ".join(self.received_transcriptions)
+
+ embedding_1 = self.similarity_model.encode(
+ expected_transcriptions.lower().strip(), convert_to_tensor=True
+ )
+ embedding_2 = self.similarity_model.encode(
+ received_transcriptions.lower().strip(), convert_to_tensor=True
+ )
similarity = util.pytorch_cos_sim(embedding_1, embedding_2).item()
# Print summary before assertion
@@ -137,8 +157,9 @@ def load_annotations(self):
Returns:
dict: A dictionary containing expected transcriptions for test audio files.
"""
- with open(self.annotations_path, 'r') as file:
+ with open(self.annotations_path, "r") as file:
return json.load(file)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
unittest.main()
diff --git a/test/vad/test_pyannote_vad.py b/test/vad/test_pyannote_vad.py
index 807750a..22a98f8 100644
--- a/test/vad/test_pyannote_vad.py
+++ b/test/vad/test_pyannote_vad.py
@@ -8,50 +8,72 @@
from src.vad.pyannote_vad import PyannoteVAD
from src.client import Client
+
class TestPyannoteVAD(unittest.TestCase):
def setUp(self):
self.vad = PyannoteVAD()
- self.annotations_path = os.path.join(os.path.dirname(__file__), "../audio_files/annotations.json")
+ self.annotations_path = os.path.join(
+ os.path.dirname(__file__), "../audio_files/annotations.json"
+ )
self.client = Client("test_client", 16000, 2) # Example client
def load_annotations(self):
- with open(self.annotations_path, 'r') as file:
+ with open(self.annotations_path, "r") as file:
return json.load(file)
def test_detect_activity(self):
annotations = self.load_annotations()
for audio_file, data in annotations.items():
- audio_file_path = os.path.join(os.path.dirname(__file__), f"../audio_files/{audio_file}")
+ audio_file_path = os.path.join(
+ os.path.dirname(__file__), f"../audio_files/{audio_file}"
+ )
for annotated_segment in data["segments"]:
# Load the specific audio segment for VAD
- audio_segment = self.get_audio_segment(audio_file_path, annotated_segment["start"], annotated_segment["end"])
+ audio_segment = self.get_audio_segment(
+ audio_file_path,
+ annotated_segment["start"],
+ annotated_segment["end"],
+ )
self.client.scratch_buffer = bytearray(audio_segment.raw_data)
vad_results = asyncio.run(self.vad.detect_activity(self.client))
# Adjust VAD-detected times by adding the start time of the annotated segment
- adjusted_vad_results = [{"start": segment["start"] + annotated_segment["start"],
- "end": segment["end"] + annotated_segment["start"]}
- for segment in vad_results]
+ adjusted_vad_results = [
+ {
+ "start": segment["start"] + annotated_segment["start"],
+ "end": segment["end"] + annotated_segment["start"],
+ }
+ for segment in vad_results
+ ]
- detected_segments = [segment for segment in adjusted_vad_results if
- segment["start"] <= annotated_segment["start"] + 1.0 and
- segment["end"] <= annotated_segment["end"] + 2.0]
+ detected_segments = [
+ segment
+ for segment in adjusted_vad_results
+ if segment["start"] <= annotated_segment["start"] + 1.0
+ and segment["end"] <= annotated_segment["end"] + 2.0
+ ]
# Print formatted information about the test
- print(f"\nTesting segment from '{audio_file}': Annotated Start: {annotated_segment['start']}, Annotated End: {annotated_segment['end']}")
+ print(
+ f"\nTesting segment from '{audio_file}': Annotated Start: {annotated_segment['start']}, Annotated End: {annotated_segment['end']}"
+ )
print(f"VAD segments: {adjusted_vad_results}")
print(f"Overlapping, Detected segments: {detected_segments}")
# Assert that at least one detected segment meets the condition
- self.assertTrue(len(detected_segments) > 0, "No detected segment matches the annotated segment")
+ self.assertTrue(
+ len(detected_segments) > 0,
+ "No detected segment matches the annotated segment",
+ )
def get_audio_segment(self, file_path, start, end):
- with open(file_path, 'rb') as file:
+ with open(file_path, "rb") as file:
audio = AudioSegment.from_file(file, format="wav")
- return audio[start * 1000:end * 1000] # pydub works in milliseconds
+ return audio[start * 1000 : end * 1000] # pydub works in milliseconds
+
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()