From 934ca8d70da05b0407b54cf1d0ec48b47004203a Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 22 Nov 2024 13:35:41 -0800 Subject: [PATCH 1/3] [PoC] Remove need for template values in configs, support building TRTLLM engine on model load if none found, add env vars for conveniently configuring engine and tokenizers from a single location --- .../ensemble/config.pbtxt | 2 +- .../postprocessing/1/model.py | 11 +- .../postprocessing/config.pbtxt | 8 +- .../preprocessing/1/model.py | 13 ++- .../preprocessing/config.pbtxt | 10 +- .../tensorrt_llm/1/model.py | 108 +++++++++++++++--- .../tensorrt_llm/config.pbtxt | 16 +-- .../tensorrt_llm_bls/config.pbtxt | 6 +- 8 files changed, 131 insertions(+), 43 deletions(-) diff --git a/all_models/inflight_batcher_llm/ensemble/config.pbtxt b/all_models/inflight_batcher_llm/ensemble/config.pbtxt index dd552480..c4863adc 100644 --- a/all_models/inflight_batcher_llm/ensemble/config.pbtxt +++ b/all_models/inflight_batcher_llm/ensemble/config.pbtxt @@ -26,7 +26,7 @@ name: "ensemble" platform: "ensemble" -max_batch_size: ${triton_max_batch_size} +max_batch_size: 256 input [ { name: "text_input" diff --git a/all_models/inflight_batcher_llm/postprocessing/1/model.py b/all_models/inflight_batcher_llm/postprocessing/1/model.py index b233f6c4..2c40df1c 100644 --- a/all_models/inflight_batcher_llm/postprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/postprocessing/1/model.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import os import numpy as np import triton_python_backend_utils as pb_utils @@ -53,8 +54,14 @@ def initialize(self, args): """ # Parse model configs model_config = json.loads(args['model_config']) - tokenizer_dir = model_config['parameters']['tokenizer_dir'][ - 'string_value'] + # Support tokenizer dir from env var for central location + tokenizer_dir = os.environ.get( + "TRTLLM_ENGINE_DIR", + model_config['parameters']['tokenizer_dir']['string_value']) + if not tokenizer_dir: + raise pb_utils.TritonModelException( + f"No tokenizer directory set. Please set TRTLLM_ENGINE_DIR env var or 'tokenizer_dir' config field to the directory containing engines and tokenizers." + ) skip_special_tokens = model_config['parameters'].get( 'skip_special_tokens') diff --git a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt index 9f3655ef..babf34a7 100644 --- a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt +++ b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt @@ -26,7 +26,7 @@ name: "postprocessing" backend: "python" -max_batch_size: ${triton_max_batch_size} +max_batch_size: 256 dynamic_batching {} input [ { @@ -48,13 +48,15 @@ output [ } ] +# TODO: env var parameters { key: "tokenizer_dir" value: { - string_value: "${tokenizer_dir}" + string_value: "" } } +# TODO: Lookup how its filled today parameters { key: "skip_special_tokens" value: { @@ -64,7 +66,7 @@ parameters { instance_group [ { - count: ${postprocessing_instance_count} + count: 8 kind: KIND_CPU } ] diff --git a/all_models/inflight_batcher_llm/preprocessing/1/model.py b/all_models/inflight_batcher_llm/preprocessing/1/model.py index ec60537b..c2ca2f54 100755 --- a/all_models/inflight_batcher_llm/preprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/preprocessing/1/model.py @@ -55,8 +55,14 @@ def initialize(self, args): """ # Parse model configs model_config = json.loads(args['model_config']) - tokenizer_dir = model_config['parameters']['tokenizer_dir'][ - 'string_value'] + # Support tokenizer dir from env var for central location + tokenizer_dir = os.environ.get( + "TRTLLM_ENGINE_DIR", + model_config['parameters']['tokenizer_dir']['string_value']) + if not tokenizer_dir: + raise pb_utils.TritonModelException( + f"No tokenizer directory set. Please set TRTLLM_ENGINE_DIR env var or 'tokenizer_dir' config field to the directory containing engines and tokenizers." + ) add_special_tokens = model_config['parameters'].get( 'add_special_tokens') @@ -662,9 +668,8 @@ def __init__(self, import requests import torch from PIL import Image - from torch.utils.dlpack import from_dlpack - from tensorrt_llm._utils import str_dtype_to_torch + from torch.utils.dlpack import from_dlpack # create method for loading image from urls self.load_images_from_urls = lambda img_urls: [ diff --git a/all_models/inflight_batcher_llm/preprocessing/config.pbtxt b/all_models/inflight_batcher_llm/preprocessing/config.pbtxt index cdc04257..2c156e48 100755 --- a/all_models/inflight_batcher_llm/preprocessing/config.pbtxt +++ b/all_models/inflight_batcher_llm/preprocessing/config.pbtxt @@ -26,7 +26,7 @@ name: "preprocessing" backend: "python" -max_batch_size: ${triton_max_batch_size} +max_batch_size: 256 input [ { name: "QUERY" @@ -177,10 +177,11 @@ output [ } ] +# TODO: Use shared env var parameters { key: "tokenizer_dir" value: { - string_value: "${tokenizer_dir}" + string_value: "" } } @@ -198,10 +199,11 @@ parameters { } } +# TODO: Shared env var parameters: { key: "gpt_model_path" value: { - string_value: "${engine_dir}" + string_value: "" } } @@ -214,7 +216,7 @@ parameters: { instance_group [ { - count: ${preprocessing_instance_count} + count: 8 kind: KIND_CPU } ] diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py index 423dde6b..64ef2893 100755 --- a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py +++ b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py @@ -3,17 +3,19 @@ import os import sys import time +from pathlib import Path from random import randint from threading import Lock, Thread import numpy as np +import tensorrt_llm.bindings.executor as trtllm +import tensorrt_llm.logger as logger import torch import triton_python_backend_utils as pb_utils from torch import from_numpy from torch.utils.dlpack import from_dlpack -import tensorrt_llm.bindings.executor as trtllm -import tensorrt_llm.logger as logger +from tensorrt_llm import LLM, BuildConfig def mpi_comm(): @@ -321,7 +323,7 @@ def convert_request(request, exclude_input_from_output, decoupled): # if request doesn't specify exclude_input_from_output, try to use the parameter output_config.exclude_input_from_output = ( exclude_input_from_output - if exclude_input_from_output is not None else false) + if exclude_input_from_output is not None else False) else: output_config.exclude_input_from_output = req_exclude_input_from_output @@ -791,27 +793,25 @@ def initialize(self, args): * model_version: Model version * model_name: Model name """ - model_config = json.loads(args['model_config']) - gpt_model_path = get_parameter(model_config, "gpt_model_path") - if get_parameter(model_config, "enable_trt_overlap", bool): + self.model_config = json.loads(args['model_config']) + if get_parameter(self.model_config, "enable_trt_overlap", bool): raise pb_utils.TritonModelException( f"enable_trt_overlap=true is not supported.") self.exclude_input_from_output = get_parameter( - model_config, "exclude_input_in_output", bool) - executor_config = self.get_executor_config(model_config) - self.executor = trtllm.Executor(gpt_model_path, - trtllm.ModelType.DECODER_ONLY, - executor_config) + self.model_config, "exclude_input_in_output", bool) self.decoupled = pb_utils.using_decoupled_model_transaction_policy( - model_config) + self.model_config) self.cancellation_check_period_ms = get_parameter( - model_config, "cancellation_check_period_ms", int) or 100 + self.model_config, "cancellation_check_period_ms", int) or 100 self.stats_check_period_ms = get_parameter( - model_config, "stats_check_period_ms", int) or 100 + self.model_config, "stats_check_period_ms", int) or 100 + + # Setup and initialize executor + self.init_engine() self.create_metrics(args["model_name"], args["model_version"], - is_v1_model=executor_config.batching_type == + is_v1_model=self.executor_config.batching_type == trtllm.BatchingType.STATIC) self.triton_user_id_to_req_ids = {} self.triton_req_id_to_req_ids = {} @@ -830,6 +830,78 @@ def initialize(self, args): # In leader mode, worker ranks will wait here until leader is done. self.executor.shutdown() + def init_engine(self): + engine_dir: Path = self.get_engine_dir() + engines = [engine for engine in engine_dir.glob("*.engine")] + # Build engine if not found + if not engines: + pb_utils.Logger.log_info(f"No engine(s) found at {engine_dir}.") + model_id: str = os.environ.get("TRTLLM_MODEL") + if not model_id: + raise pb_utils.TritonModelException( + f"Could not build engine because no TRTLLM_MODEL was specified." + ) + + self.build_engine(model_id, engine_dir) + + self.load_engine(engine_dir) + + def get_engine_dir(self) -> Path: + engine_dir: Path = Path( + os.environ.get("TRTLLM_ENGINE_DIR", + get_parameter(self.model_config, "gpt_model_path"))) + if not engine_dir: + raise pb_utils.TritonModelException( + f"No engine directory set. Please set TRTLLM_ENGINE_DIR env var or 'gpt_model_path' config field to the directory containing engines and tokenizers." + ) + + if not engine_dir.exists(): + pb_utils.Logger.log_info( + f"{engine_dir} does not exist, so it will be created.") + engine_dir.mkdir() + + if not engine_dir.is_dir(): + raise pb_utils.TritonModelException( + f"{engine_dir} is not a valid directory, please choose a valid directory." + ) + + return engine_dir + + def load_engine(self, engine_dir: Path): + self.executor_config = self.get_executor_config(self.model_config) + self.executor = trtllm.Executor(engine_dir, + trtllm.ModelType.DECODER_ONLY, + self.executor_config) + + def build_engine(self, model_id: str, engine_dir: Path): + """ + model_id: str + Local filepath to model weights or HuggingFace ID + """ + pb_utils.Logger.log_info(f"Building engine from {model_id}.") + + # TODO: Read from config,json if available + build_config = self.get_engine_build_config() + engine = LLM( + model_id, + build_config=build_config, + # TODO: Needed to avoid OOM here? + kv_cache_config=trtllm.KvCacheConfig(free_gpu_memory_fraction=0.1)) + pb_utils.Logger.log_info(f"Saving engine to {engine_dir}.") + engine.save(str(engine_dir)) + pb_utils.Logger.log_info(f"Saved engine to {engine_dir}.") + + def get_engine_build_config(self): + # NOTE: Given config.json, can read from 'build_config' section and from_dict + config = BuildConfig() + # TODO: Expose more build args to user + # TODO: Discuss LLM API BuildConfig defaults + # NOTE: Using some defaults from trtllm-build because LLM API defaults are too low + #config.max_input_len = 1024 + #config.max_seq_len = 8192 + #config.max_batch_size = 256 + return config + def handle_stop_request(self, triton_user_id, response_sender): if triton_user_id is None or triton_user_id == "": response_sender.send( @@ -877,9 +949,7 @@ def execute(self, requests): triton_req_ids = [] for request in requests: - triton_user_id = request.request_id() - response_sender = request.get_response_sender() stop = get_input_scalar_by_name(request, 'stop') @@ -899,9 +969,11 @@ def execute(self, requests): except Exception as e: response_sender.send( pb_utils.InferenceResponse(error=pb_utils.TritonError( - f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'" + f"An error occurred when processing the input values for request id {triton_user_id}, the error was '{e}'" )), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + # TODO: Remove + raise e else: for batch_index, converted_req in enumerate( converted_reqs): diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt index ace52600..9060ed09 100644 --- a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt +++ b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt @@ -25,17 +25,16 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. name: "tensorrt_llm" -backend: "${triton_backend}" -max_batch_size: ${triton_max_batch_size} +backend: "python" +max_batch_size: 256 model_transaction_policy { - decoupled: ${decoupled_mode} + decoupled: True } dynamic_batching { - preferred_batch_size: [ ${triton_max_batch_size} ] - max_queue_delay_microseconds: ${max_queue_delay_microseconds} - default_queue_policy: { max_queue_size: ${max_queue_size} } + max_queue_delay_microseconds: 100 + default_queue_policy: { max_queue_size: 0 } } input [ @@ -433,7 +432,7 @@ output [ instance_group [ { count: 1 - kind : KIND_CPU + kind : KIND_MODEL } ] parameters: { @@ -451,9 +450,10 @@ parameters: { parameters: { key: "gpt_model_type" value: { - string_value: "${batching_strategy}" + string_value: "inflight_fused_batching" } } +# TODO: env var parameters: { key: "gpt_model_path" value: { diff --git a/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt b/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt index 4d1bccdd..3e1b9899 100644 --- a/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt +++ b/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt @@ -26,10 +26,10 @@ name: "tensorrt_llm_bls" backend: "python" -max_batch_size: ${triton_max_batch_size} +max_batch_size: 256 model_transaction_policy { - decoupled: ${decoupled_mode} + decoupled: True } input [ @@ -330,7 +330,7 @@ parameters: { instance_group [ { - count: ${bls_instance_count} + count: 1 kind : KIND_CPU } ] From 4ad7815d3f9a54dd7b80e0d1dfa0ddfd6b492432 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 22 Nov 2024 14:54:47 -0800 Subject: [PATCH 2/3] Remove kvcache config, add TRTLLM_TOKENIZER env var support --- .../postprocessing/1/model.py | 38 ++++++++++---- .../preprocessing/1/model.py | 50 +++++++++++++------ .../tensorrt_llm/1/model.py | 7 +-- 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/all_models/inflight_batcher_llm/postprocessing/1/model.py b/all_models/inflight_batcher_llm/postprocessing/1/model.py index 2c40df1c..68a335de 100644 --- a/all_models/inflight_batcher_llm/postprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/postprocessing/1/model.py @@ -53,17 +53,11 @@ def initialize(self, args): * model_name: Model name """ # Parse model configs - model_config = json.loads(args['model_config']) + self.model_config = json.loads(args['model_config']) # Support tokenizer dir from env var for central location - tokenizer_dir = os.environ.get( - "TRTLLM_ENGINE_DIR", - model_config['parameters']['tokenizer_dir']['string_value']) - if not tokenizer_dir: - raise pb_utils.TritonModelException( - f"No tokenizer directory set. Please set TRTLLM_ENGINE_DIR env var or 'tokenizer_dir' config field to the directory containing engines and tokenizers." - ) + tokenizer_dir = self.get_tokenizer_dir() - skip_special_tokens = model_config['parameters'].get( + skip_special_tokens = self.model_config['parameters'].get( 'skip_special_tokens') if skip_special_tokens is not None: skip_special_tokens_str = skip_special_tokens[ @@ -94,12 +88,36 @@ def initialize(self, args): # Parse model output configs output_config = pb_utils.get_output_config_by_name( - model_config, "OUTPUT") + self.model_config, "OUTPUT") # Convert Triton types to numpy types self.output_dtype = pb_utils.triton_string_to_numpy( output_config['data_type']) + def get_tokenizer_dir(self): + # Manual override of tokenizer. This is to support common case/models + # when engine/tokenizer are downloaded on demand at model load time. + tokenizer_dir = os.environ.get("TRTLLM_TOKENIZER") + + # If no override, use tokenizer co-located with engine + if not tokenizer_dir: + tokenizer_dir = os.environ.get("TRTLLM_ENGINE_DIR") + + # If no env var used at all, use tokenizer dir defined in config.pbtxt + # This is for backwards compatibility but is the most tedious to set + # and keep aligned in each location. + if not tokenizer_dir: + tokenizer_dir = self.model_config['parameters']['tokenizer_dir'][ + 'string_value'] + + # If no method of setting tokenizer worked, fail. + if not tokenizer_dir: + raise pb_utils.TritonModelException( + f"No tokenizer directory set. Please set TRTLLM_ENGINE_DIR env var or 'tokenizer_dir' config field to the directory containing engines and tokenizers." + ) + + return tokenizer_dir + def execute(self, requests): """`execute` must be implemented in every Python model. `execute` function receives a list of pb_utils.InferenceRequest as the only diff --git a/all_models/inflight_batcher_llm/preprocessing/1/model.py b/all_models/inflight_batcher_llm/preprocessing/1/model.py index c2ca2f54..92acfc7a 100755 --- a/all_models/inflight_batcher_llm/preprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/preprocessing/1/model.py @@ -54,21 +54,15 @@ def initialize(self, args): * model_name: Model name """ # Parse model configs - model_config = json.loads(args['model_config']) + self.model_config = json.loads(args['model_config']) # Support tokenizer dir from env var for central location - tokenizer_dir = os.environ.get( - "TRTLLM_ENGINE_DIR", - model_config['parameters']['tokenizer_dir']['string_value']) - if not tokenizer_dir: - raise pb_utils.TritonModelException( - f"No tokenizer directory set. Please set TRTLLM_ENGINE_DIR env var or 'tokenizer_dir' config field to the directory containing engines and tokenizers." - ) + tokenizer_dir = self.get_tokenizer_dir() - add_special_tokens = model_config['parameters'].get( + add_special_tokens = self.model_config['parameters'].get( 'add_special_tokens') - visual_model_path = model_config['parameters']['visual_model_path'][ - 'string_value'] - max_num_images = model_config['parameters'].get('max_num_images') + visual_model_path = self.model_config['parameters'][ + 'visual_model_path']['string_value'] + max_num_images = self.model_config['parameters'].get('max_num_images') if max_num_images is not None: max_num_images_str = max_num_images['string_value'] @@ -139,7 +133,7 @@ def initialize(self, args): 'llava', 'blip2-opt', 'vila', 'mllama' ], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, vila and mllama. Got {self.model_type}." - llm_model_path = model_config['parameters']['gpt_model_path'][ + llm_model_path = self.model_config['parameters']['gpt_model_path'][ 'string_value'] llm_model_path = os.path.join(llm_model_path, 'config.json') with open(llm_model_path, 'r') as f: @@ -150,7 +144,7 @@ def initialize(self, args): self.vision_preprocessor = VisionPreProcessor( self.model_type, AutoProcessor.from_pretrained(tokenizer_dir), - model_config) + self.model_config) # Parse model output configs and convert Triton types to numpy types output_names = [ @@ -165,7 +159,7 @@ def initialize(self, args): input_name.lower() + "_dtype", pb_utils.triton_string_to_numpy( pb_utils.get_input_config_by_name( - model_config, input_name)['data_type'])) + self.model_config, input_name)['data_type'])) for output_name in output_names: setattr( @@ -173,7 +167,31 @@ def initialize(self, args): output_name.lower() + "_dtype", pb_utils.triton_string_to_numpy( pb_utils.get_output_config_by_name( - model_config, output_name)['data_type'])) + self.model_config, output_name)['data_type'])) + + def get_tokenizer_dir(self): + # Manual override of tokenizer. This is to support common case/models + # when engine/tokenizer are downloaded on demand at model load time. + tokenizer_dir = os.environ.get("TRTLLM_TOKENIZER") + + # If no override, use tokenizer co-located with engine + if not tokenizer_dir: + tokenizer_dir = os.environ.get("TRTLLM_ENGINE_DIR") + + # If no env var used at all, use tokenizer dir defined in config.pbtxt + # This is for backwards compatibility but is the most tedious to set + # and keep aligned in each location. + if not tokenizer_dir: + tokenizer_dir = self.model_config['parameters']['tokenizer_dir'][ + 'string_value'] + + # If no method of setting tokenizer worked, fail. + if not tokenizer_dir: + raise pb_utils.TritonModelException( + f"No tokenizer directory set. Please set TRTLLM_ENGINE_DIR env var or 'tokenizer_dir' config field to the directory containing engines and tokenizers." + ) + + return tokenizer_dir def _setup_ptable_shape(self, llm_model_config): max_prompt_embedding_table_size = llm_model_config['build_config'][ diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py index 64ef2893..cb3110c1 100755 --- a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py +++ b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py @@ -882,12 +882,7 @@ def build_engine(self, model_id: str, engine_dir: Path): # TODO: Read from config,json if available build_config = self.get_engine_build_config() - engine = LLM( - model_id, - build_config=build_config, - # TODO: Needed to avoid OOM here? - kv_cache_config=trtllm.KvCacheConfig(free_gpu_memory_fraction=0.1)) - pb_utils.Logger.log_info(f"Saving engine to {engine_dir}.") + engine = LLM(model_id, build_config=build_config) engine.save(str(engine_dir)) pb_utils.Logger.log_info(f"Saved engine to {engine_dir}.") From 1bea632cd36076c9903647d27e0cde6e22b1a211 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Fri, 22 Nov 2024 15:36:15 -0800 Subject: [PATCH 3/3] Add placeholder for using engine build config from config,json, but comment it out because it can't be ingested directly --- .../tensorrt_llm/1/model.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py index cb3110c1..73eebcdd 100755 --- a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py +++ b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py @@ -834,7 +834,10 @@ def init_engine(self): engine_dir: Path = self.get_engine_dir() engines = [engine for engine in engine_dir.glob("*.engine")] # Build engine if not found - if not engines: + if engines: + pb_utils.Logger.log_info( + f"Found existing engine(s) at {engine_dir}.") + else: pb_utils.Logger.log_info(f"No engine(s) found at {engine_dir}.") model_id: str = os.environ.get("TRTLLM_MODEL") if not model_id: @@ -887,14 +890,23 @@ def build_engine(self, model_id: str, engine_dir: Path): pb_utils.Logger.log_info(f"Saved engine to {engine_dir}.") def get_engine_build_config(self): - # NOTE: Given config.json, can read from 'build_config' section and from_dict + # FIXME: Can't construct BuildConfig directly from **build_config + # If a config file exists with a build_config, use it. + #config_file = engine_dir / "config.json" + #if config_file.exists(): + # pb_utils.Logger.log_info(f"Found engine build config at {config_file}.") + # with open(config_file) as f: + # config_json = json.load(f) + # build_config = config_json["build_config"] + # pb_utils.Logger.log_info(f"Using build config: {build_config}") + # config = BuildConfig(**build_config) + #else: + # pb_utils.Logger.log_info(f"Using default build config.") + # # Default config if no config file found + # config = BuildConfig() + config = BuildConfig() - # TODO: Expose more build args to user - # TODO: Discuss LLM API BuildConfig defaults - # NOTE: Using some defaults from trtllm-build because LLM API defaults are too low - #config.max_input_len = 1024 - #config.max_seq_len = 8192 - #config.max_batch_size = 256 + pb_utils.Logger.log_info(f"Using default build config: {config}") return config def handle_stop_request(self, triton_user_id, response_sender):