From e238c26b867abdccd724489e24bdca2d3fad3ff4 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Fri, 11 Apr 2025 03:14:56 -0700 Subject: [PATCH] feat: Add option to pass QAICInferenceSession to TextGeneration Signed-off-by: quic-shagun --- .../generation/text_generation_inference.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 2dd485a5e..f76437ab5 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -322,6 +322,8 @@ def cloud_ai_100_exec_kv( automation=False, prompt_to_lora_id_mapping: Optional[List[int]] = None, is_tlm: bool = False, + session: Optional[QAICInferenceSession] = None, + print_latency_stats: bool = True, ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. @@ -342,6 +344,8 @@ def cloud_ai_100_exec_kv( :Write_io_dir (str): Path to write the input and output files. ``Defaults to None``. :automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``. :prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter. + :session (QAICInferenceSession): Pre-initialized QAICInferenceSession object. ``Defaults to None``. + :print_latency_stats (bool): If True, it prints latency statistics. ``Defaults to True``. Returns: :CloudAI100ExecInfo: Object holding execution output and performance details. @@ -372,6 +376,7 @@ def cloud_ai_100_exec_kv( write_io_dir=write_io_dir, full_batch_size=full_batch_size, is_tlm=is_tlm, + session=session, ) if full_batch_size is None: exec_info = [ @@ -395,8 +400,8 @@ def cloud_ai_100_exec_kv( exec_info = generate_text.generate( prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping ) - - print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation) + if print_latency_stats: + print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation) return exec_info @@ -411,13 +416,16 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: Optional[int] = None, + session: Optional[QAICInferenceSession] = None, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir self.is_tlm = is_tlm # Load QPC - self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + self._session = ( + session if session else QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + ) # Fetch the variables from the QPC self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size @@ -905,9 +913,10 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: bool = False, + session: Optional[QAICInferenceSession] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( - tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm + tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm, session ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer