diff --git a/Dockerfile b/Dockerfile index 1a0a11d..a5a4c57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,27 +1,29 @@ # Use an NVIDIA CUDA base image with Python 3 -FROM nvidia/cuda:11.6.2-base-ubuntu20.04 +FROM nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04 # Set the working directory in the container WORKDIR /usr/src/app -# Copy the requirements.txt file first to leverage Docker cache -COPY requirements.txt ./ - # Avoid interactive prompts from apt-get ENV DEBIAN_FRONTEND=noninteractive -# Install any needed packages specified in requirements.txt -RUN apt-get update && apt-get install -y python3-pip libsndfile1 ffmpeg && \ - pip3 install --no-cache-dir -r requirements.txt +# Install any needed packages +RUN apt-get update && \ + apt-get install -y python3-pip libsndfile1 ffmpeg && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* -# Reset the frontend (not necessary in newer Docker versions) -ENV DEBIAN_FRONTEND=newt +# Copy the requirements.txt file +COPY requirements.txt ./ + +# Install any needed packages specified in requirements.txt +RUN pip3 install --no-cache-dir -r requirements.txt # Copy the rest of your application's code COPY . . -# Make port 8765 available to the world outside this container -EXPOSE 8765 +# Make port 80 available to the world outside this container +EXPOSE 80 # Define environment variable ENV NAME VoiceStreamAI @@ -30,5 +32,4 @@ ENV NAME VoiceStreamAI ENTRYPOINT ["python3", "-m", "src.main"] # Provide a default command (can be overridden at runtime) -CMD ["--host", "0.0.0.0", "--port", "8765"] - +CMD ["--host", "0.0.0.0", "--port", "80", "--static-path", "./src/static"] diff --git a/LICENSE.txt b/LICENSE.txt index 84609a6..c13f991 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 Alessandro Saccoia +Copyright (c) 2024 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index b44062c..90a044f 100644 --- a/README.md +++ b/README.md @@ -13,18 +13,15 @@ VoiceStreamAI is a Python 3 -based server and JavaScript client solution that en - Customizable audio chunk processing strategies. - Support for multilingual transcription. -## Demo Video +## Demo +[View Demo Video](https://raw.githubusercontent.com/TyreseDev/VoiceStreamAI/main/img/voicestreamai_test.mp4) -https://github.com/alesaccoia/VoiceStreamAI/assets/1385023/9b5f2602-fe0b-4c9d-af9e-4662e42e23df - -## Demo Client - -![Client Demo](/img/client.png "Client Demo") +![Demo Image](https://raw.githubusercontent.com/TyreseDev/VoiceStreamAI/main/img/client.png) ## Running with Docker -This will not guide you in detail on how to use CUDA in docker, see for example [here](https://medium.com/@kevinsjy997/configure-docker-to-use-local-gpu-for-training-ml-models-70980168ec9b). +This will not guide you in detail on how to use CUDA in docker, see for example [here](https://medium.com/@kevinsjy997/configure-docker-to-use-local-gpu-for-training-ml-models-70980168ec9b). Still, these are the commands for Linux: @@ -52,13 +49,13 @@ After getting your VAD token (see next sections) run: sudo docker volume create huggingface_models -sudo docker run --gpus all -p 8765:8765 -v huggingface_models:/root/.cache/huggingface -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai +sudo docker run --gpus all -p 80:80 -v huggingface_models:/root/.cache/huggingface -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai ``` The "volume" stuff will allow you not to re-download the huggingface models each time you re-run the container. If you don't need this, just use: ```bash -sudo docker run --gpus all -p 8765:8765 -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai +sudo docker run --gpus all -p 80:80 -e PYANNOTE_AUTH_TOKEN='VAD_TOKEN_HERE' voicestreamai ``` ## Normal, Manual Installation @@ -92,7 +89,7 @@ The VoiceStreamAI server can be customized through command line arguments, allow - `--asr-type`: Specifies the type of Automatic Speech Recognition (ASR) pipeline to use (default: `faster_whisper`). - `--asr-args`: A JSON string containing additional arguments for the ASR pipeline (one can for example change `model_name` for whisper) - `--host`: Sets the host address for the WebSocket server (default: `127.0.0.1`). -- `--port`: Sets the port on which the server listens (default: `8765`). +- `--port`: Sets the port on which the server listens (default: `80`). For running the server with the standard configuration: @@ -103,7 +100,7 @@ For running the server with the standard configuration: python3 -m src.main --vad-args '{"auth_token": "vad token here"}' ``` -You can see all the command line options with the command: +You can see all the command line options with the command: ```bash python3 -m src.main --help @@ -112,13 +109,12 @@ python3 -m src.main --help ## Client Usage 1. Open the `client/VoiceStreamAI_Client.html` file in a web browser. -2. Enter the WebSocket address (default is `ws://localhost:8765`). +2. Enter the WebSocket address (default is `ws://localhost/ws`). 3. Configure the audio chunk length and offset. See below. 4. Select the language for transcription. 5. Click 'Connect' to establish a WebSocket connection. 6. Use 'Start Streaming' and 'Stop Streaming' to control audio capture. - ## Technology Overview - **Python Server**: Manages WebSocket connections, processes audio streams, and handles voice activity detection and transcription. @@ -207,10 +203,8 @@ Please make sure that the end variables are in place for example for the VAD aut ### Dependence on Audio Files -Currently, VoiceStreamAI processes audio by saving chunks to files and then running these files through the models. +Currently, VoiceStreamAI processes audio by saving chunks to files and then running these files through the models. ## Contributors -- Alessandro Saccoia - [alessandro.saccoia@gmail.com](mailto:alessandro.saccoia@gmail.com) - This project is open for contributions. Feel free to fork the repository and submit pull requests. diff --git a/client/VoiceStreamAI_Client.html b/client/VoiceStreamAI_Client.html deleted file mode 100644 index a3c5cb3..0000000 --- a/client/VoiceStreamAI_Client.html +++ /dev/null @@ -1,133 +0,0 @@ - - - - - - Audio Stream to WebSocket Server - - - - -

Transcribe a Web Audio Stream with Huggingface VAD + Whisper

-
-
- - -
-
- - -
-
-
- - -
-
- - -
-
-
- - -
- -
- - -
-
-
WebSocket: Not Connected
-
Detected Language: Undefined
-
Last Processing Time: Undefined
- - diff --git a/client/index.html b/client/index.html new file mode 100644 index 0000000..c23be91 --- /dev/null +++ b/client/index.html @@ -0,0 +1,147 @@ + + + + + + Audio Stream to WebSocket Server + + + + +

Transcribe a Web Audio Stream with PyAnnote + Whisper

