Skip to content

Commit 6e893ed

Browse files
committed
moving the load_tensorrt_llm to dynamo/utils.py
1 parent 1e54bbf commit 6e893ed

File tree

3 files changed

+127
-126
lines changed

3 files changed

+127
-126
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

-124
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
import functools
44
import logging
55
import os
6-
import shutil
7-
import subprocess
8-
import sys
96
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
107

118
import numpy as np
@@ -16,7 +13,6 @@
1613
from torch.fx.node import Argument, Target
1714
from torch.fx.passes.shape_prop import TensorMetadata
1815
from torch_tensorrt import _enums
19-
from torch_tensorrt._enums import Platform
2016
from torch_tensorrt.dynamo._settings import CompilationSettings
2117
from torch_tensorrt.dynamo._SourceIR import SourceIR
2218
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -1006,123 +1002,3 @@ def args_bounds_check(
10061002
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
10071003
) -> Any:
10081004
return args[i] if len(args) > i and args[i] is not None else replacement
1009-
1010-
1011-
def download_plugin_lib_path(py_version: str, platform: str) -> str:
1012-
plugin_lib_path = None
1013-
1014-
# Downloading TRT-LLM lib
1015-
# TODO: check how to fix the 0.18.0 hardcode below
1016-
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
1017-
file_name = f"tensorrt_llm-0.18.0-{py_version}-{py_version}-{platform}.whl"
1018-
download_url = base_url + file_name
1019-
cmd = ["wget", download_url]
1020-
if not (os.path.exists(file_name)):
1021-
try:
1022-
subprocess.run(cmd, check=True)
1023-
_LOGGER.debug("Download succeeded and TRT-LLM wheel is now present")
1024-
except subprocess.CalledProcessError as e:
1025-
_LOGGER.error(
1026-
"Download failed (file not found or connection issue). Error code:",
1027-
e.returncode,
1028-
)
1029-
except FileNotFoundError:
1030-
_LOGGER.error("wget is required but not found. Please install wget.")
1031-
1032-
# Proceeding with the unzip of the wheel file
1033-
# This will exist if the filename was already downloaded
1034-
if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"):
1035-
plugin_lib_path = "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1036-
else:
1037-
try:
1038-
import zipfile
1039-
except:
1040-
raise ImportError(
1041-
"zipfile module is required but not found. Please install zipfile"
1042-
)
1043-
with zipfile.ZipFile(file_name, "r") as zip_ref:
1044-
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
1045-
plugin_lib_path = (
1046-
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1047-
)
1048-
return plugin_lib_path
1049-
1050-
1051-
def load_tensorrt_llm() -> bool:
1052-
"""
1053-
Attempts to load the TensorRT-LLM plugin and initialize it.
1054-
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
1055-
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
1056-
1057-
Returns:
1058-
bool: True if the plugin was successfully loaded and initialized, False otherwise.
1059-
"""
1060-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
1061-
if not plugin_lib_path:
1062-
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
1063-
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
1064-
"1",
1065-
"true",
1066-
"yes",
1067-
"on",
1068-
)
1069-
if not use_trtllm_plugin:
1070-
_LOGGER.warning(
1071-
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
1072-
)
1073-
return False
1074-
else:
1075-
# this is used as the default py version
1076-
py_version = f"cp312"
1077-
platform = Platform.current_platform()
1078-
1079-
platform = str(platform).lower()
1080-
plugin_lib_path = download_plugin_lib_path(py_version, platform)
1081-
1082-
try:
1083-
# Load the shared TRT-LLM file
1084-
handle = ctypes.CDLL(plugin_lib_path)
1085-
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
1086-
except OSError as e_os_error:
1087-
if "libmpi" in str(e_os_error):
1088-
_LOGGER.warning(
1089-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
1090-
f"The dependency libmpi.so is missing. "
1091-
f"Please install the packages libmpich-dev and libopenmpi-dev.",
1092-
exc_info=e_os_error,
1093-
)
1094-
else:
1095-
_LOGGER.warning(
1096-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
1097-
f"Ensure the path is correct and the library is compatible",
1098-
exc_info=e_os_error,
1099-
)
1100-
return False
1101-
1102-
try:
1103-
# Configure plugin initialization arguments
1104-
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
1105-
handle.initTrtLlmPlugins.restype = ctypes.c_bool
1106-
except AttributeError as e_plugin_unavailable:
1107-
_LOGGER.warning(
1108-
"Unable to initialize the TensorRT-LLM plugin library",
1109-
exc_info=e_plugin_unavailable,
1110-
)
1111-
return False
1112-
1113-
try:
1114-
# Initialize the plugin
1115-
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
1116-
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
1117-
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
1118-
return True
1119-
else:
1120-
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
1121-
return False
1122-
except Exception as e_initialization_error:
1123-
_LOGGER.warning(
1124-
"Exception occurred during TensorRT-LLM plugin library initialization",
1125-
exc_info=e_initialization_error,
1126-
)
1127-
return False
1128-
return False

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
1212
dynamo_tensorrt_converter,
1313
)
14-
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
1514
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
1615
tensorrt_fused_nccl_all_gather_op,
1716
tensorrt_fused_nccl_reduce_scatter_op,
1817
)
18+
from torch_tensorrt.dynamo.utils import load_tensorrt_llm
1919

