Skip to content

Commit 9e4a4b5

Browse files
authoredOct 29, 2024··
feat: support live TTS of fish audio (#555)
* feat: support live TTS of fish audio Signed-off-by: Frost Ming <[email protected]>
1 parent 7096933 commit 9e4a4b5

File tree

8 files changed

+273
-202
lines changed

8 files changed

+273
-202
lines changed
 

‎pdm.lock

+58-73
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎requirements.txt

+5-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ edge-tts==6.1.10
2222
exceptiongroup==1.2.0; python_version < "3.11"
2323
frozenlist==1.4.1
2424
google-ai-generativelanguage==0.6.10
25-
google-api-core==2.15.0
25+
google-api-core[grpc]==2.15.0
2626
google-api-python-client==2.125.0
2727
google-auth==2.26.1
2828
google-auth-httplib2==0.2.0
@@ -37,7 +37,8 @@ grpcio-status==1.60.0
3737
h11==0.14.0
3838
httpcore==1.0.5
3939
httplib2==0.22.0
40-
httpx==0.27.2
40+
httpx-ws==0.6.2
41+
httpx[socks]==0.27.2
4142
idna==3.7
4243
jiter==0.5.0
4344
jsonpatch==1.33
@@ -83,13 +84,13 @@ socksio==1.0.0
8384
soupsieve==2.5
8485
sqlalchemy==2.0.25
8586
tenacity==8.2.3
86-
tetos==0.3.1
87+
tetos==0.4.1
8788
tqdm==4.66.1
8889
typing-extensions==4.12.2
8990
typing-inspect==0.9.0
9091
uritemplate==4.1.1
9192
urllib3==2.1.0
9293
websocket-client==1.8.0
93-
websockets==12.0
94+
wsproto==1.2.0
9495
yarl==1.14.0
9596
zhipuai==2.1.5.20230904

‎xiaogpt/tts/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from xiaogpt.tts.base import TTS
2+
from xiaogpt.tts.file import TetosFileTTS
3+
from xiaogpt.tts.live import TetosLiveTTS
24
from xiaogpt.tts.mi import MiTTS
3-
from xiaogpt.tts.tetos import TetosTTS
45

5-
__all__ = ["TTS", "TetosTTS", "MiTTS"]
6+
__all__ = ["TTS", "TetosFileTTS", "MiTTS", "TetosLiveTTS"]

‎xiaogpt/tts/base.py

+1-90
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,10 @@
22

33
import abc
44
import asyncio
5-
import functools
65
import json
76
import logging
8-
import os
9-
import random
10-
import socket
11-
import tempfile
12-
import threading
13-
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
14-
from pathlib import Path
157
from typing import TYPE_CHECKING, AsyncIterator
168

17-
from xiaogpt.utils import get_hostname
18-
199
if TYPE_CHECKING:
2010
from typing import TypeVar
2111

@@ -46,7 +36,7 @@ async def wait_for_duration(self, duration: float) -> None:
4636
break
4737
await asyncio.sleep(1)
4838

49-
async def get_if_xiaoai_is_playing(self):
39+
async def get_if_xiaoai_is_playing(self) -> bool:
5040
playing_info = await self.mina_service.player_get_status(self.device_id)
5141
# WTF xiaomi api
5242
is_playing = (
@@ -59,82 +49,3 @@ async def get_if_xiaoai_is_playing(self):
5949
async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
6050
"""Synthesize speech from a stream of text."""
6151
raise NotImplementedError
62-
63-
64-
class HTTPRequestHandler(SimpleHTTPRequestHandler):
65-
def log_message(self, format, *args):
66-
logger.debug(f"{self.address_string()} - {format}", *args)
67-
68-
def log_error(self, format, *args):
69-
logger.error(f"{self.address_string()} - {format}", *args)
70-
71-
def copyfile(self, source, outputfile):
72-
try:
73-
super().copyfile(source, outputfile)
74-
except (socket.error, ConnectionResetError, BrokenPipeError):
75-
# ignore this or TODO find out why the error later
76-
pass
77-
78-
79-
class AudioFileTTS(TTS):
80-
"""A TTS model that generates audio files locally and plays them via URL."""
81-
82-
def __init__(
83-
self, mina_service: MiNAService, device_id: str, config: Config
84-
) -> None:
85-
super().__init__(mina_service, device_id, config)
86-
self.dirname = tempfile.TemporaryDirectory(prefix="xiaogpt-tts-")
87-
self._start_http_server()
88-
89-
@abc.abstractmethod
90-
async def make_audio_file(self, lang: str, text: str) -> tuple[Path, float]:
91-
"""Synthesize speech from text and save it to a file.
92-
Return the file path and the duration of the audio in seconds.
93-
The file path must be relative to the self.dirname.
94-
"""
95-
raise NotImplementedError
96-
97-
async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
98-
queue: asyncio.Queue[tuple[str, float]] = asyncio.Queue()
99-
finished = asyncio.Event()
100-
101-
async def worker():
102-
async for text in text_stream:
103-
path, duration = await self.make_audio_file(lang, text)
104-
url = f"http://{self.hostname}:{self.port}/{path.name}"
105-
await queue.put((url, duration))
106-
finished.set()
107-
108-
task = asyncio.create_task(worker())
109-
110-
while True:
111-
try:
112-
url, duration = queue.get_nowait()
113-
except asyncio.QueueEmpty:
114-
if finished.is_set():
115-
break
116-
else:
117-
await asyncio.sleep(0.1)
118-
continue
119-
logger.debug("Playing URL %s (%s seconds)", url, duration)
120-
await asyncio.gather(
121-
self.mina_service.play_by_url(self.device_id, url, _type=1),
122-
self.wait_for_duration(duration),
123-
)
124-
await task
125-
126-
def _start_http_server(self):
127-
# set the port range
128-
port_range = range(8050, 8090)
129-
# get a random port from the range
130-
self.port = int(os.getenv("XIAOGPT_PORT", random.choice(port_range)))
131-
# create the server
132-
handler = functools.partial(HTTPRequestHandler, directory=self.dirname.name)
133-
httpd = ThreadingHTTPServer(("", self.port), handler)
134-
# start the server in a new thread
135-
server_thread = threading.Thread(target=httpd.serve_forever)
136-
server_thread.daemon = True
137-
server_thread.start()
138-
139-
self.hostname = get_hostname()
140-
logger.info(f"Serving on {self.hostname}:{self.port}")

‎xiaogpt/tts/file.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import asyncio
2+
import functools
3+
import os
4+
import random
5+
import socket
6+
import tempfile
7+
import threading
8+
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
9+
from pathlib import Path
10+
from typing import AsyncIterator
11+
12+
from miservice import MiNAService
13+
14+
from xiaogpt.config import Config
15+
from xiaogpt.tts.base import TTS, logger
16+
from xiaogpt.utils import get_hostname
17+
18+
19+
class HTTPRequestHandler(SimpleHTTPRequestHandler):
20+
def log_message(self, format, *args):
21+
logger.debug(f"{self.address_string()} - {format}", *args)
22+
23+
def log_error(self, format, *args):
24+
logger.error(f"{self.address_string()} - {format}", *args)
25+
26+
def copyfile(self, source, outputfile):
27+
try:
28+
super().copyfile(source, outputfile)
29+
except (socket.error, ConnectionResetError, BrokenPipeError):
30+
# ignore this or TODO find out why the error later
31+
pass
32+
33+
34+
class TetosFileTTS(TTS):
35+
"""A TTS model that generates audio files locally and plays them via URL."""
36+
37+
def __init__(
38+
self, mina_service: MiNAService, device_id: str, config: Config
39+
) -> None:
40+
from tetos import get_speaker
41+
42+
super().__init__(mina_service, device_id, config)
43+
self.dirname = tempfile.TemporaryDirectory(prefix="xiaogpt-tts-")
44+
self._start_http_server()
45+
46+
assert config.tts and config.tts != "mi"
47+
speaker_cls = get_speaker(config.tts)
48+
try:
49+
self.speaker = speaker_cls(**config.tts_options)
50+
except TypeError as e:
51+
raise ValueError(f"{e}. Please add them via `tts_options` config") from e
52+
53+
async def make_audio_file(self, lang: str, text: str) -> tuple[Path, float]:
54+
output_file = tempfile.NamedTemporaryFile(
55+
suffix=".mp3", mode="wb", delete=False, dir=self.dirname.name
56+
)
57+
duration = await self.speaker.synthesize(text, output_file.name, lang=lang)
58+
return Path(output_file.name), duration
59+
60+
async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
61+
queue: asyncio.Queue[tuple[str, float]] = asyncio.Queue()
62+
finished = asyncio.Event()
63+
64+
async def worker():
65+
async for text in text_stream:
66+
path, duration = await self.make_audio_file(lang, text)
67+
url = f"http://{self.hostname}:{self.port}/{path.name}"
68+
await queue.put((url, duration))
69+
finished.set()
70+
71+
task = asyncio.create_task(worker())
72+
73+
while True:
74+
try:
75+
url, duration = queue.get_nowait()
76+
except asyncio.QueueEmpty:
77+
if finished.is_set():
78+
break
79+
else:
80+
await asyncio.sleep(0.1)
81+
continue
82+
logger.debug("Playing URL %s (%s seconds)", url, duration)
83+
await asyncio.gather(
84+
self.mina_service.play_by_url(self.device_id, url, _type=1),
85+
self.wait_for_duration(duration),
86+
)
87+
await task
88+
89+
def _start_http_server(self):
90+
# set the port range
91+
port_range = range(8050, 8090)
92+
# get a random port from the range
93+
self.port = int(os.getenv("XIAOGPT_PORT", random.choice(port_range)))
94+
# create the server
95+
handler = functools.partial(HTTPRequestHandler, directory=self.dirname.name)
96+
httpd = ThreadingHTTPServer(("", self.port), handler)
97+
# start the server in a new thread
98+
server_thread = threading.Thread(target=httpd.serve_forever)
99+
server_thread.daemon = True
100+
server_thread.start()
101+
102+
self.hostname = get_hostname()
103+
logger.info(f"Serving on {self.hostname}:{self.port}")

‎xiaogpt/tts/live.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import asyncio
2+
import os
3+
import queue
4+
import random
5+
import threading
6+
import uuid
7+
from functools import lru_cache
8+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
9+
from typing import AsyncIterator
10+
11+
from miservice import MiNAService
12+
13+
from xiaogpt.config import Config
14+
from xiaogpt.tts.base import TTS, logger
15+
from xiaogpt.utils import get_hostname
16+
17+
18+
@lru_cache(maxsize=64)
19+
def get_queue(key: str) -> queue.Queue[bytes]:
20+
return queue.Queue()
21+
22+
23+
class HTTPRequestHandler(BaseHTTPRequestHandler):
24+
def do_GET(self):
25+
self.send_response(200)
26+
self.send_header("Content-type", "audio/mpeg")
27+
self.end_headers()
28+
key = self.path.split("/")[-1]
29+
queue = get_queue(key)
30+
while True:
31+
chunk = queue.get()
32+
if chunk == b"":
33+
break
34+
self.wfile.write(chunk)
35+
36+
def log_message(self, format, *args):
37+
logger.debug(f"{self.address_string()} - {format}", *args)
38+
39+
def log_error(self, format, *args):
40+
logger.error(f"{self.address_string()} - {format}", *args)
41+
42+
43+
class TetosLiveTTS(TTS):
44+
"""A TTS model that generates audio in real-time."""
45+
46+
def __init__(
47+
self, mina_service: MiNAService, device_id: str, config: Config
48+
) -> None:
49+
from tetos import get_speaker
50+
51+
super().__init__(mina_service, device_id, config)
52+
self._start_http_server()
53+
54+
assert config.tts and config.tts != "mi"
55+
speaker_cls = get_speaker(config.tts)
56+
try:
57+
self.speaker = speaker_cls(**config.tts_options)
58+
except TypeError as e:
59+
raise ValueError(f"{e}. Please add them via `tts_options` config") from e
60+
if not hasattr(self.speaker, "live"):
61+
raise ValueError(f"{config.tts} Speaker does not support live synthesis")
62+
63+
async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
64+
key = str(uuid.uuid4())
65+
queue = get_queue(key)
66+
67+
async def worker():
68+
async for chunk in self.speaker.live(text_stream, lang):
69+
queue.put(chunk)
70+
queue.put(b"")
71+
72+
task = asyncio.create_task(worker())
73+
await self.mina_service.play_by_url(
74+
self.device_id, f"http://{self.hostname}:{self.port}/{key}", _type=1
75+
)
76+
77+
while True:
78+
if await self.get_if_xiaoai_is_playing():
79+
await asyncio.sleep(1)
80+
else:
81+
break
82+
await task
83+
84+
def _start_http_server(self):
85+
# set the port range
86+
port_range = range(8050, 8090)
87+
# get a random port from the range
88+
self.port = int(os.getenv("XIAOGPT_PORT", random.choice(port_range)))
89+
# create the server
90+
handler = HTTPRequestHandler
91+
httpd = ThreadingHTTPServer(("", self.port), handler)
92+
# start the server in a new thread
93+
server_thread = threading.Thread(target=httpd.serve_forever)
94+
server_thread.daemon = True
95+
server_thread.start()
96+
97+
self.hostname = get_hostname()
98+
logger.info(f"Serving on {self.hostname}:{self.port}")

‎xiaogpt/tts/tetos.py

-31
This file was deleted.

‎xiaogpt/xiaogpt.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
WAKEUP_KEYWORD,
2424
Config,
2525
)
26-
from xiaogpt.tts import TTS, MiTTS, TetosTTS
26+
from xiaogpt.tts import TTS, MiTTS, TetosFileTTS
27+
from xiaogpt.tts.live import TetosLiveTTS
2728
from xiaogpt.utils import detect_language, parse_cookie_string
2829

2930
EOF = object()
@@ -260,8 +261,10 @@ async def do_tts(self, value):
260261
def tts(self) -> TTS:
261262
if self.config.tts == "mi":
262263
return MiTTS(self.mina_service, self.device_id, self.config)
264+
elif self.config.tts == "fish":
265+
return TetosLiveTTS(self.mina_service, self.device_id, self.config)
263266
else:
264-
return TetosTTS(self.mina_service, self.device_id, self.config)
267+
return TetosFileTTS(self.mina_service, self.device_id, self.config)
265268

266269
async def wait_for_tts_finish(self):
267270
while True:

0 commit comments

Comments
 (0)
Please sign in to comment.