Skip to content

Commit

Permalink
Merge pull request #25 from cbsiamlg/feature/AMLG-7129-liveness-healt…
Browse files Browse the repository at this point in the history
…h-update

feature/AMLG-7129-liveness-health-update
  • Loading branch information
beatgeek authored Mar 7, 2024
2 parents d9ad7da + 48a034b commit ea9f10f
Show file tree
Hide file tree
Showing 10 changed files with 827 additions and 644 deletions.
3 changes: 3 additions & 0 deletions app/faster_whisper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
WriteVTT,
WriteTSV,
WriteJSON,
WriteRawJSON,
)
from faster_whisper import WhisperModel

Expand Down Expand Up @@ -85,6 +86,8 @@ def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
WriteTSV(ResultWriter).write_result(result, file=file)
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file)
elif output == "raw_json":
WriteRawJSON(ResultWriter).write_result(result, file=file)
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file)
else:
Expand Down
76 changes: 52 additions & 24 deletions app/faster_whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,32 @@ def write_result(self, result: dict, file: TextIO):
print(segment.text.strip(), file=file, flush=True)


class WriteRawJSON(ResultWriter):
extension: str = "raw_json"
def write_result(self, result: dict, file: TextIO):
formatted_segments = self.format_segments(result)
result["segments"] = formatted_segments
json.dump(result, file, indent=2)

def format_segments(self, output):
segment = [
{
"id": 0,
"seek": 0,
"start": segment[2],
"end": segment[3],
"text": segment[4],
"tokens": segment[5],
"temperature": segment[6],
"avg_logprob": segment[7],
"compression_ratio": segment[8],
"no_speech_prob": segment[9],
}
for segment in output["segments"]
]
return segment


class WriteVTT(ResultWriter):
extension: str = "vtt"

Expand Down Expand Up @@ -119,28 +145,30 @@ def write_result(self, result: dict, file: TextIO):


def format_json(json_file):
text = json_file['text']
segments = [{
'id': 0,
'seek': 0,
'start': segment[2],
'end': segment[3],
'text': segment[4],
'tokens': segment[5],
'temperature': segment[6],
'avg_logprob': segment[7],
'compression_ratio': segment[8],
'no_speech_prob': segment[9],
'words': [{
'word': word[2],
'start': word[0],
'end': word[1],
'probability': word[3]
} for word in segment[10]]
} for segment in json_file['segments']]
output = {
"text": text,
"segments": segments,
"language": json_file["language"]
}
text = json_file["text"]
segments = [
{
"id": 0,
"seek": 0,
"start": segment[2],
"end": segment[3],
"text": segment[4],
"tokens": segment[5],
"temperature": segment[6],
"avg_logprob": segment[7],
"compression_ratio": segment[8],
"no_speech_prob": segment[9],
"words": [
{
"word": word[2],
"start": word[0],
"end": word[1],
"probability": word[3],
}
for word in segment[10]
],
}
for segment in json_file["segments"]
]
output = {"text": text, "segments": segments, "language": json_file["language"]}
return output
20 changes: 20 additions & 0 deletions app/tests/test_webservice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from starlette.testclient import TestClient
from webservice import app


@pytest.fixture
def client():
return TestClient(app)


def test_liveness(client):
response = client.get("/liveness")
assert response.status_code == 200
assert response.json() == {"status": "ok"}


def test_readiness(client):
response = client.get("/readiness")
assert response.status_code == 200
assert response.text == "OK"
158 changes: 111 additions & 47 deletions app/webservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import ffmpeg
from fastapi import FastAPI, File, UploadFile, Query, applications
from fastapi import Response, status
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.openapi.docs import get_swagger_ui_html
Expand All @@ -21,27 +22,25 @@
else:
from .openai_whisper.core import transcribe, language_detection

SAMPLE_RATE=16000
LANGUAGE_CODES=sorted(list(tokenizer.LANGUAGES.keys()))
SAMPLE_RATE = 16000
LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys()))

projectMetadata = importlib.metadata.metadata('whisper-asr-webservice')
projectMetadata = importlib.metadata.metadata("whisper-asr-webservice")
app = FastAPI(
title=projectMetadata['Name'].title().replace('-', ' '),
description=projectMetadata['Summary'],
version=projectMetadata['Version'],
contact={
"url": projectMetadata['Home-page']
},
title=projectMetadata["Name"].title().replace("-", " "),
description=projectMetadata["Summary"],
version=projectMetadata["Version"],
contact={"url": projectMetadata["Home-page"]},
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
license_info={
"name": "MIT License",
"url": projectMetadata['License']
}
license_info={"name": "MIT License", "url": projectMetadata["License"]},
)

assets_path = os.getcwd() + "/swagger-ui-assets"
if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
if path.exists(assets_path + "/swagger-ui.css") and path.exists(
assets_path + "/swagger-ui-bundle.js"
):
app.mount("/assets", StaticFiles(directory=assets_path), name="static")

def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args,
Expand All @@ -50,76 +49,141 @@ def swagger_monkey_patch(*args, **kwargs):
swagger_css_url="/assets/swagger-ui.css",
swagger_js_url="/assets/swagger-ui-bundle.js",
)

