Skip to content

Commit

Permalink
feat(parler-tts): Add new backend
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Apr 13, 2024
1 parent 4e74560 commit dfc152b
Show file tree
Hide file tree
Showing 11 changed files with 280 additions and 4 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ ARG TARGETVARIANT

ENV BUILD_TYPE=${BUILD_TYPE}
ENV DEBIAN_FRONTEND=noninteractive
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh"

ARG GO_TAGS="stablediffusion tinydream tts"

Expand Down Expand Up @@ -275,6 +275,9 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
make -C backend/python/transformers-musicgen \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
make -C backend/python/parler-tts \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
make -C backend/python/coqui \
; fi
Expand Down
13 changes: 11 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,10 @@ protogen-go-clean:
$(RM) bin/*

.PHONY: protogen-python
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen sentencetransformers-protogen transformers-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen

.PHONY: protogen-python-clean
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean

.PHONY: autogptq-protogen
autogptq-protogen:
Expand Down Expand Up @@ -524,6 +524,14 @@ transformers-protogen:
transformers-protogen-clean:
$(MAKE) -C backend/python/transformers protogen-clean

.PHONY: parler-tts-protogen
parler-tts-protogen:
$(MAKE) -C backend/python/parler-tts protogen

.PHONY: parler-tts-protogen-clean
parler-tts-protogen-clean:
$(MAKE) -C backend/python/parler-tts protogen-clean

.PHONY: transformers-musicgen-protogen
transformers-musicgen-protogen:
$(MAKE) -C backend/python/transformers-musicgen protogen
Expand Down Expand Up @@ -560,6 +568,7 @@ prepare-extra-conda-environments: protogen-python
$(MAKE) -C backend/python/sentencetransformers
$(MAKE) -C backend/python/transformers
$(MAKE) -C backend/python/transformers-musicgen
$(MAKE) -C backend/python/parler-tts
$(MAKE) -C backend/python/vall-e-x
$(MAKE) -C backend/python/exllama
$(MAKE) -C backend/python/petals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,6 @@ dependencies:
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
- descript-audio-codec
- git+https://github.com/huggingface/parler-tts.git@10016fb0300c0dc31a0fb70e26f3affee7b62f16
prefix: /opt/conda/envs/transformers
2 changes: 2 additions & 0 deletions backend/python/common-env/transformers/transformers-rocm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,6 @@ dependencies:
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
- descript-audio-codec
- git+https://github.com/huggingface/parler-tts.git@10016fb0300c0dc31a0fb70e26f3affee7b62f16
prefix: /opt/conda/envs/transformers
2 changes: 2 additions & 0 deletions backend/python/common-env/transformers/transformers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,6 @@ dependencies:
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
- descript-audio-codec
- git+https://github.com/huggingface/parler-tts.git@10016fb0300c0dc31a0fb70e26f3affee7b62f16
prefix: /opt/conda/envs/transformers
25 changes: 25 additions & 0 deletions backend/python/parler-tts/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
.PHONY: parler-tts
parler-tts: protogen
$(MAKE) -C ../common-env/transformers

.PHONY: run
run: protogen
@echo "Running transformers..."
bash run.sh
@echo "transformers run."

.PHONY: test
test: protogen
@echo "Testing transformers..."
bash test.sh
@echo "transformers tested."

.PHONY: protogen
protogen: backend_pb2_grpc.py backend_pb2.py

.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py

backend_pb2_grpc.py backend_pb2.py:
python3 -m grpc_tools.protoc -I../.. --python_out=. --grpc_python_out=. backend.proto
125 changes: 125 additions & 0 deletions backend/python/parler-tts/parler_tts_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Extra gRPC server for MusicgenForConditionalGeneration models.
"""
from concurrent import futures

import argparse
import signal
import sys
import os

import time
import backend_pb2
import backend_pb2_grpc

import grpc

from scipy.io.wavfile import write as write_wav

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import torch

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
A gRPC servicer for the backend service.
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
"""
def Health(self, request, context):
"""
A gRPC method that returns the health status of the backend service.
Args:
request: A HealthRequest object that contains the request parameters.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Reply object that contains the health status of the backend service.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def LoadModel(self, request, context):
"""
A gRPC method that loads a model into memory.
Args:
request: A LoadModelRequest object that contains the request parameters.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Result object that contains the result of the LoadModel operation.
"""
model_name = request.Model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
try:
self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")

return backend_pb2.Result(message="Model loaded successfully", success=True)

def TTS(self, request, context):
model_name = request.model
voice = request.voice
if voice == "":
voice = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required")
try:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
input_ids = self.tokenizer(voice, return_tensors="pt").input_ids.to(device)
prompt_input_ids = self.tokenizer(request.text, return_tensors="pt").input_ids.to(device)

generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
print("[parler-tts] TTS generated!", file=sys.stderr)
sf.write(request.dst, audio_arr, self.model.config.sampling_rate)
print("[parler-tts] TTS saved to", request.dst, file=sys.stderr)
print("[parler-tts] TTS for", file=sys.stderr)
print(request, file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)


def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("[parler-tts] Server started. Listening on: " + address, file=sys.stderr)

# Define the signal handler function
def signal_handler(sig, frame):
print("[parler-tts] Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)

# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
print(f"[parler-tts] startup: {args}", file=sys.stderr)
serve(args.addr)
16 changes: 16 additions & 0 deletions backend/python/parler-tts/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

##
## A bash script wrapper that runs the parler-tts server with conda

echo "Launching gRPC server for parler-tts"

export PATH=$PATH:/opt/conda/bin

# Activate conda environment
source activate transformers

# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

python $DIR/parler_tts_server.py $@
11 changes: 11 additions & 0 deletions backend/python/parler-tts/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash
##
## A bash script wrapper that runs the transformers server with conda

# Activate conda environment
source activate transformers

# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

python -m unittest $DIR/test_parler.py
81 changes: 81 additions & 0 deletions backend/python/parler-tts/test_parler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
A test script to test the gRPC service
"""
import unittest
import subprocess
import time
import backend_pb2
import backend_pb2_grpc

import grpc


class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service
"""
def setUp(self):
"""
This method sets up the gRPC service by starting the server
"""
self.service = subprocess.Popen(["python3", "parler_tts_server.py", "--addr", "localhost:50051"])
time.sleep(10)

def tearDown(self) -> None:
"""
This method tears down the gRPC service by terminating the server
"""
self.service.terminate()
self.service.wait()

def test_server_startup(self):
"""
This method tests if the server starts up successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()

def test_load_model(self):
"""
This method tests if the model is loaded successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="parler-tts/parler_tts_mini_v0.1"))
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()

def test_tts(self):
"""
This method tests if the embeddings are generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="parler-tts/parler_tts_mini_v0.1"))
self.assertTrue(response.success)
tts_request = backend_pb2.TTSRequest(text="Hey, how are you doing today?")
tts_response = stub.TTS(tts_request)
self.assertIsNotNone(tts_response)
except Exception as err:
print(err)
self.fail("TTS service failed")
finally:
self.tearDown()
2 changes: 1 addition & 1 deletion backend/python/transformers-musicgen/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ echo "Launching gRPC server for transformers-musicgen"
export PATH=$PATH:/opt/conda/bin

# Activate conda environment
source activate transformers-musicgen
source activate transformers

# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
Expand Down

0 comments on commit dfc152b

Please sign in to comment.