diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 9a04aba6de..52a9b47c12 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -281,6 +281,16 @@ void TRTEngine::enable_profiling() { exec_ctx->setProfiler(trt_engine_profiler.get()); } +void TRTEngine::set_profile_format(std::string format) { + if (format == "trex") { + this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX); + } else if (format == "perfetto") { + this->trt_engine_profiler->set_profile_format(TraceFormat::kPERFETTO); + } else { + TORCHTRT_THROW_ERROR("Invalid profile format: " + format); + } +} + std::string TRTEngine::get_engine_layer_info() { auto inspector = cuda_engine->createEngineInspector(); return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON); @@ -315,7 +325,7 @@ void TRTEngine::set_profiling_paths() { output_profile_path = std::filesystem::path{profile_path_prefix + "/" + name + "_output_profile.trace"}.string(); enqueue_profile_path = std::filesystem::path{profile_path_prefix + "/" + name + "_enqueue_profile.trace"}.string(); trt_engine_profile_path = - std::filesystem::path{profile_path_prefix + "/" + name + "_engine_exectuion_profile.trace"}.string(); + std::filesystem::path{profile_path_prefix + "/" + name + "_engine_execution_profile.trace"}.string(); cuda_graph_debug_path = std::filesystem::path{profile_path_prefix + "/" + name + "_cudagraph.dot"}.string(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 2db640b6b1..15d723ce4e 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -147,6 +147,7 @@ struct TRTEngine : torch::CustomClassHolder { std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); void enable_profiling(); + void set_profile_format(std::string profile_format); void disable_profiling(); std::string get_engine_layer_info(); diff --git a/core/runtime/TRTEngineProfiler.cpp b/core/runtime/TRTEngineProfiler.cpp index 8f7f0ac4e9..7d85ba82db 100644 --- a/core/runtime/TRTEngineProfiler.cpp +++ b/core/runtime/TRTEngineProfiler.cpp @@ -32,25 +32,40 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector< } } +void TRTEngineProfiler::set_profile_format(TraceFormat format) { + this->profile_format = format; +} + void dump_trace(const std::string& path, const TRTEngineProfiler& value) { std::stringstream out; out << "[" << std::endl; double ts = 0.0; + double running_time = 0.0; + for (size_t i = 0; i < value.layer_names.size(); i++) { + auto layer_name = value.layer_names[i]; + auto elem = value.profile.at(layer_name); + ts += elem.time; + } for (size_t i = 0; i < value.layer_names.size(); i++) { auto layer_name = value.layer_names[i]; auto elem = value.profile.at(layer_name); out << " {" << std::endl; out << " \"name\": \"" << layer_name << "\"," << std::endl; - out << " \"ph\": \"X\"," << std::endl; - out << " \"ts\": " << ts * 1000 << "," << std::endl; - out << " \"dur\": " << elem.time * 1000 << "," << std::endl; - out << " \"tid\": 1," << std::endl; - out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl; - out << " \"args\": {}" << std::endl; + if (value.profile_format == TraceFormat::kPERFETTO) { + out << " \"ph\": \"X\"," << std::endl; + out << " \"ts\": " << running_time * 1000 << "," << std::endl; + out << " \"dur\": " << elem.time * 1000 << "," << std::endl; + out << " \"tid\": 1," << std::endl; + out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl; + out << " \"args\": {}" << std::endl; + } else { // kTREX + out << " \"timeMs\": " << elem.time << "," << std::endl; + out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl; + out << " \"percentage\": " << (elem.time * 100.0 / ts) << std::endl; + } out << " }," << std::endl; - - ts += elem.time; + running_time += elem.time; } out.seekp(-2, out.cur); out << "\n]" << std::endl; diff --git a/core/runtime/TRTEngineProfiler.h b/core/runtime/TRTEngineProfiler.h index 34a901165b..6691f2e81d 100644 --- a/core/runtime/TRTEngineProfiler.h +++ b/core/runtime/TRTEngineProfiler.h @@ -10,12 +10,14 @@ namespace torch_tensorrt { namespace core { namespace runtime { +enum TraceFormat { kPERFETTO, kTREX }; + struct TRTEngineProfiler : public nvinfer1::IProfiler { struct Record { float time{0}; int count{0}; }; - + void set_profile_format(TraceFormat format); virtual void reportLayerTime(const char* layerName, float ms) noexcept; TRTEngineProfiler( const std::string& name, @@ -27,6 +29,7 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler { std::string name; std::vector layer_names; std::map profile; + TraceFormat profile_format = TraceFormat::kPERFETTO; }; } // namespace runtime diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index cbe19b0af6..173ff8c35f 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -82,6 +82,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("__repr__", &TRTEngine::to_str) .def("__obj_flatten__", &TRTEngine::__obj_flatten__) .def("enable_profiling", &TRTEngine::enable_profiling) + .def("set_profile_format", &TRTEngine::set_profile_format) .def("disable_profiling", &TRTEngine::disable_profiling) .def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix) .def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 6fabdad633..607dca76bf 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -19,3 +19,4 @@ from ._settings import CompilationSettings from ._SourceIR import SourceIR from ._tracer import trace + from .debug._Debugger import Debugger diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index e14a449aed..d7092f1e0f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,6 +2,7 @@ import collections.abc import logging +import os import platform import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union @@ -32,6 +33,8 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -43,7 +46,6 @@ get_output_metadata, parse_graph_io, prepare_inputs, - set_log_level, to_torch_device, to_torch_tensorrt_device, ) @@ -66,7 +68,6 @@ def cross_compile_for_windows( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -140,7 +141,6 @@ def cross_compile_for_windows( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -187,8 +187,12 @@ def cross_compile_for_windows( f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}" ) - if debug: - set_log_level(logger.parent, logging.DEBUG) + if kwargs.get("debug", False): + warnings.warn( + "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.", + DeprecationWarning, + stacklevel=2, + ) if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: @@ -299,7 +303,6 @@ def cross_compile_for_windows( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -401,7 +404,6 @@ def compile( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -477,7 +479,6 @@ def compile( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -520,8 +521,13 @@ def compile( torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT """ - if debug: - set_log_level(logger.parent, logging.DEBUG) + if kwargs.get("debug", False): + warnings.warn( + "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality", + DeprecationWarning, + stacklevel=2, + ) + if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: raise ValueError( @@ -643,7 +649,6 @@ def compile( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -718,12 +723,15 @@ def compile( return trt_gm +@fn_supports_debugger def compile_module( gm: torch.fx.GraphModule, sample_arg_inputs: Sequence[Input], sample_kwarg_inputs: Optional[dict[Any, Any]] = None, settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, + *, + _debugger_config: Optional[DebuggerConfig] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -747,7 +755,7 @@ def compile_module( # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( - gm, settings.debug, settings.torch_executed_ops + gm, settings.torch_executed_ops ) dryrun_tracker.total_ops_in_graph = total_ops @@ -799,7 +807,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: logger.info("Partitioning the graph via the fast partitioner") partitioned_module, supported_ops = partitioning.fast_partition( gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, @@ -820,7 +827,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: logger.info("Partitioning the graph via the global partitioner") partitioned_module, supported_ops = partitioning.global_partition( gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, @@ -928,6 +934,41 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules[name] = trt_module + if _debugger_config: + + if _debugger_config.save_engine_profile: + if settings.use_python_runtime: + if _debugger_config.profile_format != "cudagraph": + raise ValueError( + "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization." + ) + else: + trt_module.enable_profiling() + else: + if _debugger_config.profile_format == "cudagraph": + raise ValueError( + "Profiling with Cudagraph can only be enabled when using the Python runtime. C++ runtime profiling only support TREX/Perfetto visualization." + ) + else: + path = os.path.join( + _debugger_config.logging_dir, + "engine_visualization_profile", + ) + os.makedirs(path, exist_ok=True) + trt_module.enable_profiling( + profiling_results_dir=path, + profile_format=_debugger_config.profile_format, + ) + + if _debugger_config.save_layer_info: + with open( + os.path.join( + _debugger_config.logging_dir, "engine_layer_info.json" + ), + "w", + ) as f: + f.write(trt_module.get_layer_info()) + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) @@ -955,7 +996,6 @@ def convert_exported_program_to_serialized_trt_engine( enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, - debug: bool = _defaults.DEBUG, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, workspace_size: int = _defaults.WORKSPACE_SIZE, min_block_size: int = _defaults.MIN_BLOCK_SIZE, @@ -1017,7 +1057,6 @@ def convert_exported_program_to_serialized_trt_engine( torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use - debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage @@ -1057,8 +1096,12 @@ def convert_exported_program_to_serialized_trt_engine( Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ - if debug: - set_log_level(logger.parent, logging.DEBUG) + if kwargs.get("debug", False): + warnings.warn( + "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.", + DeprecationWarning, + stacklevel=2, + ) if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: @@ -1142,7 +1185,6 @@ def convert_exported_program_to_serialized_trt_engine( compilation_options = { "assume_dynamic_shape_support": assume_dynamic_shape_support, "enabled_precisions": enabled_precisions, - "debug": debug, "workspace_size": workspace_size, "min_block_size": min_block_size, "torch_executed_ops": torch_executed_ops, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 921cb37646..9863b00776 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,4 +1,5 @@ import os +import pwd import tempfile import torch @@ -6,7 +7,6 @@ from torch_tensorrt._enums import EngineCapability, dtype ENABLED_PRECISIONS = {dtype.f32} -DEBUG = False DEVICE = None DISABLE_TF32 = False ASSUME_DYNAMIC_SHAPE_SUPPORT = False @@ -57,6 +57,9 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +DEBUG_LOGGING_DIR = os.path.join( + tempfile.gettempdir(), pwd.getpwuid(os.getuid())[0], "torch_tensorrt/debug_logs" +) def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 15136a5170..7cf19e870e 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -42,7 +42,6 @@ deallocate_module, get_model_device, get_torch_inputs, - set_log_level, to_torch_device, to_torch_tensorrt_device, ) @@ -75,7 +74,6 @@ def construct_refit_mapping( interpreter = TRTInterpreter( module, inputs, - logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, compilation_settings=settings, ) @@ -269,9 +267,6 @@ def refit_module_weights( not settings.immutable_weights ), "Refitting is not enabled. Please recompile the engine with immutable_weights=False." - if settings.debug: - set_log_level(logger.parent, logging.DEBUG) - device = to_torch_tensorrt_device(settings.device) if arg_inputs: if not isinstance(arg_inputs, collections.abc.Sequence): @@ -327,7 +322,6 @@ def refit_module_weights( logger.info("Partitioning the graph via the fast partitioner") new_partitioned_module, supported_ops = partitioning.fast_partition( new_gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, @@ -347,7 +341,6 @@ def refit_module_weights( logger.info("Partitioning the graph via the global partitioner") new_partitioned_module, supported_ops = partitioning.global_partition( new_gm, - verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 97c02f34fb..7ac77cccae 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,7 +7,6 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, - DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -101,7 +100,6 @@ class CompilationSettings: """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) - debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Collection[Target] = field(default_factory=set) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 78f7989777..5f4bdd0a8d 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -7,8 +7,8 @@ import torch from torch.export import Dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import DEBUG, default_device -from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device +from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device logger = logging.getLogger(__name__) @@ -70,10 +70,6 @@ def trace( if kwarg_inputs is None: kwarg_inputs = {} - debug = kwargs.get("debug", DEBUG) - if debug: - set_log_level(logger.parent, logging.DEBUG) - device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index cef00f3a2a..ae8af28348 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -45,6 +45,8 @@ get_trt_tensor, to_torch, ) +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -70,21 +72,23 @@ class TRTInterpreterResult(NamedTuple): requires_output_allocator: bool +@cls_supports_debugger class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] def __init__( self, module: torch.fx.GraphModule, input_specs: Sequence[Input], - logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, + *, + _debugger_config: Optional[DebuggerConfig] = None, ): super().__init__(module) self.logger = TRT_LOGGER self.builder = trt.Builder(self.logger) - + self._debugger_config = _debugger_config flag = 0 if compilation_settings.use_explicit_typing: STRONGLY_TYPED = 1 << (int)( @@ -205,7 +209,7 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() - if self.compilation_settings.debug: + if self._debugger_config and self._debugger_config.engine_builder_monitor: builder_config.progress_monitor = TRTBulderMonitor() if self.compilation_settings.workspace_size != 0: @@ -216,7 +220,7 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.DETAILED - if self.compilation_settings.debug + if self._debugger_config and self._debugger_config.save_engine_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index adb7039e7e..35b6c26617 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,7 +3,6 @@ import logging from typing import Any, List, Optional, Sequence -import tensorrt as trt import torch from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -60,7 +59,6 @@ def interpret_module_to_result( interpreter = TRTInterpreter( module, inputs, - logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, compilation_settings=settings, engine_cache=engine_cache, diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py new file mode 100644 index 0000000000..be5bea358b --- /dev/null +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -0,0 +1,199 @@ +import contextlib +import functools +import logging +import os +import tempfile +from logging.config import dictConfig +from typing import Any, List, Optional +from unittest import mock + +import torch +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import ( + _DEBUG_ENABLED_CLS, + _DEBUG_ENABLED_FUNCS, +) +from torch_tensorrt.dynamo.lowering import ( + ATEN_POST_LOWERING_PASSES, + ATEN_PRE_LOWERING_PASSES, +) + +_LOGGER = logging.getLogger(__name__) +GRAPH_LEVEL = 5 +logging.addLevelName(GRAPH_LEVEL, "GRAPHS") + + +class Debugger: + def __init__( + self, + log_level: str = "debug", + capture_fx_graph_before: Optional[List[str]] = None, + capture_fx_graph_after: Optional[List[str]] = None, + save_engine_profile: bool = False, + profile_format: str = "perfetto", + engine_builder_monitor: bool = True, + logging_dir: str = DEBUG_LOGGING_DIR, + save_layer_info: bool = False, + ): + """Initialize a debugger for TensorRT conversion. + + Args: + log_level (str): Logging level to use. Valid options are: + 'debug', 'info', 'warning', 'error', 'internal_errors', 'graphs'. + Defaults to 'debug'. + capture_fx_graph_before (List[str], optional): List of pass names to visualize FX graph + before execution of a lowering pass. Defaults to None. + capture_fx_graph_after (List[str], optional): List of pass names to visualize FX graph + after execution of a lowering pass. Defaults to None. + save_engine_profile (bool): Whether to save TensorRT engine profiling information. + Defaults to False. + profile_format (str): Format for profiling data. Choose from 'perfetto', 'trex', 'cudagraph'. + If you need to generate engine graph using the profiling files, set it to 'trex' and use the C++ runtime. + If you need to generate cudagraph visualization, set it to 'cudagraph'. + Defaults to 'perfetto'. + engine_builder_monitor (bool): Whether to monitor TensorRT engine building process. + Defaults to True. + logging_dir (str): Directory to save debug logs and profiles. + Defaults to system temp directory. + save_layer_info (bool): Whether to save layer info. + Defaults to False. + """ + + os.makedirs(logging_dir, exist_ok=True) + self.cfg = DebuggerConfig( + log_level=log_level, + save_engine_profile=save_engine_profile, + engine_builder_monitor=engine_builder_monitor, + logging_dir=logging_dir, + profile_format=profile_format, + save_layer_info=save_layer_info, + ) + + if log_level == "debug": + self.log_level = logging.DEBUG + elif log_level == "info": + self.log_level = logging.INFO + elif log_level == "warning": + self.log_level = logging.WARNING + elif log_level == "error": + self.log_level = logging.ERROR + elif log_level == "internal_errors": + self.log_level = logging.CRITICAL + elif log_level == "graphs": + self.log_level = GRAPH_LEVEL + + else: + raise ValueError( + f"Invalid level: {log_level}, allowed levels are: debug, info, warning, error, internal_errors, graphs" + ) + + self.capture_fx_graph_before = capture_fx_graph_before + self.capture_fx_graph_after = capture_fx_graph_after + + def __enter__(self) -> None: + self.original_lvl = _LOGGER.getEffectiveLevel() + self.rt_level = torch.ops.tensorrt.get_logging_level() + dictConfig(self.get_logging_config(self.log_level)) + + if self.capture_fx_graph_before or self.capture_fx_graph_after: + self.old_pre_passes, self.old_post_passes = ( + ATEN_PRE_LOWERING_PASSES.passes, + ATEN_POST_LOWERING_PASSES.passes, + ) + pre_pass_names = [p.__name__ for p in self.old_pre_passes] + post_pass_names = [p.__name__ for p in self.old_post_passes] + path = os.path.join(self.cfg.logging_dir, "lowering_passes_visualization") + if self.capture_fx_graph_before is not None: + pre_vis_passes = [ + p for p in self.capture_fx_graph_before if p in pre_pass_names + ] + post_vis_passes = [ + p for p in self.capture_fx_graph_before if p in post_pass_names + ] + ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(pre_vis_passes, path) + ATEN_POST_LOWERING_PASSES.insert_debug_pass_before( + post_vis_passes, path + ) + if self.capture_fx_graph_after is not None: + pre_vis_passes = [ + p for p in self.capture_fx_graph_after if p in pre_pass_names + ] + post_vis_passes = [ + p for p in self.capture_fx_graph_after if p in post_pass_names + ] + ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path) + ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path) + + self._context_stack = contextlib.ExitStack() + + for f in _DEBUG_ENABLED_FUNCS: + f.__kwdefaults__["_debugger_config"] = self.cfg + + [ + self._context_stack.enter_context( + mock.patch.object( + c, + "__init__", + functools.partialmethod(c.__init__, _debugger_config=self.cfg), + ) + ) + for c in _DEBUG_ENABLED_CLS + ] + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + + dictConfig(self.get_logging_config(None)) + torch.ops.tensorrt.set_logging_level(self.rt_level) + if self.capture_fx_graph_before or self.capture_fx_graph_after: + ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = ( + self.old_pre_passes, + self.old_post_passes, + ) + self.debug_file_dir = tempfile.TemporaryDirectory().name + + for f in _DEBUG_ENABLED_FUNCS: + f.__kwdefaults__["_debugger_config"] = None + + self._context_stack.close() + + def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: + level = log_level if log_level is not None else self.original_lvl + config: dict[str, Any] = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "brief": { + "format": "%(asctime)s - %(levelname)s - %(message)s", + "datefmt": "%H:%M:%S", + }, + "standard": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "console": { + "level": level, + "class": "logging.StreamHandler", + "formatter": "brief", + }, + }, + "loggers": { + "": { # root logger + "handlers": ["console"], + "level": level, + "propagate": True, + }, + }, + "force": True, + } + if log_level is not None: + config["handlers"]["file"] = { + "level": level, + "class": "logging.FileHandler", + "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", + "formatter": "standard", + } + config["loggers"][""]["handlers"].append("file") + return config diff --git a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py new file mode 100644 index 0000000000..27a5025e8b --- /dev/null +++ b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR + + +@dataclass +class DebuggerConfig: + log_level: str = "debug" + save_engine_profile: bool = False + engine_builder_monitor: bool = True + logging_dir: str = DEBUG_LOGGING_DIR + profile_format: str = "perfetto" + save_layer_info: bool = False diff --git a/py/torch_tensorrt/dynamo/debug/_supports_debugger.py b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py new file mode 100644 index 0000000000..2d9fd2a149 --- /dev/null +++ b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py @@ -0,0 +1,17 @@ +from typing import Any, Callable, Type, TypeVar + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) + +_DEBUG_ENABLED_FUNCS = [] +_DEBUG_ENABLED_CLS = [] + + +def fn_supports_debugger(func: F) -> F: + _DEBUG_ENABLED_FUNCS.append(func) + return func + + +def cls_supports_debugger(cls: Type[T]) -> Type[T]: + _DEBUG_ENABLED_CLS.append(cls) + return cls diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 2ecc45ecf3..c7fe264c5a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -17,7 +17,7 @@ from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices -pass_list = [ +post_lowering_pass_list = [ remove_input_alias_fixing_clones, constant_fold, repair_input_as_output, @@ -28,17 +28,19 @@ remove_num_users_is_0_nodes, ] -if not is_tegra_platform(): - pass_list.append(fuse_distributed_ops) +pre_lowering_pass_list = [ + remove_detach, +] -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) +if not is_tegra_platform(): + post_lowering_pass_list.append(fuse_distributed_ops) -ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - remove_detach, - ] +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + post_lowering_pass_list ) +ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pre_lowering_pass_list) + logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index c793b1e1c9..9c1f9e18d3 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -1,10 +1,30 @@ -from typing import Any, Callable, List, Optional, Sequence +import os +from typing import Any, Callable, List, Optional import torch +from torch.fx import passes from torch.fx.passes.pass_manager import PassManager +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR from torch_tensorrt.dynamo._settings import CompilationSettings +def _generate_draw_fx_graph_pass( + output_path_prefix: str, name: str +) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]: + def draw_fx_graph_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + if not os.path.exists(f"{output_path_prefix}/"): + os.makedirs(f"{output_path_prefix}/") + path = f"{output_path_prefix}/{name}.svg" + g = passes.graph_drawer.FxGraphDrawer(gm, name) + with open(path, "wb") as f: + f.write(g.get_dot_graph().create_svg()) + return gm + + return draw_fx_graph_pass + + class DynamoPassManager(PassManager): # type: ignore[misc] def __init__( self, @@ -15,8 +35,9 @@ def __init__( ] ] ] = None, + constraints: Optional[List[Callable]] = None, ): - super().__init__(passes) + super().__init__(passes, constraints) @classmethod def build_from_passlist( @@ -35,8 +56,7 @@ def build_from_passlist( def add_pass_with_index( self, lowering_pass: Callable[ - [torch.fx.GraphModule, CompilationSettings, Sequence[torch.Tensor]], - torch.fx.GraphModule, + [torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule ], index: Optional[int] = None, ) -> None: @@ -49,6 +69,65 @@ def add_pass_with_index( def remove_pass_with_index(self, index: int) -> None: del self.passes[index] + def insert_debug_pass_before( + self, passes: List[str], output_path_prefix: str = DEBUG_LOGGING_DIR + ) -> None: + """Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass. + + Args: + passes: List of pass names to insert debug passes before + output_path_prefix: Prefix to use for generated debug files + + Debug passes generate SVG visualizations of the FX graph at specified points + in the pass sequence. + """ + self.check_pass_names_valid(passes) + new_pass_list = [] + for ps in self.passes: + if ps.__name__ in passes: + new_pass_list.append( + _generate_draw_fx_graph_pass( + output_path_prefix, f"before_{ps.__name__}" + ) + ) + new_pass_list.append(ps) + + self.passes = new_pass_list + self._validated = False + + def insert_debug_pass_after( + self, passes: List[str], output_path_prefix: str = DEBUG_LOGGING_DIR + ) -> None: + """Insert debug passes in the PassManager pass sequence after the execution of a particular pass. + + Args: + passes: List of pass names to insert debug passes after + output_path_prefix: Prefix to use for generated debug files + + Debug passes generate SVG visualizations of the FX graph at specified points + in the pass sequence. + """ + self.check_pass_names_valid(passes) + new_pass_list = [] + for ps in self.passes: + new_pass_list.append(ps) + if ps.__name__ in passes: + new_pass_list.append( + _generate_draw_fx_graph_pass( + output_path_prefix, f"after_{ps.__name__}" + ) + ) + + self.passes = new_pass_list + self._validated = False + + def check_pass_names_valid(self, debug_pass_names: List[str]) -> None: + pass_names_str = [p.__name__ for p in self.passes] + for name in debug_pass_names: + assert ( + name in pass_names_str + ), f"{name} is not a valid pass! Passes: {pass_names_str}" + def __call__(self, gm: Any, settings: CompilationSettings) -> Any: self.validate() out = gm diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 429de3ffbb..2cb7fe43f5 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -13,14 +13,15 @@ ) from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet from torch_tensorrt.dynamo._defaults import ( - DEBUG, MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + ConverterRegistry, +) logger = logging.getLogger(__name__) @@ -250,7 +251,6 @@ def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: def partition( gm: torch.fx.GraphModule, - verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, @@ -286,7 +286,6 @@ def partition( partitioned_graph = partitioner.partition_graph() - if verbose: - supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) + supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) return partitioned_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index bdca0e1e1d..3279db00cf 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -7,14 +7,15 @@ from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupport, SupportDict from torch_tensorrt.dynamo._defaults import ( - DEBUG, MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + ConverterRegistry, +) logger = logging.getLogger(__name__) @@ -200,7 +201,6 @@ def print_support_overview( def partition( gm: torch.fx.GraphModule, - verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, @@ -229,7 +229,6 @@ def partition( # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_") - if verbose: - supported_ops.print_support_overview(len(partitions)) + supported_ops.print_support_overview(len(partitions)) return fused_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 685ec6ebef..e499e988a9 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -5,7 +5,6 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import DEBUG from torch_tensorrt.dynamo.utils import contains_sym_int, extract_var_range_info logger = logging.getLogger(__name__) @@ -169,7 +168,6 @@ def get_submodule_io( def get_graph_converter_support( graph_module: torch.fx.GraphModule, - verbose: bool = DEBUG, torch_executed_ops: Optional[Set[str]] = None, ) -> Tuple[int, int]: """Helper function to get converter support overview pre-partitioning @@ -199,7 +197,6 @@ def get_graph_converter_support( number_of_supported_nodes += 1 # Print node support overview prior to partitioning - if verbose: - op_support.print_support_overview(print_node_support=True) + op_support.print_support_overview(print_node_support=True) return number_of_supported_nodes, total_functional_nodes diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index cd732811b3..94eaa9b333 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,5 +1,6 @@ import inspect import logging +import warnings from copy import deepcopy from enum import Enum, auto from typing import Any, Dict, Iterator, Optional, Union @@ -85,7 +86,6 @@ def __init__( sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. - debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6415ce11c3..fc76b20141 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -12,6 +12,8 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -111,6 +113,7 @@ def set_runtime_states( ) +@cls_supports_debugger class PythonTorchTensorRTModule(Module): # type: ignore[misc] """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. @@ -128,6 +131,7 @@ def __init__( settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, + _debugger_config: Optional[DebuggerConfig] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -157,6 +161,7 @@ def __init__( """ self.context: Any + self._debugger_config: Optional[DebuggerConfig] = _debugger_config super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) @@ -193,7 +198,11 @@ def __init__( self.target_device_properties = torch.cuda.get_device_properties( self.target_device_id ) - self.profiling_enabled = settings.debug if settings.debug is not None else False + self.profiling_enabled = ( + _debugger_config.save_engine_profile + if _debugger_config is not None + else False + ) self.settings = settings self.engine = None self.weight_name_map = weight_name_map diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index c3fe925eee..95f1581881 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -334,7 +334,11 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: return tuple(outputs) - def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None: + def enable_profiling( + self, + profiling_results_dir: Optional[str] = None, + profile_format: str = "perfetto", + ) -> None: """Enable the profiler to collect latency information about the execution of the engine Traces can be visualized using https://ui.perfetto.dev/ or compatible alternatives @@ -347,7 +351,9 @@ def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None: if profiling_results_dir is not None: self.engine.profile_path_prefix = profiling_results_dir + assert profile_format in ["trex", "perfetto"] self.engine.enable_profiling() + self.engine.set_profile_format(profile_format) def disable_profiling(self) -> None: """Disable the profiler""" diff --git a/pyproject.toml b/pyproject.toml index c527db1fc6..f786812aae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,12 @@ dev = [ "pyyaml", ] +debug = [ + "pydot >= 4.0.0", + "tabulate >= 0.8.10", + "graphviz >= 0.20.3" +] + [project.optional-dependencies] torchvision = [ "torchvision", diff --git a/tools/debug/engine_visualization/README.md b/tools/debug/engine_visualization/README.md new file mode 100644 index 0000000000..90547b8ba9 --- /dev/null +++ b/tools/debug/engine_visualization/README.md @@ -0,0 +1,11 @@ +## Introduction +We use the TRT Engine Explorer (TREX) to visualize the engine graph structure. TREX is a diagnostic and profiling tool for TensorRT engine files. It allows you to inspect, benchmark, and debug TensorRT engines with ease. + +## Installation +```bash +pip install git+https://github.com/NVIDIA/TensorRT.git#subdirectory=tools/experimental/trt-engine-explorer +sudo apt --yes install graphviz +``` + +## Usage +The example usage can be found in `draw_engine_graph_example.py`. We use `torch_tensorrt.dynamo.debugger` to first output the engine profile info that required by TREX. Note that only when the compilation settings `use_python_runtime=False` can it produce TREX profiling. When it is saved to a folder, we call `draw_engine` on the same directory where the profile files are saved, which is in the subdirectory `engine_visualization_profile`. \ No newline at end of file diff --git a/tools/debug/engine_visualization/draw_engine_graph.py b/tools/debug/engine_visualization/draw_engine_graph.py new file mode 100644 index 0000000000..e2514e04c8 --- /dev/null +++ b/tools/debug/engine_visualization/draw_engine_graph.py @@ -0,0 +1,44 @@ +import argparse +import os +import re +import shutil +import subprocess +import warnings +from typing import Tuple + +import networkx as nx +import trex +import trex.engine_plan +import trex.graphing + + +def draw_engine(dir_path: str): + try: + import trex + except ImportError: + print("trex is required but it is not installed.\n") + print("Check README.md for installation instructions.") + exit() + + engine_json_fname = os.path.join( + dir_path, "_run_on_acc_0_engine_layer_information.json" + ) + profiling_json_fname = os.path.join( + dir_path, "_run_on_acc_0_engine_engine_execution_profile.trace" + ) + + graphviz_is_installed = shutil.which("dot") is not None + if not graphviz_is_installed: + print("graphviz is required but it is not installed.\n") + print("To install on Ubuntu:") + print("sudo apt --yes install graphviz") + exit() + + plan = trex.engine_plan.EnginePlan( + engine_json_fname, profiling_file=profiling_json_fname + ) + layer_node_formatter = trex.graphing.layer_type_formatter + graph = trex.graphing.to_dot(plan, layer_node_formatter) + output_format = "png" # svg or jpg + + trex.graphing.render_dot(graph, engine_json_fname, output_format) diff --git a/tools/debug/engine_visualization/draw_engine_graph_example.py b/tools/debug/engine_visualization/draw_engine_graph_example.py new file mode 100644 index 0000000000..e6236d0c59 --- /dev/null +++ b/tools/debug/engine_visualization/draw_engine_graph_example.py @@ -0,0 +1,34 @@ +import logging +import os + +import numpy as np +import torch +import torch_tensorrt as torch_tensorrt +import torchvision.models as models +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR + +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] +model = models.resnet18(pretrained=False).eval().to("cuda") +exp_program = torch.export.export(model, tuple(inputs)) + +with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=DEBUG_LOGGING_DIR, + capture_fx_graph_after=["constant_fold"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=False, +): + trt_gm = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + enabled_precisions={torch.float}, + truncate_double=True, + use_python_runtime=False, + min_block_size=1, + ) + trt_output = trt_gm(*inputs) + + from draw_engine_graph import draw_engine + + draw_engine(os.path.join(DEBUG_LOGGING_DIR, "engine_visualization_profile"))