|
1 | 1 | <!--
|
2 |
| -# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
3 | 3 | #
|
4 | 4 | # Redistribution and use in source and binary forms, with or without
|
5 | 5 | # modification, are permitted provided that the following conditions
|
@@ -243,3 +243,127 @@ instance in the
|
243 | 243 | [model configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#instance-groups)
|
244 | 244 | to ensure that the model instance and the tensors used for inference are
|
245 | 245 | assigned to the same GPU device as on which the model was traced.
|
| 246 | + |
| 247 | +# PyTorch 2.0 Backend \[Experimental\] |
| 248 | + |
| 249 | +> [!WARNING] |
| 250 | +> *This feature is subject to change and removal.* |
| 251 | +
|
| 252 | +Starting from 24.01, PyTorch models can be served directly via |
| 253 | +[Python runtime](src/model.py). By default, Triton will use the |
| 254 | +[LibTorch runtime](#pytorch-libtorch-backend) for PyTorch models. To use Python |
| 255 | +runtime, provide the following |
| 256 | +[runtime setting](https://github.com/triton-inference-server/backend/blob/main/README.md#backend-shared-library) |
| 257 | +in the model configuration: |
| 258 | + |
| 259 | +``` |
| 260 | +runtime: "model.py" |
| 261 | +``` |
| 262 | + |
| 263 | +## Dependencies |
| 264 | + |
| 265 | +### Python backend dependency |
| 266 | + |
| 267 | +This feature depends on |
| 268 | +[Python backend](https://github.com/triton-inference-server/python_backend), |
| 269 | +see |
| 270 | +[Python-based Backends](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md) |
| 271 | +for more details. |
| 272 | + |
| 273 | +### PyTorch dependency |
| 274 | + |
| 275 | +This feature will take advantage of the |
| 276 | +[`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile) |
| 277 | +optimization, make sure the |
| 278 | +[PyTorch 2.0+ pip package](https://pypi.org/project/torch) is available in the |
| 279 | +same Python environment. |
| 280 | + |
| 281 | +Alternatively, a [Python Execution Environment](#using-custom-python-execution-environments) |
| 282 | +with the PyTorch dependency may be used. It can be created with the |
| 283 | +[provided script](tools/gen_pb_exec_env.sh). The resulting |
| 284 | +`pb_exec_env_model.py.tar.gz` file should be placed at the same |
| 285 | +[backend shared library](https://github.com/triton-inference-server/backend/blob/main/README.md#backend-shared-library) |
| 286 | +directory as the [Python runtime](src/model.py). |
| 287 | + |
| 288 | +## Model Layout |
| 289 | + |
| 290 | +### PyTorch 2.0 models |
| 291 | + |
| 292 | +The model repository should look like: |
| 293 | + |
| 294 | +``` |
| 295 | +model_repository/ |
| 296 | +`-- model_directory |
| 297 | + |-- 1 |
| 298 | + | |-- model.py |
| 299 | + | `-- [model.pt] |
| 300 | + `-- config.pbtxt |
| 301 | +``` |
| 302 | + |
| 303 | +The `model.py` contains the class definition of the PyTorch model. The class |
| 304 | +should extend the |
| 305 | +[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module). |
| 306 | +The `model.pt` may be optionally provided which contains the saved |
| 307 | +[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference) |
| 308 | +of the model. |
| 309 | + |
| 310 | +### TorchScript models |
| 311 | + |
| 312 | +The model repository should look like: |
| 313 | + |
| 314 | +``` |
| 315 | +model_repository/ |
| 316 | +`-- model_directory |
| 317 | + |-- 1 |
| 318 | + | `-- model.pt |
| 319 | + `-- config.pbtxt |
| 320 | +``` |
| 321 | + |
| 322 | +The `model.pt` is the TorchScript model file. |
| 323 | + |
| 324 | +## Customization |
| 325 | + |
| 326 | +The following PyTorch settings may be customized by setting parameters on the |
| 327 | +`config.pbtxt`. |
| 328 | + |
| 329 | +[`torch.set_num_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads) |
| 330 | +- Key: NUM_THREADS |
| 331 | +- Value: The number of threads used for intraop parallelism on CPU. |
| 332 | + |
| 333 | +[`torch.set_num_interop_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_interop_threads.html#torch.set_num_interop_threads) |
| 334 | +- Key: NUM_INTEROP_THREADS |
| 335 | +- Value: The number of threads used for interop parallelism (e.g. in JIT |
| 336 | +interpreter) on CPU. |
| 337 | + |
| 338 | +[`torch.compile()` parameters](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile) |
| 339 | +- Key: TORCH_COMPILE_OPTIONAL_PARAMETERS |
| 340 | +- Value: Any of following parameter(s) encoded as a JSON object. |
| 341 | + - fullgraph (*bool*): Whether it is ok to break model into several subgraphs. |
| 342 | + - dynamic (*bool*): Use dynamic shape tracing. |
| 343 | + - backend (*str*): The backend to be used. |
| 344 | + - mode (*str*): Can be either "default", "reduce-overhead" or "max-autotune". |
| 345 | + - options (*dict*): A dictionary of options to pass to the backend. |
| 346 | + - disable (*bool*): Turn `torch.compile()` into a no-op for testing. |
| 347 | + |
| 348 | +For example: |
| 349 | +``` |
| 350 | +parameters: { |
| 351 | + key: "NUM_THREADS" |
| 352 | + value: { string_value: "4" } |
| 353 | +} |
| 354 | +parameters: { |
| 355 | + key: "TORCH_COMPILE_OPTIONAL_PARAMETERS" |
| 356 | + value: { string_value: "{\"disable\": true}" } |
| 357 | +} |
| 358 | +``` |
| 359 | + |
| 360 | +## Limitations |
| 361 | + |
| 362 | +Following are few known limitations of this feature: |
| 363 | +- Python functions optimizable by `torch.compile` may not be served directly in |
| 364 | +the `model.py` file, they need to be enclosed by a class extending the |
| 365 | +[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module). |
| 366 | +- Model weights cannot be shared across multiple instances on the same GPU |
| 367 | +device. |
| 368 | +- When using `KIND_MODEL` as model instance kind, the default device of the |
| 369 | +first parameter on the model is used. |
0 commit comments