2020
_LOGGER: logging.Logger = logging.getLogger(__name__)
2121

py/torch_tensorrt/dynamo/utils.py

+126-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from __future__ import annotations
22

3+
import ctypes
34
import gc
45
import logging
6+
import os
7+
import shutil
8+
import subprocess
9+
import sys
510
import warnings
611
from dataclasses import fields, replace
712
from enum import Enum
@@ -14,7 +19,7 @@
1419
from torch._subclasses.fake_tensor import FakeTensor
1520
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1621
from torch_tensorrt._Device import Device
17-
from torch_tensorrt._enums import dtype
22+
from torch_tensorrt._enums import Platform, dtype
1823
from torch_tensorrt._features import ENABLED_FEATURES
1924
from torch_tensorrt._Input import Input
2025
from torch_tensorrt.dynamo import _defaults
@@ -812,3 +817,123 @@ def is_tegra_platform() -> bool:
812817
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
813818
return True
814819
return False
820+
821+
822+
def download_plugin_lib_path(py_version: str, platform: str) -> str:
823+
plugin_lib_path = None
824+
825+
# Downloading TRT-LLM lib
826+
# TODO: check how to fix the 0.18.0 hardcode below
827+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
828+
file_name = f"tensorrt_llm-0.18.0-{py_version}-{py_version}-{platform}.whl"
829+
download_url = base_url + file_name
830+
cmd = ["wget", download_url]
831+
if not (os.path.exists(file_name)):
832+
try:
833+
subprocess.run(cmd, check=True)
834+
logger.debug("Download succeeded and TRT-LLM wheel is now present")
835+
except subprocess.CalledProcessError as e:
836+
logger.error(
837+
"Download failed (file not found or connection issue). Error code:",
838+
e.returncode,
839+
)
840+
except FileNotFoundError:
841+
logger.error("wget is required but not found. Please install wget.")
842+
843+
# Proceeding with the unzip of the wheel file
844+
# This will exist if the filename was already downloaded
845+
if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"):
846+
plugin_lib_path = "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
847+
else:
848+
try:
849+
import zipfile
850+
except:
851+
raise ImportError(
852+
"zipfile module is required but not found. Please install zipfile"
853+
)
854+
with zipfile.ZipFile(file_name, "r") as zip_ref:
855+
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
856+
plugin_lib_path = (
857+
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
858+
)
859+
return plugin_lib_path
860+
861+
862+
def load_tensorrt_llm() -> bool:
863+
"""
864+
Attempts to load the TensorRT-LLM plugin and initialize it.
865+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
866+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
867+
868+
Returns:
869+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
870+
"""
871+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
872+
if not plugin_lib_path:
873+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
874+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
875+
"1",
876+
"true",
877+
"yes",
878+
"on",
879+
)
880+
if not use_trtllm_plugin:
881+
logger.warning(
882+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
883+
)
884+
return False
885+
else:
886+
# this is used as the default py version
887+
py_version = f"cp312"
888+
platform = Platform.current_platform()
889+
890+
platform = str(platform).lower()
891+
plugin_lib_path = download_plugin_lib_path(py_version, platform)
892+
893+
try:
894+
# Load the shared TRT-LLM file
895+
handle = ctypes.CDLL(plugin_lib_path)
896+
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
897+
except OSError as e_os_error:
898+
if "libmpi" in str(e_os_error):
899+
logger.warning(
900+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
901+
f"The dependency libmpi.so is missing. "
902+
f"Please install the packages libmpich-dev and libopenmpi-dev.",
903+
exc_info=e_os_error,
904+
)
905+
else:
906+
logger.warning(
907+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
908+
f"Ensure the path is correct and the library is compatible",
909+
exc_info=e_os_error,
910+
)
911+
return False
912+
913+
try:
914+
# Configure plugin initialization arguments
915+
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
916+
handle.initTrtLlmPlugins.restype = ctypes.c_bool
917+
except AttributeError as e_plugin_unavailable:
918+
logger.warning(
919+
"Unable to initialize the TensorRT-LLM plugin library",
920+
exc_info=e_plugin_unavailable,
921+
)
922+
return False
923+
924+
try:
925+
# Initialize the plugin
926+
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
927+
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
928+
logger.info("TensorRT-LLM plugin successfully initialized")
929+
return True
930+
else:
931+
logger.warning("TensorRT-LLM plugin library failed in initialization")
932+
return False
933+
except Exception as e_initialization_error:
934+
logger.warning(
935+
"Exception occurred during TensorRT-LLM plugin library initialization",
936+
exc_info=e_initialization_error,
937+
)
938+
return False
939+
return False

0 commit comments

Comments
 (0)