-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: EthanD <[email protected]> Co-authored-by: EthanD <[email protected]>
- Loading branch information
1 parent
3671e55
commit 5ed8913
Showing
38 changed files
with
2,284 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
FROM dockerhub.icu/pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
WORKDIR /opt/CosyVoice | ||
|
||
RUN chmod 777 /tmp && sed -i 's@//.*archive.ubuntu.com@//mirrors.ustc.edu.cn@g' /etc/apt/sources.list && apt-get update -y && apt-get -y install git unzip git-lfs | ||
RUN git lfs install && git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git | ||
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed | ||
COPY ./requirements.txt CosyVoice | ||
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple | ||
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto | ||
COPY fastapi/server.py CosyVoice/runtime/python/fastapi/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import argparse | ||
import logging | ||
import requests | ||
|
||
def saveResponse(path, response): | ||
# 以二进制写入模式打开文件 | ||
with open(path, 'wb') as file: | ||
# 将响应的二进制内容写入文件 | ||
file.write(response.content) | ||
|
||
def main(): | ||
api = args.api_base | ||
if args.mode == 'sft': | ||
url = api + "/api/inference/sft" | ||
payload={ | ||
'tts': args.tts_text, | ||
'role': args.spk_id | ||
} | ||
response = requests.request("POST", url, data=payload) | ||
saveResponse(args.tts_wav, response) | ||
elif args.mode == 'zero_shot': | ||
url = api + "/api/inference/zero-shot" | ||
payload={ | ||
'tts': args.tts_text, | ||
'prompt': args.prompt_text | ||
} | ||
files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))] | ||
response = requests.request("POST", url, data=payload, files=files) | ||
saveResponse(args.tts_wav, response) | ||
elif args.mode == 'cross_lingual': | ||
url = api + "/api/inference/cross-lingual" | ||
payload={ | ||
'tts': args.tts_text, | ||
} | ||
files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))] | ||
response = requests.request("POST", url, data=payload, files=files) | ||
saveResponse(args.tts_wav, response) | ||
else: | ||
url = api + "/api/inference/instruct" | ||
payload = { | ||
'tts': args.tts_text, | ||
'role': args.spk_id, | ||
'instruct': args.instruct_text | ||
} | ||
response = requests.request("POST", url, data=payload) | ||
saveResponse(args.tts_wav, response) | ||
logging.info("Response save to {}", args.tts_wav) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--api_base', | ||
type=str, | ||
default='http://127.0.0.1:50000') | ||
parser.add_argument('--mode', | ||
default='sft', | ||
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], | ||
help='request mode') | ||
parser.add_argument('--tts_text', | ||
type=str, | ||
default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') | ||
parser.add_argument('--spk_id', | ||
type=str, | ||
default='中文男') | ||
parser.add_argument('--prompt_text', | ||
type=str, | ||
default='希望你以后能够做的比我还好呦。') | ||
parser.add_argument('--prompt_wav', | ||
type=str, | ||
default='../../../zero_shot_prompt.wav') | ||
parser.add_argument('--instruct_text', | ||
type=str, | ||
default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.') | ||
parser.add_argument('--tts_wav', | ||
type=str, | ||
default='loushiming.mp3') | ||
args = parser.parse_args() | ||
prompt_sr, target_sr = 16000, 22050 | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Set inference model | ||
# export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct | ||
# For development | ||
# fastapi dev --port 6006 fastapi_server.py | ||
# For production deployment | ||
# fastapi run --port 6006 fastapi_server.py | ||
|
||
import os | ||
import sys | ||
import io,time | ||
from fastapi import FastAPI, Request, Response, File, UploadFile, Form, Body | ||
from fastapi.responses import HTMLResponse | ||
from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块 | ||
from contextlib import asynccontextmanager | ||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append('{}/../../..'.format(ROOT_DIR)) | ||
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) | ||
from cosyvoice.cli.cosyvoice import CosyVoice | ||
from cosyvoice.utils.file_utils import load_wav | ||
import numpy as np | ||
import torch | ||
import torchaudio | ||
import logging | ||
from pydantic import BaseModel | ||
logging.getLogger('matplotlib').setLevel(logging.WARNING) | ||
|
||
class LaunchFailed(Exception): | ||
pass | ||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT") | ||
if model_dir: | ||
logging.info("MODEL_DIR is {}", model_dir) | ||
app.cosyvoice = CosyVoice(model_dir) | ||
# sft usage | ||
logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks()) | ||
else: | ||
raise LaunchFailed("MODEL_DIR environment must set") | ||
yield | ||
|
||
app = FastAPI(lifespan=lifespan) | ||
|
||
#设置允许访问的域名 | ||
origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。 | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=origins, #设置允许的origins来源 | ||
allow_credentials=True, | ||
allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。 | ||
allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。 | ||
|
||
def buildResponse(output): | ||
buffer = io.BytesIO() | ||
torchaudio.save(buffer, output, 22050, format="mp3") | ||
buffer.seek(0) | ||
return Response(content=buffer.read(-1), media_type="audio/mpeg") | ||
|
||
@app.post("/api/inference/sft") | ||
@app.get("/api/inference/sft") | ||
async def sft(tts: str = Form(), role: str = Form()): | ||
start = time.process_time() | ||
output = app.cosyvoice.inference_sft(tts, role) | ||
end = time.process_time() | ||
logging.info("infer time is {} seconds", end-start) | ||
return buildResponse(output['tts_speech']) | ||
|
||
class SpeechRequest(BaseModel): | ||
model: str | ||
input: str | ||
voice: str | ||
|
||
@app.post("/v1/audio/speech") | ||
async def sft(request: Request, speech_request: SpeechRequest): | ||
# 解析请求体中的JSON数据 | ||
data = speech_request.dict() | ||
|
||
start = time.process_time() | ||
output = app.cosyvoice.inference_sft(data['input'], data['voice']) | ||
end = time.process_time() | ||
logging.info("infer time is {} seconds", end-start) | ||
return buildResponse(output['tts_speech']) | ||
|
||
@app.post("/api/inference/zero-shot") | ||
async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()): | ||
start = time.process_time() | ||
prompt_speech = load_wav(audio.file, 16000) | ||
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() | ||
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0) | ||
prompt_speech_16k = prompt_speech_16k.float() / (2**15) | ||
|
||
output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k) | ||
end = time.process_time() | ||
logging.info("infer time is {} seconds", end-start) | ||
return buildResponse(output['tts_speech']) | ||
|
||
@app.post("/api/inference/cross-lingual") | ||
async def crossLingual(tts: str = Form(), audio: UploadFile = File()): | ||
start = time.process_time() | ||
prompt_speech = load_wav(audio.file, 16000) | ||
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() | ||
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0) | ||
prompt_speech_16k = prompt_speech_16k.float() / (2**15) | ||
|
||
output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k) | ||
end = time.process_time() | ||
logging.info("infer time is {} seconds", end-start) | ||
return buildResponse(output['tts_speech']) | ||
|
||
@app.post("/api/inference/instruct") | ||
@app.get("/api/inference/instruct") | ||
async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()): | ||
start = time.process_time() | ||
output = app.cosyvoice.inference_instruct(tts, role, instruct) | ||
end = time.process_time() | ||
logging.info("infer time is {} seconds", end-start) | ||
return buildResponse(output['tts_speech']) | ||
|
||
@app.get("/api/roles") | ||
async def roles(): | ||
return {"roles": app.cosyvoice.list_avaliable_spks()} | ||
|
||
@app.get("/", response_class=HTMLResponse) | ||
async def root(): | ||
return """ | ||
<!DOCTYPE html> | ||
<html lang=zh-cn> | ||
<head> | ||
<meta charset=utf-8> | ||
<title>Api information</title> | ||
</head> | ||
<body> | ||
Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a> | ||
</body> | ||
</html> | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import sys | ||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append('{}/../../..'.format(ROOT_DIR)) | ||
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) | ||
import logging | ||
import argparse | ||
import torchaudio | ||
import cosyvoice_pb2 | ||
import cosyvoice_pb2_grpc | ||
import grpc | ||
import torch | ||
import numpy as np | ||
from cosyvoice.utils.file_utils import load_wav | ||
|
||
|
||
def main(): | ||
with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel: | ||
stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel) | ||
request = cosyvoice_pb2.Request() | ||
if args.mode == 'sft': | ||
logging.info('send sft request') | ||
sft_request = cosyvoice_pb2.sftRequest() | ||
sft_request.spk_id = args.spk_id | ||
sft_request.tts_text = args.tts_text | ||
request.sft_request.CopyFrom(sft_request) | ||
elif args.mode == 'zero_shot': | ||
logging.info('send zero_shot request') | ||
zero_shot_request = cosyvoice_pb2.zeroshotRequest() | ||
zero_shot_request.tts_text = args.tts_text | ||
zero_shot_request.prompt_text = args.prompt_text | ||
prompt_speech = load_wav(args.prompt_wav, 16000) | ||
zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() | ||
request.zero_shot_request.CopyFrom(zero_shot_request) | ||
elif args.mode == 'cross_lingual': | ||
logging.info('send cross_lingual request') | ||
cross_lingual_request = cosyvoice_pb2.crosslingualRequest() | ||
cross_lingual_request.tts_text = args.tts_text | ||
prompt_speech = load_wav(args.prompt_wav, 16000) | ||
cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() | ||
request.cross_lingual_request.CopyFrom(cross_lingual_request) | ||
else: | ||
logging.info('send instruct request') | ||
instruct_request = cosyvoice_pb2.instructRequest() | ||
instruct_request.tts_text = args.tts_text | ||
instruct_request.spk_id = args.spk_id | ||
instruct_request.instruct_text = args.instruct_text | ||
request.instruct_request.CopyFrom(instruct_request) | ||
|
||
response = stub.Inference(request) | ||
logging.info('save response to {}'.format(args.tts_wav)) | ||
tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0) | ||
torchaudio.save(args.tts_wav, tts_speech, target_sr) | ||
logging.info('get response') | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--host', | ||
type=str, | ||
default='0.0.0.0') | ||
parser.add_argument('--port', | ||
type=int, | ||
default='50000') | ||
parser.add_argument('--mode', | ||
default='sft', | ||
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], | ||
help='request mode') | ||
parser.add_argument('--tts_text', | ||
type=str, | ||
default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') | ||
parser.add_argument('--spk_id', | ||
type=str, | ||
default='中文女') | ||
parser.add_argument('--prompt_text', | ||
type=str, | ||
default='希望你以后能够做的比我还好呦。') | ||
parser.add_argument('--prompt_wav', | ||
type=str, | ||
default='../../../zero_shot_prompt.wav') | ||
parser.add_argument('--instruct_text', | ||
type=str, | ||
default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.') | ||
parser.add_argument('--tts_wav', | ||
type=str, | ||
default='demo.wav') | ||
args = parser.parse_args() | ||
prompt_sr, target_sr = 16000, 22050 | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
syntax = "proto3"; | ||
|
||
package cosyvoice; | ||
option go_package = "protos/"; | ||
|
||
service CosyVoice{ | ||
rpc Inference(Request) returns (Response) {} | ||
} | ||
|
||
message Request{ | ||
oneof RequestPayload { | ||
sftRequest sft_request = 1; | ||
zeroshotRequest zero_shot_request = 2; | ||
crosslingualRequest cross_lingual_request = 3; | ||
instructRequest instruct_request = 4; | ||
} | ||
} | ||
|
||
message sftRequest{ | ||
string spk_id = 1; | ||
string tts_text = 2; | ||
} | ||
|
||
message zeroshotRequest{ | ||
string tts_text = 1; | ||
string prompt_text = 2; | ||
bytes prompt_audio = 3; | ||
} | ||
|
||
message crosslingualRequest{ | ||
string tts_text = 1; | ||
bytes prompt_audio = 2; | ||
} | ||
|
||
message instructRequest{ | ||
string tts_text = 1; | ||
string spk_id = 2; | ||
string instruct_text = 3; | ||
} | ||
|
||
message Response{ | ||
bytes tts_audio = 1; | ||
} |
Oops, something went wrong.