-
Notifications
You must be signed in to change notification settings - Fork 0
/
translate.py
66 lines (55 loc) · 2.68 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import os
import time
from gpt_subtitle_translator.constants import TOKENS_PER_CHUNK, DEFAULT_MODEL, MAX_RETRIES, DEFAULT_TEMPERATURE
from gpt_subtitle_translator.models.claude import Claude
from gpt_subtitle_translator.models.gemini import Gemini
from gpt_subtitle_translator.models.gpt import GPT
from gpt_subtitle_translator.subtitle_translator import SubtitleTranslator, TranslationError
from gpt_subtitle_translator.logger import logger
import chardet
def get_output_filename(input_filename):
token = int(time.time())
directory, filename = os.path.split(input_filename)
return os.path.join(directory, f"{filename.split('.')[0]}_{token}_translated.srt")
def get_model(model_name):
if model_name.startswith("gpt"):
return GPT(model_name)
if model_name.startswith("gemini"):
return Gemini(model_name)
return Claude(model_name)
def main():
parser = argparse.ArgumentParser(description='Translate a transcript file.')
parser.add_argument('file', help='The transcript file to translate.', nargs='?', default='')
parser.add_argument('-l', '--language', type=str, default="English", help='Language to translate to.')
parser.add_argument('-t', '--threads', type=int, default=1, help='Number of threads to use.')
parser.add_argument('-temp', '--temperature', type=float, default=DEFAULT_TEMPERATURE, help='Temperature for generation.')
parser.add_argument('-s', '--chunk_size', type=int, default=TOKENS_PER_CHUNK, help='Number of tokens per chunk.')
parser.add_argument('-m', '--model', type=str, default=DEFAULT_MODEL, help='Model to use.')
parser.add_argument('-r', '--retries', type=int, default=MAX_RETRIES, help='Number of retries.')
args = parser.parse_args()
with open(args.file, 'rb') as f:
encoding = chardet.detect(f.read())['encoding']
with open(args.file, 'r', encoding=encoding) as f:
srt_data = f.read()
filename = get_output_filename(args.file)
model = get_model(args.model)
translator = SubtitleTranslator(
model=model,
lang=args.language,
num_threads=args.threads,
tokens_per_chunk=args.chunk_size,
max_retries=args.retries,
temperature=args.temperature
)
try:
result = translator.translate_subtitles(srt_data)
except TranslationError as e:
return logger.error(e)
finally:
logger.info(f"Total API cost: ${model.get_total_cost():.5f}")
with open(filename, 'w', encoding='utf-8') as f:
f.write(result)
logger.info(f"Translated subtitles with {args.model}, file written to {filename}")
if __name__ == '__main__':
main()