+
+
+ + +
+
+ + +
+
+
+ + +
+
+ + +
+
+
+ + +
+ +
+ + +
+
+
WebSocket: Not Connected
+
Detected Language: Undefined
+
Last Processing Time: Undefined
+ + diff --git a/client/utils.js b/client/utils.js index abdaf13..48756f1 100644 --- a/client/utils.js +++ b/client/utils.js @@ -1,8 +1,6 @@ /** * VoiceStreamAI Client - WebSocket-based real-time transcription * - * Contributor: - * - Alessandro Saccoia - alessandro.saccoia@gmail.com */ let websocket; @@ -10,207 +8,231 @@ let context; let processor; let globalStream; -const websocket_uri = 'ws://localhost:8765'; const bufferSize = 4096; let isRecording = false; function initWebSocket() { - const websocketAddress = document.getElementById('websocketAddress').value; - chunk_length_seconds = document.getElementById('chunk_length_seconds').value; - chunk_offset_seconds = document.getElementById('chunk_offset_seconds').value; - const selectedLanguage = document.getElementById('languageSelect').value; - language = selectedLanguage !== 'multilingual' ? selectedLanguage : null; - - if (!websocketAddress) { - console.log("WebSocket address is required."); - return; - } - - websocket = new WebSocket(websocketAddress); - websocket.onopen = () => { - console.log("WebSocket connection established"); - document.getElementById("webSocketStatus").textContent = 'Connected'; - document.getElementById('startButton').disabled = false; - }; - websocket.onclose = event => { - console.log("WebSocket connection closed", event); - document.getElementById("webSocketStatus").textContent = 'Not Connected'; - document.getElementById('startButton').disabled = true; - document.getElementById('stopButton').disabled = true; - }; - websocket.onmessage = event => { - console.log("Message from server:", event.data); - const transcript_data = JSON.parse(event.data); - updateTranscription(transcript_data); - }; + const websocketAddress = document.getElementById("websocketAddress").value; + chunk_length_seconds = document.getElementById("chunk_length_seconds").value; + chunk_offset_seconds = document.getElementById("chunk_offset_seconds").value; + const selectedLanguage = document.getElementById("languageSelect").value; + language = selectedLanguage !== "multilingual" ? selectedLanguage : null; + + if (!websocketAddress) { + console.log("WebSocket address is required."); + return; + } + + websocket = new WebSocket(websocketAddress); + websocket.onopen = () => { + console.log("WebSocket connection established"); + document.getElementById("webSocketStatus").textContent = "Connected"; + document.getElementById("startButton").disabled = false; + }; + websocket.onclose = (event) => { + console.log("WebSocket connection closed", event); + document.getElementById("webSocketStatus").textContent = "Not Connected"; + document.getElementById("startButton").disabled = true; + document.getElementById("stopButton").disabled = true; + }; + websocket.onmessage = (event) => { + console.log("Message from server:", event.data); + const transcript_data = JSON.parse(event.data); + updateTranscription(transcript_data); + }; } function updateTranscription(transcript_data) { - const transcriptionDiv = document.getElementById('transcription'); - const languageDiv = document.getElementById('detected_language'); - - if (transcript_data['words'] && transcript_data['words'].length > 0) { - // Append words with color based on their probability - transcript_data['words'].forEach(wordData => { - const span = document.createElement('span'); - const probability = wordData['probability']; - span.textContent = wordData['word'] + ' '; - - // Set the color based on the probability - if (probability > 0.9) { - span.style.color = 'green'; - } else if (probability > 0.6) { - span.style.color = 'orange'; - } else { - span.style.color = 'red'; - } - - transcriptionDiv.appendChild(span); - }); - - // Add a new line at the end - transcriptionDiv.appendChild(document.createElement('br')); - } else { - // Fallback to plain text - transcriptionDiv.textContent += transcript_data['text'] + '\n'; - } - - // Update the language information - if (transcript_data['language'] && transcript_data['language_probability']) { - languageDiv.textContent = transcript_data['language'] + ' (' + transcript_data['language_probability'].toFixed(2) + ')'; - } - - // Update the processing time, if available - const processingTimeDiv = document.getElementById('processing_time'); - if (transcript_data['processing_time']) { - processingTimeDiv.textContent = 'Processing time: ' + transcript_data['processing_time'].toFixed(2) + ' seconds'; - } + const transcriptionDiv = document.getElementById("transcription"); + const languageDiv = document.getElementById("detected_language"); + + if (transcript_data.words && transcript_data.words.length > 0) { + // Append words with color based on their probability + // biome-ignore lint/complexity/noForEach: + transcript_data.words.forEach((wordData) => { + const span = document.createElement("span"); + const probability = wordData.probability; + span.textContent = `${wordData.word} `; + + // Set the color based on the probability + if (probability > 0.9) { + span.style.color = "green"; + } else if (probability > 0.6) { + span.style.color = "orange"; + } else { + span.style.color = "red"; + } + + transcriptionDiv.appendChild(span); + }); + + // Add a new line at the end + transcriptionDiv.appendChild(document.createElement("br")); + } else { + // Fallback to plain text + transcriptionDiv.textContent += `${transcript_data.text}\n`; + } + + // Update the language information + if (transcript_data.language && transcript_data.language_probability) { + languageDiv.textContent = `${ + transcript_data.language + } (${transcript_data.language_probability.toFixed(2)})`; + } + + // Update the processing time, if available + const processingTimeDiv = document.getElementById("processing_time"); + if (transcript_data.processing_time) { + processingTimeDiv.textContent = `Processing time: ${transcript_data.processing_time.toFixed( + 2, + )} seconds`; + } } - function startRecording() { - if (isRecording) return; - isRecording = true; - - const AudioContext = window.AudioContext || window.webkitAudioContext; - context = new AudioContext(); - - navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => { - globalStream = stream; - const input = context.createMediaStreamSource(stream); - processor = context.createScriptProcessor(bufferSize, 1, 1); - processor.onaudioprocess = e => processAudio(e); - input.connect(processor); - processor.connect(context.destination); - - sendAudioConfig(); - }).catch(error => console.error('Error accessing microphone', error)); - - // Disable start button and enable stop button - document.getElementById('startButton').disabled = true; - document.getElementById('stopButton').disabled = false; + if (isRecording) return; + isRecording = true; + + const AudioContext = window.AudioContext || window.webkitAudioContext; + context = new AudioContext(); + + navigator.mediaDevices + .getUserMedia({ audio: true }) + .then((stream) => { + globalStream = stream; + const input = context.createMediaStreamSource(stream); + processor = context.createScriptProcessor(bufferSize, 1, 1); + processor.onaudioprocess = (e) => processAudio(e); + input.connect(processor); + processor.connect(context.destination); + + sendAudioConfig(); + }) + .catch((error) => console.error("Error accessing microphone", error)); + + // Disable start button and enable stop button + document.getElementById("startButton").disabled = true; + document.getElementById("stopButton").disabled = false; } function stopRecording() { - if (!isRecording) return; - isRecording = false; - - if (globalStream) { - globalStream.getTracks().forEach(track => track.stop()); - } - if (processor) { - processor.disconnect(); - processor = null; - } - if (context) { - context.close().then(() => context = null); - } - document.getElementById('startButton').disabled = false; - document.getElementById('stopButton').disabled = true; + if (!isRecording) return; + isRecording = false; + + if (globalStream) { + // biome-ignore lint/complexity/noForEach: + globalStream.getTracks().forEach((track) => track.stop()); + } + if (processor) { + processor.disconnect(); + processor = null; + } + if (context) { + // biome-ignore lint/suspicious/noAssignInExpressions: + context.close().then(() => (context = null)); + } + document.getElementById("startButton").disabled = false; + document.getElementById("stopButton").disabled = true; } function sendAudioConfig() { - let selectedStrategy = document.getElementById('bufferingStrategySelect').value; - let processingArgs = {}; - - if (selectedStrategy === 'silence_at_end_of_chunk') { - processingArgs = { - chunk_length_seconds: parseFloat(document.getElementById('chunk_length_seconds').value), - chunk_offset_seconds: parseFloat(document.getElementById('chunk_offset_seconds').value) - }; - } - - const audioConfig = { - type: 'config', - data: { - sampleRate: context.sampleRate, - bufferSize: bufferSize, - channels: 1, // Assuming mono channel - language: language, - processing_strategy: selectedStrategy, - processing_args: processingArgs - } + const selectedStrategy = document.getElementById( + "bufferingStrategySelect", + ).value; + let processingArgs = {}; + + if (selectedStrategy === "silence_at_end_of_chunk") { + processingArgs = { + chunk_length_seconds: Number.parseFloat( + document.getElementById("chunk_length_seconds").value, + ), + chunk_offset_seconds: Number.parseFloat( + document.getElementById("chunk_offset_seconds").value, + ), }; - - websocket.send(JSON.stringify(audioConfig)); + } + + const audioConfig = { + type: "config", + data: { + sampleRate: context.sampleRate, + bufferSize: bufferSize, + channels: 1, // Assuming mono channel + language: language, + processing_strategy: selectedStrategy, + processing_args: processingArgs, + }, + }; + + websocket.send(JSON.stringify(audioConfig)); } function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) { - if (inputSampleRate === outputSampleRate) { - return buffer; + if (inputSampleRate === outputSampleRate) { + return buffer; + } + const sampleRateRatio = inputSampleRate / outputSampleRate; + const newLength = Math.round(buffer.length / sampleRateRatio); + const result = new Float32Array(newLength); + let offsetResult = 0; + let offsetBuffer = 0; + while (offsetResult < result.length) { + const nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio); + let accum = 0; + let count = 0; + for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) { + accum += buffer[i]; + count++; } - var sampleRateRatio = inputSampleRate / outputSampleRate; - var newLength = Math.round(buffer.length / sampleRateRatio); - var result = new Float32Array(newLength); - var offsetResult = 0; - var offsetBuffer = 0; - while (offsetResult < result.length) { - var nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio); - var accum = 0, count = 0; - for (var i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) { - accum += buffer[i]; - count++; - } - result[offsetResult] = accum / count; - offsetResult++; - offsetBuffer = nextOffsetBuffer; - } - return result; + result[offsetResult] = accum / count; + offsetResult++; + offsetBuffer = nextOffsetBuffer; + } + return result; } function processAudio(e) { - const inputSampleRate = context.sampleRate; - const outputSampleRate = 16000; // Target sample rate - - const left = e.inputBuffer.getChannelData(0); - const downsampledBuffer = downsampleBuffer(left, inputSampleRate, outputSampleRate); - const audioData = convertFloat32ToInt16(downsampledBuffer); - - if (websocket && websocket.readyState === WebSocket.OPEN) { - websocket.send(audioData); - } + const inputSampleRate = context.sampleRate; + const outputSampleRate = 16000; // Target sample rate + + const left = e.inputBuffer.getChannelData(0); + const downsampledBuffer = downsampleBuffer( + left, + inputSampleRate, + outputSampleRate, + ); + const audioData = convertFloat32ToInt16(downsampledBuffer); + + if (websocket && websocket.readyState === WebSocket.OPEN) { + websocket.send(audioData); + } } function convertFloat32ToInt16(buffer) { - let l = buffer.length; - const buf = new Int16Array(l); - while (l--) { - buf[l] = Math.min(1, buffer[l]) * 0x7FFF; - } - return buf.buffer; + let l = buffer.length; + const buf = new Int16Array(l); + while (l--) { + buf[l] = Math.min(1, buffer[l]) * 0x7fff; + } + return buf.buffer; } -// Initialize WebSocket on page load -// window.onload = initWebSocket; - function toggleBufferingStrategyPanel() { - var selectedStrategy = document.getElementById('bufferingStrategySelect').value; - if (selectedStrategy === 'silence_at_end_of_chunk') { - var panel = document.getElementById('silence_at_end_of_chunk_options_panel'); - panel.classList.remove('hidden'); - } else { - var panel = document.getElementById('silence_at_end_of_chunk_options_panel'); - panel.classList.add('hidden'); - } + const selectedStrategy = document.getElementById( + "bufferingStrategySelect", + ).value; + if (selectedStrategy === "silence_at_end_of_chunk") { + const panel = document.getElementById( + "silence_at_end_of_chunk_options_panel", + ); + panel.classList.remove("hidden"); + } else { + const panel = document.getElementById( + "silence_at_end_of_chunk_options_panel", + ); + panel.classList.add("hidden"); + } } +// // Initialize WebSocket on page load +// window.onload = initWebSocket; diff --git a/dockerhub_push.sh b/dockerhub_push.sh new file mode 100644 index 0000000..40cf311 --- /dev/null +++ b/dockerhub_push.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Define variables +# DOCKER_USERNAME="tyrese3915" +read -p "Enter your Docker Hub username: " DOCKER_USERNAME +IMAGE_NAME="voice-stream-ai" +TAG="latest" + +# Step 1: Log in to Docker Hub +echo "Logging in to Docker Hub..." +docker login --username "$DOCKER_USERNAME" +if [ $? -ne 0 ]; then + echo "Docker login failed. Exiting..." + exit 1 +fi + +# Step 2: Build your Docker image +echo "Building the Docker image..." +docker build -t "$IMAGE_NAME:$TAG" . +if [ $? -ne 0 ]; then + echo "Docker build failed. Exiting..." + exit 1 +fi + +# Step 3: Tag the image for your Docker Hub repository +echo "Tagging the image..." +docker tag "$IMAGE_NAME:$TAG" "$DOCKER_USERNAME/$IMAGE_NAME:$TAG" +if [ $? -ne 0 ]; then + echo "Docker tag failed. Exiting..." + exit 1 +fi + +# Step 4: Push the image to Docker Hub +echo "Pushing the image to Docker Hub..." +docker push "$DOCKER_USERNAME/$IMAGE_NAME:$TAG" +if [ $? -ne 0 ]; then + echo "Docker push failed. Exiting..." + exit 1 +fi + +echo "Docker image has been successfully pushed to Docker Hub." diff --git a/img/voicestreamai_test.mp4 b/img/voicestreamai_test.mp4 new file mode 100644 index 0000000..9d9c848 Binary files /dev/null and b/img/voicestreamai_test.mp4 differ diff --git a/requirements.txt b/requirements.txt index 966d009..b2eded3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -websockets -speechbrain pyannote-audio -asyncio sentence-transformers transformers faster-whisper +torchvision +aiohttp \ No newline at end of file diff --git a/src/asr/asr_factory.py b/src/asr/asr_factory.py index 433291a..bd8ba77 100644 --- a/src/asr/asr_factory.py +++ b/src/asr/asr_factory.py @@ -1,6 +1,7 @@ from .whisper_asr import WhisperASR from .faster_whisper_asr import FasterWhisperASR + class ASRFactory: @staticmethod def create_asr_pipeline(type, **kwargs): diff --git a/src/asr/faster_whisper_asr.py b/src/asr/faster_whisper_asr.py index 8cc1734..a27a606 100644 --- a/src/asr/faster_whisper_asr.py +++ b/src/asr/faster_whisper_asr.py @@ -110,15 +110,25 @@ class FasterWhisperASR(ASRInterface): def __init__(self, **kwargs): - model_size = kwargs.get('model_size', "large-v3") + model_size = kwargs.get("model_size", "large-v3") # Run on GPU with FP16 - self.asr_pipeline = WhisperModel(model_size, device="cuda", compute_type="float16") + self.asr_pipeline = WhisperModel( + model_size, device="cuda", compute_type="float16" + ) async def transcribe(self, client): - file_path = await save_audio_to_file(client.scratch_buffer, client.get_file_name()) + file_path = await save_audio_to_file( + client.scratch_buffer, client.get_file_name() + ) - language = None if client.config['language'] is None else language_codes.get(client.config['language'].lower()) - segments, info = self.asr_pipeline.transcribe(file_path, word_timestamps=True, language=language) + language = ( + None + if client.config["language"] is None + else language_codes.get(client.config["language"].lower()) + ) + segments, info = self.asr_pipeline.transcribe( + file_path, word_timestamps=True, language=language + ) segments = list(segments) # The transcription will actually run here. os.remove(file_path) @@ -126,13 +136,17 @@ async def transcribe(self, client): flattened_words = [word for segment in segments for word in segment.words] to_return = { - "language": info.language, - "language_probability": info.language_probability, - "text": ' '.join([s.text.strip() for s in segments]), - "words": - [ - {"word": w.word, "start": w.start, "end": w.end, "probability":w.probability} for w in flattened_words - ] + "language": info.language, + "language_probability": info.language_probability, + "text": " ".join([s.text.strip() for s in segments]), + "words": [ + { + "word": w.word, + "start": w.start, + "end": w.end, + "probability": w.probability, + } + for w in flattened_words + ], } return to_return - diff --git a/src/asr/whisper_asr.py b/src/asr/whisper_asr.py index 38d14f6..1c31da3 100644 --- a/src/asr/whisper_asr.py +++ b/src/asr/whisper_asr.py @@ -3,18 +3,23 @@ from src.audio_utils import save_audio_to_file import os + class WhisperASR(ASRInterface): def __init__(self, **kwargs): - model_name = kwargs.get('model_name', "openai/whisper-large-v3") + model_name = kwargs.get("model_name", "openai/whisper-large-v3") self.asr_pipeline = pipeline("automatic-speech-recognition", model=model_name) async def transcribe(self, client): - file_path = await save_audio_to_file(client.scratch_buffer, client.get_file_name()) - - if client.config['language'] is not None: - to_return = self.asr_pipeline(file_path, generate_kwargs={"language": client.config['language']})['text'] + file_path = await save_audio_to_file( + client.scratch_buffer, client.get_file_name() + ) + + if client.config["language"] is not None: + to_return = self.asr_pipeline( + file_path, generate_kwargs={"language": client.config["language"]} + )["text"] else: - to_return = self.asr_pipeline(file_path)['text'] + to_return = self.asr_pipeline(file_path)["text"] os.remove(file_path) @@ -22,6 +27,6 @@ async def transcribe(self, client): "language": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER", "language_probability": None, "text": to_return.strip(), - "words": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER" + "words": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER", } return to_return diff --git a/src/audio_utils.py b/src/audio_utils.py index 9b8d2b0..370e920 100644 --- a/src/audio_utils.py +++ b/src/audio_utils.py @@ -1,7 +1,10 @@ import wave import os -async def save_audio_to_file(audio_data, file_name, audio_dir="audio_files", audio_format="wav"): + +async def save_audio_to_file( + audio_data, file_name, audio_dir="audio_files", audio_format="wav" +): """ Saves the audio data to a file. @@ -14,10 +17,10 @@ async def save_audio_to_file(audio_data, file_name, audio_dir="audio_files", aud """ os.makedirs(audio_dir, exist_ok=True) - + file_path = os.path.join(audio_dir, file_name) - with wave.open(file_path, 'wb') as wav_file: + with wave.open(file_path, "wb") as wav_file: wav_file.setnchannels(1) # Assuming mono audio wav_file.setsampwidth(2) wav_file.setframerate(16000) diff --git a/src/buffering_strategy/buffering_strategies.py b/src/buffering_strategy/buffering_strategies.py index aca1da5..bb5cc5d 100644 --- a/src/buffering_strategy/buffering_strategies.py +++ b/src/buffering_strategy/buffering_strategies.py @@ -5,6 +5,7 @@ from .buffering_strategy_interface import BufferingStrategyInterface + class SilenceAtEndOfChunk(BufferingStrategyInterface): """ A buffering strategy that processes audio at the end of each chunk with silence detection. @@ -28,20 +29,20 @@ def __init__(self, client, **kwargs): """ self.client = client - self.chunk_length_seconds = os.environ.get('BUFFERING_CHUNK_LENGTH_SECONDS') + self.chunk_length_seconds = os.environ.get("BUFFERING_CHUNK_LENGTH_SECONDS") if not self.chunk_length_seconds: - self.chunk_length_seconds = kwargs.get('chunk_length_seconds') + self.chunk_length_seconds = kwargs.get("chunk_length_seconds") self.chunk_length_seconds = float(self.chunk_length_seconds) - self.chunk_offset_seconds = os.environ.get('BUFFERING_CHUNK_OFFSET_SECONDS') + self.chunk_offset_seconds = os.environ.get("BUFFERING_CHUNK_OFFSET_SECONDS") if not self.chunk_offset_seconds: - self.chunk_offset_seconds = kwargs.get('chunk_offset_seconds') + self.chunk_offset_seconds = kwargs.get("chunk_offset_seconds") self.chunk_offset_seconds = float(self.chunk_offset_seconds) - self.error_if_not_realtime = os.environ.get('ERROR_IF_NOT_REALTIME') + self.error_if_not_realtime = os.environ.get("ERROR_IF_NOT_REALTIME") if not self.error_if_not_realtime: - self.error_if_not_realtime = kwargs.get('error_if_not_realtime', False) - + self.error_if_not_realtime = kwargs.get("error_if_not_realtime", False) + self.processing_flag = False def process_audio(self, websocket, vad_pipeline, asr_pipeline): @@ -56,17 +57,26 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline): vad_pipeline: The voice activity detection pipeline. asr_pipeline: The automatic speech recognition pipeline. """ - chunk_length_in_bytes = self.chunk_length_seconds * self.client.sampling_rate * self.client.samples_width + chunk_length_in_bytes = ( + self.chunk_length_seconds + * self.client.sampling_rate + * self.client.samples_width + ) if len(self.client.buffer) > chunk_length_in_bytes: if self.processing_flag: - exit("Error in realtime processing: tried processing a new chunk while the previous one was still being processed") + print( + "Error in realtime processing: tried processing a new chunk while the previous one was still being processed" + ) + return self.client.scratch_buffer += self.client.buffer self.client.buffer.clear() self.processing_flag = True # Schedule the processing in a separate task - asyncio.create_task(self.process_audio_async(websocket, vad_pipeline, asr_pipeline)) - + asyncio.create_task( + self.process_audio_async(websocket, vad_pipeline, asr_pipeline) + ) + async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline): """ Asynchronously process audio for activity detection and transcription. @@ -78,7 +88,7 @@ async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline): websocket (Websocket): The WebSocket connection for sending transcriptions. vad_pipeline: The voice activity detection pipeline. asr_pipeline: The automatic speech recognition pipeline. - """ + """ start = time.time() vad_results = await vad_pipeline.detect_activity(self.client) @@ -88,15 +98,18 @@ async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline): self.processing_flag = False return - last_segment_should_end_before = ((len(self.client.scratch_buffer) / (self.client.sampling_rate * self.client.samples_width)) - self.chunk_offset_seconds) - if vad_results[-1]['end'] < last_segment_should_end_before: + last_segment_should_end_before = ( + len(self.client.scratch_buffer) + / (self.client.sampling_rate * self.client.samples_width) + ) - self.chunk_offset_seconds + if vad_results[-1]["end"] < last_segment_should_end_before: transcription = await asr_pipeline.transcribe(self.client) - if transcription['text'] != '': + if transcription["text"] != "": end = time.time() - transcription['processing_time'] = end - start - json_transcription = json.dumps(transcription) - await websocket.send(json_transcription) + transcription["processing_time"] = end - start + json_transcription = json.dumps(transcription) + await websocket.send_str(json_transcription) self.client.scratch_buffer.clear() self.client.increment_file_counter() - - self.processing_flag = False \ No newline at end of file + + self.processing_flag = False diff --git a/src/buffering_strategy/buffering_strategy_factory.py b/src/buffering_strategy/buffering_strategy_factory.py index 8d3f452..6f8131c 100644 --- a/src/buffering_strategy/buffering_strategy_factory.py +++ b/src/buffering_strategy/buffering_strategy_factory.py @@ -1,5 +1,6 @@ from .buffering_strategies import SilenceAtEndOfChunk + class BufferingStrategyFactory: """ A factory class for creating instances of different buffering strategies. diff --git a/src/client.py b/src/client.py index a735d9e..95ab70c 100644 --- a/src/client.py +++ b/src/client.py @@ -1,5 +1,6 @@ from src.buffering_strategy.buffering_strategy_factory import BufferingStrategyFactory + class Client: """ Represents a client connected to the VoiceStreamAI server. @@ -16,26 +17,29 @@ class Client: sampling_rate (int): The sampling rate of the audio data in Hz. samples_width (int): The width of each audio sample in bits. """ + def __init__(self, client_id, sampling_rate, samples_width): self.client_id = client_id self.buffer = bytearray() self.scratch_buffer = bytearray() - self.config = {"language": None, - "processing_strategy": "silence_at_end_of_chunk", - "processing_args": { - "chunk_length_seconds": 5, - "chunk_offset_seconds": 0.1 - } - } + self.config = { + "language": None, + "processing_strategy": "silence_at_end_of_chunk", + "processing_args": {"chunk_length_seconds": 5, "chunk_offset_seconds": 0.1}, + } self.file_counter = 0 self.total_samples = 0 self.sampling_rate = sampling_rate self.samples_width = samples_width - self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy(self.config['processing_strategy'], self, **self.config['processing_args']) + self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy( + self.config["processing_strategy"], self, **self.config["processing_args"] + ) def update_config(self, config_data): self.config.update(config_data) - self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy(self.config['processing_strategy'], self, **self.config['processing_args']) + self.buffering_strategy = BufferingStrategyFactory.create_buffering_strategy( + self.config["processing_strategy"], self, **self.config["processing_args"] + ) def append_audio_data(self, audio_data): self.buffer.extend(audio_data) @@ -49,6 +53,6 @@ def increment_file_counter(self): def get_file_name(self): return f"{self.client_id}_{self.file_counter}.wav" - + def process_audio(self, websocket, vad_pipeline, asr_pipeline): self.buffering_strategy.process_audio(websocket, vad_pipeline, asr_pipeline) diff --git a/src/main.py b/src/main.py index fd9f8fa..89b0841 100644 --- a/src/main.py +++ b/src/main.py @@ -2,20 +2,54 @@ import asyncio import json -from .server import Server +from src.server import Server from src.asr.asr_factory import ASRFactory from src.vad.vad_factory import VADFactory + def parse_args(): - parser = argparse.ArgumentParser(description="VoiceStreamAI Server: Real-time audio transcription using self-hosted Whisper and WebSocket") - parser.add_argument("--vad-type", type=str, default="pyannote", help="Type of VAD pipeline to use (e.g., 'pyannote')") - parser.add_argument("--vad-args", type=str, default='{"auth_token": "huggingface_token"}', help="JSON string of additional arguments for VAD pipeline") - parser.add_argument("--asr-type", type=str, default="faster_whisper", help="Type of ASR pipeline to use (e.g., 'whisper')") - parser.add_argument("--asr-args", type=str, default='{"model_size": "large-v3"}', help="JSON string of additional arguments for ASR pipeline") - parser.add_argument("--host", type=str, default="127.0.0.1", help="Host for the WebSocket server") - parser.add_argument("--port", type=int, default=8765, help="Port for the WebSocket server") + parser = argparse.ArgumentParser( + description="VoiceStreamAI Server: Real-time audio transcription using self-hosted Whisper and WebSocket" + ) + parser.add_argument( + "--vad-type", + type=str, + default="pyannote", + help="Type of VAD pipeline to use (e.g., 'pyannote')", + ) + parser.add_argument( + "--vad-args", + type=str, + default='{"auth_token": "huggingface_token"}', + help="JSON string of additional arguments for VAD pipeline", + ) + parser.add_argument( + "--asr-type", + type=str, + default="faster_whisper", + help="Type of ASR pipeline to use (e.g., 'whisper')", + ) + parser.add_argument( + "--asr-args", + type=str, + default='{"model_size": "large-v3"}', + help="JSON string of additional arguments for ASR pipeline", + ) + parser.add_argument( + "--host", type=str, default="127.0.0.1", help="Host for the WebSocket server" + ) + parser.add_argument( + "--port", type=int, default=80, help="Port for the WebSocket server" + ) + parser.add_argument( + "--static-path", + type=str, + default="./static", + help="Port for the WebSocket server", + ) return parser.parse_args() + def main(): args = parse_args() @@ -29,10 +63,19 @@ def main(): vad_pipeline = VADFactory.create_vad_pipeline(args.vad_type, **vad_args) asr_pipeline = ASRFactory.create_asr_pipeline(args.asr_type, **asr_args) - server = Server(vad_pipeline, asr_pipeline, host=args.host, port=args.port, sampling_rate=16000, samples_width=2) + server = Server( + vad_pipeline, + asr_pipeline, + host=args.host, + port=args.port, + sampling_rate=16000, + samples_width=2, + static_files_path=args.static_path, + ) asyncio.get_event_loop().run_until_complete(server.start()) asyncio.get_event_loop().run_forever() + if __name__ == "__main__": main() diff --git a/src/server.py b/src/server.py index a9028d4..ab2031d 100644 --- a/src/server.py +++ b/src/server.py @@ -1,11 +1,10 @@ -import websockets +from aiohttp import web import uuid import json -import asyncio -from src.audio_utils import save_audio_to_file from src.client import Client + class Server: """ Represents the WebSocket server for handling real-time audio transcription. @@ -23,7 +22,17 @@ class Server: samples_width (int): The width of each audio sample in bits. connected_clients (dict): A dictionary mapping client IDs to Client objects. """ - def __init__(self, vad_pipeline, asr_pipeline, host='localhost', port=8765, sampling_rate=16000, samples_width=2): + + def __init__( + self, + vad_pipeline, + asr_pipeline, + host="localhost", + port=80, + sampling_rate=16000, + samples_width=2, + static_files_path="./static", + ): self.vad_pipeline = vad_pipeline self.asr_pipeline = asr_pipeline self.host = host @@ -31,39 +40,57 @@ def __init__(self, vad_pipeline, asr_pipeline, host='localhost', port=8765, samp self.sampling_rate = sampling_rate self.samples_width = samples_width self.connected_clients = {} + self.static_files_path = static_files_path + self.app = web.Application() + self.setup_routes() + + def setup_routes(self): + self.app.router.add_get("/ws", self.websocket_handler) + self.app.router.add_get("/", self.index_handler) + self.app.router.add_static( + "/static/", path=self.static_files_path, name="static" + ) - async def handle_audio(self, client, websocket): - while True: - message = await websocket.recv() + async def index_handler(self, request): + return web.FileResponse(path=f"{self.static_files_path}/index.html") - if isinstance(message, bytes): - client.append_audio_data(message) - elif isinstance(message, str): - config = json.loads(message) - if config.get('type') == 'config': - client.update_config(config['data']) - continue - else: - print(f"Unexpected message type from {client.client_id}") + def setup_routes(self): + self.app.router.add_get("/ws", self.websocket_handler) + self.app.router.add_get("/", self.serve_index) + self.app.router.add_static("/", path=self.static_files_path, name="static") - # this is synchronous, any async operation is in BufferingStrategy - client.process_audio(websocket, self.vad_pipeline, self.asr_pipeline) + async def serve_index(self, request): + return web.FileResponse(path=f"{self.static_files_path}/index.html") + async def websocket_handler(self, request): + ws = web.WebSocketResponse() + await ws.prepare(request) - async def handle_websocket(self, websocket, path): client_id = str(uuid.uuid4()) client = Client(client_id, self.sampling_rate, self.samples_width) self.connected_clients[client_id] = client - print(f"Client {client_id} connected") + print(f"Client {client_id} connected.") + async for msg in ws: + if msg.type == web.WSMsgType.TEXT: + message_text = msg.data + if message_text == "close": + await ws.close() + else: + # Handle textual WebSocket messages + config = json.loads(message_text) + if config.get("type") == "config": + client.update_config(config["data"]) + elif msg.type == web.WSMsgType.BINARY: + # Handle binary WebSocket messages + client.append_audio_data(msg.data) + client.process_audio(ws, self.vad_pipeline, self.asr_pipeline) + elif msg.type == web.WSMsgType.ERROR: + print(f"WebSocket connection closed with exception {ws.exception()}") - try: - await self.handle_audio(client, websocket) - except websockets.ConnectionClosed as e: - print(f"Connection with {client_id} closed: {e}") - finally: - del self.connected_clients[client_id] + print(f"Client {client_id} disconnected.") + del self.connected_clients[client_id] + return ws def start(self): - print("Websocket server ready to accept connections") - return websockets.serve(self.handle_websocket, self.host, self.port) + web.run_app(self.app, host=self.host, port=self.port) diff --git a/src/static/index.html b/src/static/index.html new file mode 100644 index 0000000..c23be91 --- /dev/null +++ b/src/static/index.html @@ -0,0 +1,147 @@ + + + + + + Audio Stream to WebSocket Server + + + + +

Transcribe a Web Audio Stream with PyAnnote + Whisper

+
+
+ + +
+
+ + +
+
+
+ + +
+
+ + +
+
+
+ + +
+ +
+ + +
+
+
WebSocket: Not Connected
+
Detected Language: Undefined
+
Last Processing Time: Undefined
+ + diff --git a/src/static/utils.js b/src/static/utils.js new file mode 100644 index 0000000..bf362b3 --- /dev/null +++ b/src/static/utils.js @@ -0,0 +1,256 @@ +/** + * VoiceStreamAI Client - WebSocket-based real-time transcription + * + */ + +let websocket; +let context; +let processor; +let globalStream; + +const bufferSize = 4096; +let isRecording = false; + +function initWebSocket() { + const websocketAddress = document.getElementById("websocketAddress").value; + chunk_length_seconds = document.getElementById("chunk_length_seconds").value; + chunk_offset_seconds = document.getElementById("chunk_offset_seconds").value; + const selectedLanguage = document.getElementById("languageSelect").value; + language = selectedLanguage !== "multilingual" ? selectedLanguage : null; + + if (!websocketAddress) { + console.log("WebSocket address is required."); + return; + } + + websocket = new WebSocket(websocketAddress); + websocket.onopen = () => { + console.log("WebSocket connection established"); + document.getElementById("webSocketStatus").textContent = "Connected"; + document.getElementById("startButton").disabled = false; + }; + websocket.onclose = (event) => { + console.log("WebSocket connection closed", event); + document.getElementById("webSocketStatus").textContent = "Not Connected"; + document.getElementById("startButton").disabled = true; + document.getElementById("stopButton").disabled = true; + }; + websocket.onmessage = (event) => { + console.log("Message from server:", event.data); + const transcript_data = JSON.parse(event.data); + updateTranscription(transcript_data); + }; +} + +function updateTranscription(transcript_data) { + const transcriptionDiv = document.getElementById("transcription"); + const languageDiv = document.getElementById("detected_language"); + + if (transcript_data.words && transcript_data.words.length > 0) { + // Append words with color based on their probability + // biome-ignore lint/complexity/noForEach: + transcript_data.words.forEach((wordData) => { + const span = document.createElement("span"); + const probability = wordData.probability; + span.textContent = `${wordData.word} `; + + // Set the color based on the probability + if (probability > 0.9) { + span.style.color = "green"; + } else if (probability > 0.6) { + span.style.color = "orange"; + } else { + span.style.color = "red"; + } + + transcriptionDiv.appendChild(span); + }); + + // Add a new line at the end + transcriptionDiv.appendChild(document.createElement("br")); + } else { + // Fallback to plain text + transcriptionDiv.textContent += `${transcript_data.text}\n`; + } + + // Update the language information + if (transcript_data.language && transcript_data.language_probability) { + languageDiv.textContent = `${ + transcript_data.language + } (${transcript_data.language_probability.toFixed(2)})`; + } + + // Update the processing time, if available + const processingTimeDiv = document.getElementById("processing_time"); + if (transcript_data.processing_time) { + processingTimeDiv.textContent = `Processing time: ${transcript_data.processing_time.toFixed( + 2, + )} seconds`; + } +} + +function startRecording() { + if (isRecording) return; + isRecording = true; + + const AudioContext = window.AudioContext || window.webkitAudioContext; + context = new AudioContext(); + + navigator.mediaDevices + .getUserMedia({ audio: true }) + .then((stream) => { + globalStream = stream; + const input = context.createMediaStreamSource(stream); + processor = context.createScriptProcessor(bufferSize, 1, 1); + processor.onaudioprocess = (e) => processAudio(e); + input.connect(processor); + processor.connect(context.destination); + + sendAudioConfig(); + }) + .catch((error) => console.error("Error accessing microphone", error)); + + // Disable start button and enable stop button + document.getElementById("startButton").disabled = true; + document.getElementById("stopButton").disabled = false; +} + +function stopRecording() { + if (!isRecording) return; + isRecording = false; + + if (globalStream) { + // biome-ignore lint/complexity/noForEach: + globalStream.getTracks().forEach((track) => track.stop()); + } + if (processor) { + processor.disconnect(); + processor = null; + } + if (context) { + // biome-ignore lint/suspicious/noAssignInExpressions: + context.close().then(() => (context = null)); + } + document.getElementById("startButton").disabled = false; + document.getElementById("stopButton").disabled = true; +} + +function sendAudioConfig() { + const selectedStrategy = document.getElementById( + "bufferingStrategySelect", + ).value; + let processingArgs = {}; + + if (selectedStrategy === "silence_at_end_of_chunk") { + processingArgs = { + chunk_length_seconds: Number.parseFloat( + document.getElementById("chunk_length_seconds").value, + ), + chunk_offset_seconds: Number.parseFloat( + document.getElementById("chunk_offset_seconds").value, + ), + }; + } + + const audioConfig = { + type: "config", + data: { + sampleRate: context.sampleRate, + bufferSize: bufferSize, + channels: 1, // Assuming mono channel + language: language, + processing_strategy: selectedStrategy, + processing_args: processingArgs, + }, + }; + + websocket.send(JSON.stringify(audioConfig)); +} + +function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) { + if (inputSampleRate === outputSampleRate) { + return buffer; + } + const sampleRateRatio = inputSampleRate / outputSampleRate; + const newLength = Math.round(buffer.length / sampleRateRatio); + const result = new Float32Array(newLength); + let offsetResult = 0; + let offsetBuffer = 0; + while (offsetResult < result.length) { + const nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio); + let accum = 0; + let count = 0; + for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) { + accum += buffer[i]; + count++; + } + result[offsetResult] = accum / count; + offsetResult++; + offsetBuffer = nextOffsetBuffer; + } + return result; +} + +function processAudio(e) { + const inputSampleRate = context.sampleRate; + const outputSampleRate = 16000; // Target sample rate + + const left = e.inputBuffer.getChannelData(0); + const downsampledBuffer = downsampleBuffer( + left, + inputSampleRate, + outputSampleRate, + ); + const audioData = convertFloat32ToInt16(downsampledBuffer); + + if (websocket && websocket.readyState === WebSocket.OPEN) { + websocket.send(audioData); + } +} + +function convertFloat32ToInt16(buffer) { + let l = buffer.length; + const buf = new Int16Array(l); + while (l--) { + buf[l] = Math.min(1, buffer[l]) * 0x7fff; + } + return buf.buffer; +} + +function toggleBufferingStrategyPanel() { + const selectedStrategy = document.getElementById( + "bufferingStrategySelect", + ).value; + if (selectedStrategy === "silence_at_end_of_chunk") { + const panel = document.getElementById( + "silence_at_end_of_chunk_options_panel", + ); + panel.classList.remove("hidden"); + } else { + const panel = document.getElementById( + "silence_at_end_of_chunk_options_panel", + ); + panel.classList.add("hidden"); + } +} + +function getWebSocketUrl() { + if ( + window.location.protocol !== "https:" && + window.location.protocol !== "http:" + ) + return null; + const wsProtocol = window.location.protocol === "https:" ? "wss" : "ws"; + const wsUrl = `${wsProtocol}://${window.location.host}/ws`; + return wsUrl; +} + +// // Initialize WebSocket on page load +// window.onload = initWebSocket; + +window.onload = () => { + const url = getWebSocketUrl(); + document.getElementById("websocketAddress").value = + url ?? "ws://localhost/ws"; + initWebSocket(); +}; diff --git a/src/vad/pyannote_vad.py b/src/vad/pyannote_vad.py index 758e919..7469a6b 100644 --- a/src/vad/pyannote_vad.py +++ b/src/vad/pyannote_vad.py @@ -1,7 +1,6 @@ from os import remove import os -from pyannote.core import Segment from pyannote.audio import Model from pyannote.audio.pipelines import VoiceActivityDetection @@ -22,23 +21,35 @@ def __init__(self, **kwargs): model_name (str): The model name for Pyannote. auth_token (str, optional): Authentication token for Hugging Face. """ - - model_name = kwargs.get('model_name', "pyannote/segmentation") - auth_token = os.environ.get('PYANNOTE_AUTH_TOKEN') + model_name = kwargs.get("model_name", "pyannote/segmentation") + + auth_token = os.environ.get("PYANNOTE_AUTH_TOKEN") if not auth_token: - auth_token = kwargs.get('auth_token') - + auth_token = kwargs.get("auth_token") + if auth_token is None: - raise ValueError("Missing required env var in PYANNOTE_AUTH_TOKEN or argument in --vad-args: 'auth_token'") - - pyannote_args = kwargs.get('pyannote_args', {"onset": 0.5, "offset": 0.5, "min_duration_on": 0.3, "min_duration_off": 0.3}) + raise ValueError( + "Missing required env var in PYANNOTE_AUTH_TOKEN or argument in --vad-args: 'auth_token'" + ) + + pyannote_args = kwargs.get( + "pyannote_args", + { + "onset": 0.5, + "offset": 0.5, + "min_duration_on": 0.1, + "min_duration_off": 0.1, + }, + ) self.model = Model.from_pretrained(model_name, use_auth_token=auth_token) self.vad_pipeline = VoiceActivityDetection(segmentation=self.model) self.vad_pipeline.instantiate(pyannote_args) async def detect_activity(self, client): - audio_file_path = await save_audio_to_file(client.scratch_buffer, client.get_file_name()) + audio_file_path = await save_audio_to_file( + client.scratch_buffer, client.get_file_name() + ) vad_results = self.vad_pipeline(audio_file_path) remove(audio_file_path) vad_segments = [] diff --git a/src/vad/vad_factory.py b/src/vad/vad_factory.py index 600864c..e20d3f9 100644 --- a/src/vad/vad_factory.py +++ b/src/vad/vad_factory.py @@ -1,5 +1,6 @@ from .pyannote_vad import PyannoteVAD + class VADFactory: """ Factory for creating instances of VAD systems. diff --git a/test/asr/test_asr.py b/test/asr/test_asr.py index d9fafcf..b385c75 100644 --- a/test/asr/test_asr.py +++ b/test/asr/test_asr.py @@ -4,51 +4,65 @@ import asyncio from sentence_transformers import SentenceTransformer, util from pydub import AudioSegment -import argparse from src.asr.asr_factory import ASRFactory from src.client import Client + class TestWhisperASR(unittest.TestCase): @classmethod def setUpClass(cls): # Use an environment variable to get the ASR model type - cls.asr_type = os.getenv('ASR_TYPE', 'whisper') + cls.asr_type = os.getenv("ASR_TYPE", "whisper") def setUp(self): self.asr = ASRFactory.create_asr_pipeline(self.asr_type) - self.annotations_path = os.path.join(os.path.dirname(__file__), "../audio_files/annotations.json") + self.annotations_path = os.path.join( + os.path.dirname(__file__), "../audio_files/annotations.json" + ) self.client = Client("test_client", 16000, 2) # Example client - self.similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') + self.similarity_model = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2" + ) def load_annotations(self): - with open(self.annotations_path, 'r') as file: + with open(self.annotations_path, "r") as file: return json.load(file) def get_audio_segment(self, file_path, start, end): - with open(file_path, 'rb') as file: + with open(file_path, "rb") as file: audio = AudioSegment.from_file(file, format="wav") - return audio[start * 1000:end * 1000] # pydub works in milliseconds + return audio[start * 1000 : end * 1000] # pydub works in milliseconds def test_transcribe_segments(self): annotations = self.load_annotations() for audio_file, data in annotations.items(): - audio_file_path = os.path.join(os.path.dirname(__file__), f"../audio_files/{audio_file}") + audio_file_path = os.path.join( + os.path.dirname(__file__), f"../audio_files/{audio_file}" + ) similarities = [] for segment in data["segments"]: - audio_segment = self.get_audio_segment(audio_file_path, segment["start"], segment["end"]) + audio_segment = self.get_audio_segment( + audio_file_path, segment["start"], segment["end"] + ) self.client.scratch_buffer = bytearray(audio_segment.raw_data) - self.client.config['language'] = None + self.client.config["language"] = None transcription = asyncio.run(self.asr.transcribe(self.client))["text"] - embedding_1 = self.similarity_model.encode(transcription.lower().strip(), convert_to_tensor=True) - embedding_2 = self.similarity_model.encode(segment["transcription"].lower().strip(), convert_to_tensor=True) + embedding_1 = self.similarity_model.encode( + transcription.lower().strip(), convert_to_tensor=True + ) + embedding_2 = self.similarity_model.encode( + segment["transcription"].lower().strip(), convert_to_tensor=True + ) similarity = util.pytorch_cos_sim(embedding_1, embedding_2).item() similarities.append(similarity) - print(f"\nSegment from '{audio_file}' ({segment['start']}-{segment['end']}s):") + print( + f"\nSegment from '{audio_file}' ({segment['start']}-{segment['end']}s):" + ) print(f"Expected: {segment['transcription']}") print(f"Actual: {transcription}") print(f"Similarity: {similarity}") @@ -60,7 +74,10 @@ def test_transcribe_segments(self): print(f"\nAverage similarity for '{audio_file}': {avg_similarity}") # Assert that the average similarity is above the threshold - self.assertGreaterEqual(avg_similarity, 0.7) # Adjust the threshold as needed + self.assertGreaterEqual( + avg_similarity, 0.7 + ) # Adjust the threshold as needed + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/requirements.txt b/test/requirements.txt new file mode 100644 index 0000000..d4ad60d --- /dev/null +++ b/test/requirements.txt @@ -0,0 +1,3 @@ +pydub +sentence_transformers +websockets \ No newline at end of file diff --git a/test/server/test_server.py b/test/server/test_server.py index 23b1734..0155fff 100644 --- a/test/server/test_server.py +++ b/test/server/test_server.py @@ -12,6 +12,7 @@ from src.vad.vad_factory import VADFactory from src.asr.asr_factory import ASRFactory + class TestServer(unittest.TestCase): """ Test suite for testing the Server class responsible for real-time audio transcription. @@ -27,11 +28,12 @@ class TestServer(unittest.TestCase): test_server_response: Tests the server's response accuracy by comparing received and expected transcriptions. load_annotations: Loads transcription annotations for comparison with server responses. """ + @classmethod def setUpClass(cls): # Use an environment variable to get the ASR model type - cls.asr_type = os.getenv('ASR_TYPE', 'faster_whisper') - cls.vad_type = os.getenv('VAD_TYPE', 'pyannote') + cls.asr_type = os.getenv("ASR_TYPE", "faster_whisper") + cls.vad_type = os.getenv("VAD_TYPE", "pyannote") def setUp(self): """ @@ -42,10 +44,16 @@ def setUp(self): """ self.vad_pipeline = VADFactory.create_vad_pipeline(self.vad_type) self.asr_pipeline = ASRFactory.create_asr_pipeline(self.asr_type) - self.server = Server(self.vad_pipeline, self.asr_pipeline, host='127.0.0.1', port=8767) - self.annotations_path = os.path.join(os.path.dirname(__file__), "../audio_files/annotations.json") + self.server = Server( + self.vad_pipeline, self.asr_pipeline, host="127.0.0.1", port=8767 + ) + self.annotations_path = os.path.join( + os.path.dirname(__file__), "../audio_files/annotations.json" + ) self.received_transcriptions = [] - self.similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') + self.similarity_model = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2" + ) async def receive_transcriptions(self, websocket): """ @@ -57,9 +65,11 @@ async def receive_transcriptions(self, websocket): try: while True: transcription_str = await websocket.recv() - transcription = json.loads(transcription_str) - self.received_transcriptions.append(transcription['text']) - print(f"Received transcription: {transcription['text']}, processing time: {transcription['processing_time']}") + transcription = json.loads(transcription_str) + self.received_transcriptions.append(transcription["text"]) + print( + f"Received transcription: {transcription['text']}, processing time: {transcription['processing_time']}" + ) except websockets.exceptions.ConnectionClosed: pass # Expected when server closes the connection @@ -72,20 +82,20 @@ async def mock_client(self, audio_file): Args: audio_file (str): Path to the audio file to be sent to the server. """ - uri = "ws://127.0.0.1:8767" + uri = "ws://127.0.0.1/ws" async with websockets.connect(uri) as websocket: # Start receiving transcriptions in a separate task receive_task = asyncio.create_task(self.receive_transcriptions(websocket)) # Stream the entire audio file in chunks - with open(audio_file, 'rb') as file: + with open(audio_file, "rb") as file: audio = AudioSegment.from_file(file, format="wav") - + for i in range(0, len(audio), 250): # 4000 samples = 250 ms at 16000 Hz - chunk = audio[i:i+250] + chunk = audio[i : i + 250] await websocket.send(chunk.raw_data) await asyncio.sleep(0.25) # Wait for the chunk duration - + # Stream 10 seconds of silence silence = AudioSegment.silent(duration=10000) await websocket.send(silence.raw_data) @@ -107,17 +117,27 @@ def test_server_response(self): annotations = self.load_annotations() for audio_file_name, data in annotations.items(): - audio_file_path = os.path.join(os.path.dirname(__file__), f"../audio_files/{audio_file_name}") + audio_file_path = os.path.join( + os.path.dirname(__file__), f"../audio_files/{audio_file_name}" + ) # Run the mock client for each audio file - asyncio.get_event_loop().run_until_complete(self.mock_client(audio_file_path)) + asyncio.get_event_loop().run_until_complete( + self.mock_client(audio_file_path) + ) # Compare received transcriptions with expected transcriptions - expected_transcriptions = ' '.join([seg["transcription"] for seg in data['segments']]) - received_transcriptions = ' '.join(self.received_transcriptions) - - embedding_1 = self.similarity_model.encode(expected_transcriptions.lower().strip(), convert_to_tensor=True) - embedding_2 = self.similarity_model.encode(received_transcriptions.lower().strip(), convert_to_tensor=True) + expected_transcriptions = " ".join( + [seg["transcription"] for seg in data["segments"]] + ) + received_transcriptions = " ".join(self.received_transcriptions) + + embedding_1 = self.similarity_model.encode( + expected_transcriptions.lower().strip(), convert_to_tensor=True + ) + embedding_2 = self.similarity_model.encode( + received_transcriptions.lower().strip(), convert_to_tensor=True + ) similarity = util.pytorch_cos_sim(embedding_1, embedding_2).item() # Print summary before assertion @@ -137,8 +157,9 @@ def load_annotations(self): Returns: dict: A dictionary containing expected transcriptions for test audio files. """ - with open(self.annotations_path, 'r') as file: + with open(self.annotations_path, "r") as file: return json.load(file) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/test/vad/test_pyannote_vad.py b/test/vad/test_pyannote_vad.py index 807750a..22a98f8 100644 --- a/test/vad/test_pyannote_vad.py +++ b/test/vad/test_pyannote_vad.py @@ -8,50 +8,72 @@ from src.vad.pyannote_vad import PyannoteVAD from src.client import Client + class TestPyannoteVAD(unittest.TestCase): def setUp(self): self.vad = PyannoteVAD() - self.annotations_path = os.path.join(os.path.dirname(__file__), "../audio_files/annotations.json") + self.annotations_path = os.path.join( + os.path.dirname(__file__), "../audio_files/annotations.json" + ) self.client = Client("test_client", 16000, 2) # Example client def load_annotations(self): - with open(self.annotations_path, 'r') as file: + with open(self.annotations_path, "r") as file: return json.load(file) def test_detect_activity(self): annotations = self.load_annotations() for audio_file, data in annotations.items(): - audio_file_path = os.path.join(os.path.dirname(__file__), f"../audio_files/{audio_file}") + audio_file_path = os.path.join( + os.path.dirname(__file__), f"../audio_files/{audio_file}" + ) for annotated_segment in data["segments"]: # Load the specific audio segment for VAD - audio_segment = self.get_audio_segment(audio_file_path, annotated_segment["start"], annotated_segment["end"]) + audio_segment = self.get_audio_segment( + audio_file_path, + annotated_segment["start"], + annotated_segment["end"], + ) self.client.scratch_buffer = bytearray(audio_segment.raw_data) vad_results = asyncio.run(self.vad.detect_activity(self.client)) # Adjust VAD-detected times by adding the start time of the annotated segment - adjusted_vad_results = [{"start": segment["start"] + annotated_segment["start"], - "end": segment["end"] + annotated_segment["start"]} - for segment in vad_results] + adjusted_vad_results = [ + { + "start": segment["start"] + annotated_segment["start"], + "end": segment["end"] + annotated_segment["start"], + } + for segment in vad_results + ] - detected_segments = [segment for segment in adjusted_vad_results if - segment["start"] <= annotated_segment["start"] + 1.0 and - segment["end"] <= annotated_segment["end"] + 2.0] + detected_segments = [ + segment + for segment in adjusted_vad_results + if segment["start"] <= annotated_segment["start"] + 1.0 + and segment["end"] <= annotated_segment["end"] + 2.0 + ] # Print formatted information about the test - print(f"\nTesting segment from '{audio_file}': Annotated Start: {annotated_segment['start']}, Annotated End: {annotated_segment['end']}") + print( + f"\nTesting segment from '{audio_file}': Annotated Start: {annotated_segment['start']}, Annotated End: {annotated_segment['end']}" + ) print(f"VAD segments: {adjusted_vad_results}") print(f"Overlapping, Detected segments: {detected_segments}") # Assert that at least one detected segment meets the condition - self.assertTrue(len(detected_segments) > 0, "No detected segment matches the annotated segment") + self.assertTrue( + len(detected_segments) > 0, + "No detected segment matches the annotated segment", + ) def get_audio_segment(self, file_path, start, end): - with open(file_path, 'rb') as file: + with open(file_path, "rb") as file: audio = AudioSegment.from_file(file, format="wav") - return audio[start * 1000:end * 1000] # pydub works in milliseconds + return audio[start * 1000 : end * 1000] # pydub works in milliseconds + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()