applications.get_swagger_ui_html = swagger_monkey_patch


@app.get("/", response_class=RedirectResponse, include_in_schema=False)
async def index():
"""
Redirects to the documentation page.
"""
return "/docs"


@app.post("/asr", tags=["Endpoints"])
def asr(
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
task: Union[str, None] = Query(
default="transcribe", enum=["transcribe", "translate"]
),
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
initial_prompt: Union[str, None] = Query(default=None),
audio_file: UploadFile = File(...),
encode : bool = Query(default=True, description="Encode audio first through ffmpeg"),
output : Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
word_timestamps : bool = Query(
default=False,
description="World level timestamps",
include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False)
)
encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
output: Union[str, None] = Query(
default="txt", enum=["txt", "vtt", "srt", "tsv", "json", "raw_json"]
),
word_timestamps: bool = Query(
default=False,
description="World level timestamps",
include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False),
),
):

start = time.time()
result = transcribe(load_audio(audio_file.file, encode), task, language, initial_prompt, word_timestamps, output)
result = transcribe(
load_audio(audio_file.file, encode),
task,
language,
initial_prompt,
word_timestamps,
output,
)
end = time.time()
end_time = end - start
logger.info(f"Transcription took {end_time} seconds")
content_disposition = f'attachment; filename="{audio_file.filename}.{output}"'
return StreamingResponse(
result,
media_type="text/plain",
result,
media_type="text/plain",
headers={
'Asr-Engine': ASR_ENGINE,
'Content-Disposition': f'attachment; filename="{audio_file.filename}.{output}"'
})
"Asr-Engine": ASR_ENGINE,
"Content-Disposition": content_disposition,
},
)


@app.post("/detect-language", tags=["Endpoints"])
def detect_language(
audio_file: UploadFile = File(...),
encode : bool = Query(default=True, description="Encode audio first through ffmpeg")
encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
):
"""
Endpoint for language detection.
Detects the language of the audio file.
Parameters:
- audio_file (UploadFile): The audio file for language detection.
- encode (bool, optional): Whether to encode the audio first via ffmpeg.
Defaults to True.
Returns:
- dict: A dictionary containing the detected language and language code.
"""
detected_lang_code = language_detection(load_audio(audio_file.file, encode))
return { "detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code" : detected_lang_code }
return {
"detected_language": tokenizer.LANGUAGES[detected_lang_code],
"language_code": detected_lang_code,
}


def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
"""
Open an audio file object and read as mono waveform, resampling as necessary.
Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object
Parameters
----------
file: BinaryIO
The audio file like object
encode: Boolean
If true, encode audio stream to WAV before sending to whisper
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
Open an audio file object and read as mono waveform,
resampling as necessary.
Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py
to accept a file object
Parameters:
- file (BinaryIO): The audio file object.
- encode (bool, optional): Whether to encode audio stream to WAV
before sending to whisper.
Defaults to True.
- sr (int, optional): The sample rate to resample the audio if necessary.
Defaults to SAMPLE_RATE.
Returns:
- np.ndarray: A NumPy array containing the audio waveform,
in float32 dtype.
"""
if encode:
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
# This launches a subprocess to decode audio while down-mixing
# and resampling
# as necessary. Requires the ffmpeg CLI and `ffmpeg-python`
# package to be installed.
out, _ = (
ffmpeg.input("pipe:", threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read())
.run(
cmd="ffmpeg",
capture_stdout=True,
capture_stderr=True,
input=file.read(),
)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
error_message = f"Failed to load audio: {e.stderr.decode()}"
raise RuntimeError(error_message) from e
else:
out = file.read()

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


@app.get("/readiness/", status_code=status.HTTP_200_OK)
def readiness_check():
"""
Health readiness endpoint for external monitoring systems.
Returns the status as "ok".
"""
return {"status": "ok"}


@app.get("/liveness/", status_code=status.HTTP_200_OK)
def liveness_check():
"""
Health liveness check endpoint for internal monitoring systems.
Returns the status as "ok".
"""
return {"status": "ok"}
8 changes: 8 additions & 0 deletions infrastructure/whisper/NOTES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Thank you for installing AMLG {{ .Chart.Name }}.

Your release is named {{ .Release.Name }}.

To learn more about the release, try:

$ helm status {{ .Release.Name }}
$ helm get all {{ .Release.Name }}
18 changes: 18 additions & 0 deletions infrastructure/whisper/templates/whisper-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,23 @@ spec:
value: faster_whisper
- name: ASR_MODEL
value: medium.en
livenessProbe:
httpGet:
path: /liveness
port: {{ .Values.whisperService.port }}
initialDelaySeconds: 200
timeoutSeconds: 2
periodSeconds: 30
failureThreshold: 1
readinessProbe:
httpGet:
path: /readiness
port: {{ .Values.whisperService.port }}
initialDelaySeconds: 200
timeoutSeconds: 2
periodSeconds: 10
failureThreshold: 3
nodeSelector:
cloud.google.com/gke-nodepool: "splice-xcd-gpu-t4"


Loading

0 comments on commit ea9f10f

Please sign in to comment.