Skip to content

Commit

Permalink
Merge pull request #20 from cbsiamlg/feature/fix-whisper-argument
Browse files Browse the repository at this point in the history
Feature/fix whisper argument
  • Loading branch information
beatgeek authored Dec 13, 2023
2 parents dda3c48 + ec065fc commit 91fce8c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 39 deletions.
5 changes: 2 additions & 3 deletions app/faster_whisper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
model_converter(model_name, model_path)

if torch.cuda.is_available():
model = WhisperModel(model_path, device="cuda", compute_type="float16")
model = WhisperModel(model_path, device="cuda", compute_type="float32")
else:
model = WhisperModel(model_path, device="cpu", compute_type="int8")
model_lock = Lock()
Expand All @@ -47,7 +47,6 @@ def transcribe(
segments = []
text = ""
i = 0
logging.info(f"Options: {options_dict}")
segment_generator, info = model.transcribe(audio, beam_size=5, **options_dict)
for segment in segment_generator:
segments.append(segment)
Expand All @@ -59,7 +58,7 @@ def transcribe(
}

outputFile = StringIO()
write_result(result, outputFile, output, word_timestamps=word_timestamps)
write_result(result, outputFile, output)
outputFile.seek(0)

return outputFile
Expand Down
57 changes: 21 additions & 36 deletions app/faster_whisper/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
from typing import TextIO
import logging
from ctranslate2.converters.transformers import TransformersConverter


Expand Down Expand Up @@ -114,45 +113,31 @@ def write_result(self, result: dict, file: TextIO):
class WriteJSON(ResultWriter):
extension: str = "json"

def write_result(self, result: dict, file: TextIO, word_timestamps: bool):
formatted_result = format_json(result, word_timestamps)
def write_result(self, result: dict, file: TextIO):
formatted_result = format_json(result)
json.dump(formatted_result, file, indent=2)


def format_json(json_file, word_timestamps):
def format_json(json_file):
text = json_file['text']
if word_timestamps:
segments = [{
'id': 0,
'seek': 0,
'start': segment[2],
'end': segment[3],
'text': segment[4],
'tokens': segment[5],
'temperature': segment[6],
'avg_logprob': segment[7],
'compression_ratio': segment[8],
'no_speech_prob': segment[9],
'words': [{
'word': word[2],
'start': word[0],
'end': word[1],
'probability': word[3]
} for word in segment[10]]
} for segment in json_file['segments']]
else:
segments = [{
'id': 0,
'seek': 0,
'start': segment[2],
'end': segment[3],
'text': segment[4],
'tokens': segment[5],
'temperature': segment[6],
'avg_logprob': segment[7],
'compression_ratio': segment[8],
'no_speech_prob': segment[9],
} for segment in json_file['segments']]
segments = [{
'id': 0,
'seek': 0,
'start': segment[2],
'end': segment[3],
'text': segment[4],
'tokens': segment[5],
'temperature': segment[6],
'avg_logprob': segment[7],
'compression_ratio': segment[8],
'no_speech_prob': segment[9],
'words': [{
'word': word[2],
'start': word[0],
'end': word[1],
'probability': word[3]
} for word in segment[10]]
} for segment in json_file['segments']]
output = {
"text": text,
"segments": segments,
Expand Down

0 comments on commit 91fce8c

Please sign in to comment.