diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/README.md b/dockerfiles/triton/samples/test_models/speech_recognition/README.md new file mode 100644 index 0000000..7177f99 --- /dev/null +++ b/dockerfiles/triton/samples/test_models/speech_recognition/README.md @@ -0,0 +1,38 @@ +## TRITON SERVER AND CLIENT FOR WHISPER ASR MODELS + +### INSTALLATIONS +``` +pip install git+https://github.com/huggingface/optimum-habana.git +``` +``` +export HF_HOME=/data/huggingface_cache +export TRANSFORMERS_CACHE=$HF_HOME/models +export HF_HUB_CACHE=$HF_HOME/hub +export HF_TOKEN="huggingface token" +``` + +### STARTING THE TRITON SERVER +``` +cd Setup_and_Install/dockerfiles/triton/samples/test_models/speech_recognition +pip install -r requirements.txt +tritonserver --model-repository model_repo --log-verbose=5 +``` + +### RUNNING TRITON ASYNC HTTP CLIENT FOR CONCURRENCY TEST +``` +cd Setup_and_Install/dockerfiles/triton/samples/test_models/speech_recognition + +python simple_http_client_async.py + +example : python simple_http_client_async.py 50 "sample_audio/en2.wav" +``` + +Note : + +Works with batches. Max batch size can be controlled by `max_batch_size: batch_size` field in config.pbtxt. + +Dynamic batching supported. Refer `dynamic_batching` section in config. + +Currently warmup logic is still being tested. Bucketing strategy like vllm-fork might help getting better performance. + +Testing in progress for more Triton server based perf optimisations. diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/en1.wav b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/en1.wav new file mode 100644 index 0000000..67d6fd6 Binary files /dev/null and b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/en1.wav differ diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/model.py b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/model.py new file mode 100644 index 0000000..d7463f6 --- /dev/null +++ b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/model.py @@ -0,0 +1,198 @@ +import json +import logging +import os +from time import time + +import numpy as np + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils +from utils import count_hpu_graphs, initialize_model_n_processor, read_audio + +gen_kwargs = {} + + +class habana_args: + device = "hpu" + model_name_or_path = "openai/whisper-large-v2" + audio_file = "en1.wav" + token = None + bf16 = True + use_hpu_graphs = True + seed = 42 + batch_size = -1 + model_revision = "main" + sampling_rate = 16000 + global_rank = 0 + world_size = 1 + local_rank = 0 + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + print(f"Initializing") + self.model, self.processor = initialize_model_n_processor(habana_args, logger) + self.device = self.model.device + self.model_dtype = self.model.dtype + self.sampling_rate = habana_args.sampling_rate + + # TEST A SAMPLE DURING INITIALISATION + cur_dir = os.path.dirname(os.path.abspath(__file__)) + input_speech_arr, sampling_rate = read_audio( + os.path.join(cur_dir, habana_args.audio_file) + ) + for i in range(1): + t1 = time() + out_transcript = self.infer_transcript( + input_speech_arr, habana_args.sampling_rate + ) + t2 = time() + print(f"Test inference time:{t2-t1}secs {out_transcript}") + + print("Initialize finished") + self.model_config = model_config = json.loads(args["model_config"]) + + # Get OUTPUT0 configuration + output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0") + + # Convert Triton types to numpy types + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config["data_type"] + ) + + def infer_transcript(self, audio_batch, sampling_rate=16000): + t1 = time() + input_features = self.processor( + audio_batch, sampling_rate=sampling_rate, return_tensors="pt" + ).input_features.to(self.device) + predicted_ids = self.model.generate( + input_features.to(self.model_dtype), **gen_kwargs + ) + transcription = self.processor.batch_decode( + predicted_ids, skip_special_tokens=True + ) + t2 = time() + print(f"Time for {len(transcription)} samples : {t2-t1}secs") + return transcription + + def batched_inference(self, requests): + request_batch = [] + for request in requests: + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + request_batch.append(in_0.as_numpy()) + + request_batch = np.array(request_batch).squeeze() + print( + f"xxxxxxxxxxx AUDIO BATCHED INPUT SIZE : {request_batch.shape} INPUT TYPE : {type(request_batch)}" + ) + + out_0 = self.infer_transcript(request_batch, habana_args.sampling_rate) + + return out_0 + + # def execute(self, requests): + def execute(self, requests): + """`execute` MUST be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference request is made + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + print(f"NUM REQUESTS {len(requests)}") + + if ( + len(requests) > 1 + ): # More than 1 requests are received , batch them and infer at once + out_0_batched = self.batched_inference(requests) + responses = [] + for i in range(len(requests)): + # Create OUTPUT tensors + out_tensor_0 = pb_utils.Tensor( + "OUTPUT0", np.array(out_0_batched[i], dtype=self.output0_dtype) + ) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0] + ) + responses.append(inference_response) + else: # Single sample inference + # Every Python backend must iterate over everyone of the requests and create a pb_utils.InferenceResponse for each of them. + for request in requests: + # Get INPUTS + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + input_speech_arr = in_0.as_numpy() + print( + f"xxxxxxxxxxx AUDIO INPUT SIZE : {input_speech_arr.shape} INPUT TYPE : {type(input_speech_arr)}" + ) + + out_0 = self.infer_transcript( + input_speech_arr, habana_args.sampling_rate + ) + + # Create OUTPUT tensors. You need pb_utils.Tensor objects to create pb_utils.InferenceResponse. + out_tensor_0 = pb_utils.Tensor( + "OUTPUT0", np.array(out_0, dtype=self.output0_dtype) + ) + + # Create InferenceResponse. + # pb_utils.InferenceResponse(output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is OPTIONAL. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/utils.py b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/utils.py new file mode 100644 index 0000000..ef919c7 --- /dev/null +++ b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/1/utils.py @@ -0,0 +1,146 @@ +import glob +import os +import shutil +import time + +import soundfile +import torch +from habana_frameworks.torch.hpu import wrap_in_hpu_graph +from optimum.habana.checkpoint_utils import get_repo_root +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from optimum.habana.utils import check_optimum_habana_min_version, set_seed +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline +from transformers.utils import check_min_version + + +def read_audio(audio_file_path): + audio_array, sample_rate = soundfile.read(audio_file_path) + return audio_array, sample_rate + + +def override_print(enable): + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if force or enable: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def override_logger(logger, enable): + logger_info = logger.info + + def info(*args, **kwargs): + force = kwargs.pop("force", False) + if force or enable: + logger_info(*args, **kwargs) + + logger.info = info + + +def count_hpu_graphs(): + return len(glob.glob(".graph_dumps/*PreGraph*")) + + +def override_prints(enable, logger): + override_print(enable) + override_logger(logger, enable) + + +def setup_env(args): + """ + Need to test periodically if any breaking change is introduced in Optimum Habana , Transformers + Might work with lower versions as well but not tested + """ + check_min_version("4.45.2") + check_optimum_habana_min_version("1.14.0.dev0") + + if args.global_rank == 0: + os.environ.setdefault("GRAPH_VISUALIZATION", "true") + shutil.rmtree(".graph_dumps", ignore_errors=True) + + if args.world_size > 0: + os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") + os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true") + + # Tweak generation so that it runs faster on Gaudi + adapt_transformers_to_gaudi() + + +def setup_device(args): + if args.device == "hpu": + import habana_frameworks.torch.core as htcore + return torch.device(args.device) + + +def setup_distributed_model(args, model_dtype, model_kwargs, logger): + """ + TO BE IMPLEMENTED + """ + raise Exception("Distributed model using Deepspeed yet to be implemented") + return + + +def load_model( + model_name_or_path, model_dtype, model_kwargs, logger, use_hpu_graphs=True +): + logger.info(f"Loading model : {model_name_or_path}") + + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_name_or_path, + torch_dtype=model_dtype, + low_cpu_mem_usage=False, + use_safetensors=True, + ) # , attn_implementation="sdpa", + # **model_kwargs) + model = model.eval().to("hpu") + if use_hpu_graphs: + model = wrap_in_hpu_graph(model) + return model + + +def load_processor(model_name_or_path, logger): + logger.info(f"Loading processor : {model_name_or_path}") + processor = AutoProcessor.from_pretrained(model_name_or_path) + return processor + + +def initialize_model_n_processor(args, logger): + init_start = time.perf_counter() + override_prints(args.global_rank == 0 or args.verbose_workers, logger) + setup_env(args) + setup_device(args) + set_seed(args.seed) + get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) + + if args.bf16: + model_dtype = torch.bfloat16 + elif args.fp8: + raise Exception("fp8 precision yet to be supported. Please try bf16") + else: + model_dtype = torch.float + args.attn_softmax_bf16 = False + + model_kwargs = { + "revision": args.model_revision, + "token": args.token, + } + + model_name_or_path = args.model_name_or_path + model = load_model( + model_name_or_path, model_dtype, model_kwargs, logger, args.use_hpu_graphs + ) + processor = load_processor(model_name_or_path, logger) + + init_end = time.perf_counter() + logger.info(f"Args: {args}") + logger.info( + f"device: {args.device}, n_hpu: {args.world_size}, dtype: {model_dtype}" + ) + logger.info(f"Model initialization took {(init_end - init_start):.3f}s") + + return model, processor diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/config.pbtxt b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/config.pbtxt new file mode 100644 index 0000000..0b34d36 --- /dev/null +++ b/dockerfiles/triton/samples/test_models/speech_recognition/model_repo/whisper_large_v2/config.pbtxt @@ -0,0 +1,76 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "whisper_large_v2" +backend: "python" +max_batch_size: 12 +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +output [ + { + name: "OUTPUT0" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [{ kind: KIND_CPU }] + +dynamic_batching { + preferred_batch_size: [12] + max_queue_delay_microseconds: 5 +} + +model_warmup [ + { + batch_size: 1 + inputs { + key: "INPUT0" + value: { + data_type: TYPE_FP32 + dims: [ 128 ] + zero_data: true + } + } + }, + { + batch_size: 12, + inputs: { + key: "INPUT0" + value: { + data_type: TYPE_FP32 + dims: [128] + zero_data: true + } + } + } +] diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/requirements.txt b/dockerfiles/triton/samples/test_models/speech_recognition/requirements.txt new file mode 100644 index 0000000..0b3ef7e --- /dev/null +++ b/dockerfiles/triton/samples/test_models/speech_recognition/requirements.txt @@ -0,0 +1,3 @@ +soundfile +librosa +tritonclient[http]==2.41.0 \ No newline at end of file diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/sample_audio/en2.wav b/dockerfiles/triton/samples/test_models/speech_recognition/sample_audio/en2.wav new file mode 100644 index 0000000..f16a2f0 Binary files /dev/null and b/dockerfiles/triton/samples/test_models/speech_recognition/sample_audio/en2.wav differ diff --git a/dockerfiles/triton/samples/test_models/speech_recognition/simple_http_client_async.py b/dockerfiles/triton/samples/test_models/speech_recognition/simple_http_client_async.py new file mode 100644 index 0000000..81ae94b --- /dev/null +++ b/dockerfiles/triton/samples/test_models/speech_recognition/simple_http_client_async.py @@ -0,0 +1,76 @@ +import asyncio +import os +import sys +import time + +import numpy as np +import soundfile +import tritonclient.http.aio as httpclient +from aiohttp import ClientSession, ClientTimeout +from PIL import Image +from tritonclient.utils import * + +http_port = 8000 +model_name = "whisper_large_v2" +timeout = 10000 # aiohttp.ClientTimeout(total=600) + + +def read_audio(audio_file_path): + audio_array, sample_rate = soundfile.read(audio_file_path) + return audio_array, sample_rate + + +def create_random_audio_arr(duration_sec=30, sampling_rate=16000): + total_samples = duration_sec * sampling_rate + audio_array = np.random.rand(total_samples) + return audio_array, sampling_rate + + +async def infer_http_async(audio_file): + async with httpclient.InferenceServerClient( + url=f"localhost:{http_port}", conn_timeout=timeout + ) as client: + audio_arr, sampling_rate = read_audio(audio_file) + # audio_arr, sampling_rate = create_random_audio_arr(duration_sec=12, sampling_rate=16000) + audio_arr = audio_arr.astype("float32").reshape(1, -1) + audio_arr = (audio_arr - audio_arr.min()) / (audio_arr.max() - audio_arr.min()) + + ## INPUT_0 + input_audio_arr = httpclient.InferInput( + "INPUT0", audio_arr.shape, np_to_triton_dtype(audio_arr.dtype) + ) + input_audio_arr.set_data_from_numpy(audio_arr) + + ## OUTPUT + output_text = httpclient.InferRequestedOutput("OUTPUT0") + + query_response = await client.infer( + model_name=model_name, inputs=[input_audio_arr], outputs=[output_text] + ) + + print(query_response.as_numpy("OUTPUT0")) + + +async def infer_http_concurrent(audio_file, num_concurrent_reqs=1): + print(f"\n\n ============= Running {audio_file} {num_concurrent_reqs} times ...") + tasks = [] + for i in range(num_concurrent_reqs): + print(f"Run {i+1}") + tasks.append(infer_http_async(audio_file)) + start = time.time() + await asyncio.gather(*tasks) + end = time.time() + print(f"Time taken: {end - start} secs\n") + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python async_http_client.py ") + sys.exit(1) + + num_concurrent_reqs = int(sys.argv[1]) + audio_file = sys.argv[2] # "sample_audio/en2.wav" + + t1 = time.time() + asyncio.run(infer_http_concurrent(audio_file, num_concurrent_reqs)) + print(f"Total time taken by {os.getpid()}: {time.time() - t1} secs \n")