Skip to content

Commit b855365

Browse files
authored
Protect thread setting call (#159)
* use call_once to prevent repeated thread count setting * update docs for thread parameters * always emit a message with number of threads * apply formatting * readme formatting
1 parent abafeb6 commit b855365

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ key: "ENABLE_CACHE_CLEANING"
200200
* `INTER_OP_THREAD_COUNT`:
201201

202202
PyTorch allows using multiple CPU threads during TorchScript model inference.
203-
One or more inference threads execute a models forward pass on the given
203+
One or more inference threads execute a model's forward pass on the given
204204
inputs. Each inference thread invokes a JIT interpreter that executes the ops
205205
of a model inline, one by one. This parameter sets the size of this thread
206206
pool. The default value of this setting is the number of cpu cores. Please refer
@@ -218,6 +218,11 @@ key: "INTER_OP_THREAD_COUNT"
218218
}
219219
```
220220

221+
> [!NOTE]
222+
> This parameter is set globally for the PyTorch backend.
223+
> The value from the first model config file that specifies this parameter will be used.
224+
> Subsequent values from other model config files, if different, will be ignored.
225+
221226
* `INTRA_OP_THREAD_COUNT`:
222227

223228
In addition to the inter-op parallelism, PyTorch can also utilize multiple threads
@@ -238,6 +243,11 @@ key: "INTRA_OP_THREAD_COUNT"
238243
}
239244
```
240245

246+
> [!NOTE]
247+
> This parameter is set globally for the PyTorch backend.
248+
> The value from the first model config file that specifies this parameter will be used.
249+
> Subsequent values from other model config files, if different, will be ignored.
250+
241251
* Additional Optimizations: Three additional boolean parameters are available to disable
242252
certain Torch optimizations that can sometimes cause latency regressions in models with
243253
complex execution modes and dynamic shapes. If not specified, all are enabled by default.

src/libtorch.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <cstdint>
3030
#include <exception>
31+
#include <mutex>
3132

3233
#include "libtorch_utils.h"
3334
#include "triton/backend/backend_common.h"
@@ -66,6 +67,11 @@
6667
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6768
//
6869

70+
namespace {
71+
std::once_flag pytorch_interop_threads_flag;
72+
std::once_flag pytorch_intraop_threads_flag;
73+
} // namespace
74+
6975
namespace triton { namespace backend { namespace pytorch {
7076

7177
//
@@ -509,11 +515,15 @@ ModelState::ParseParameters()
509515
}
510516
} else {
511517
if (intra_op_thread_count > 0) {
512-
at::set_num_threads(intra_op_thread_count);
518+
// at::set_num_threads() does not throw if called more than once, but
519+
// issues warnings. std::call_once() is useful to limit these.
520+
std::call_once(pytorch_intraop_threads_flag, [intra_op_thread_count]() {
521+
at::set_num_threads(intra_op_thread_count);
522+
});
513523
LOG_MESSAGE(
514524
TRITONSERVER_LOG_INFO,
515525
(std::string("Intra op thread count is set to ") +
516-
std::to_string(intra_op_thread_count) + " for model instance '" +
526+
std::to_string(at::get_num_threads()) + " for model instance '" +
517527
Name() + "'")
518528
.c_str());
519529
}
@@ -533,12 +543,22 @@ ModelState::ParseParameters()
533543
}
534544
} else {
535545
if (inter_op_thread_count > 0) {
536-
at::set_num_interop_threads(inter_op_thread_count);
546+
// at::set_num_interop_threads() throws if called more than once.
547+
// std::call_once() should prevent this, but try/catch is additionally
548+
// used for safety.
549+
std::call_once(pytorch_interop_threads_flag, [inter_op_thread_count]() {
550+
try {
551+
at::set_num_interop_threads(inter_op_thread_count);
552+
}
553+
catch (const c10::Error& e) {
554+
// do nothing
555+
}
556+
});
537557
LOG_MESSAGE(
538558
TRITONSERVER_LOG_INFO,
539559
(std::string("Inter op thread count is set to ") +
540-
std::to_string(inter_op_thread_count) + " for model instance '" +
541-
Name() + "'")
560+
std::to_string(at::get_num_interop_threads()) +
561+
" for model instance '" + Name() + "'")
542562
.c_str());
543563
}
544564
}

0 commit comments

Comments
 (0)