2828
2929#include < cstdint>
3030#include < exception>
31+ #include < mutex>
3132
3233#include " libtorch_utils.h"
3334#include " triton/backend/backend_common.h"
6667// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6768//
6869
70+ namespace {
71+ std::once_flag pytorch_interop_threads_flag;
72+ std::once_flag pytorch_intraop_threads_flag;
73+ } // namespace
74+
6975namespace triton { namespace backend { namespace pytorch {
7076
7177//
@@ -509,11 +515,15 @@ ModelState::ParseParameters()
509515 }
510516 } else {
511517 if (intra_op_thread_count > 0 ) {
512- at::set_num_threads (intra_op_thread_count);
518+ // at::set_num_threads() does not throw if called more than once, but
519+ // issues warnings. std::call_once() is useful to limit these.
520+ std::call_once (pytorch_intraop_threads_flag, [intra_op_thread_count]() {
521+ at::set_num_threads (intra_op_thread_count);
522+ });
513523 LOG_MESSAGE (
514524 TRITONSERVER_LOG_INFO,
515525 (std::string (" Intra op thread count is set to " ) +
516- std::to_string (intra_op_thread_count ) + " for model instance '" +
526+ std::to_string (at::get_num_threads () ) + " for model instance '" +
517527 Name () + " '" )
518528 .c_str ());
519529 }
@@ -533,12 +543,22 @@ ModelState::ParseParameters()
533543 }
534544 } else {
535545 if (inter_op_thread_count > 0 ) {
536- at::set_num_interop_threads (inter_op_thread_count);
546+ // at::set_num_interop_threads() throws if called more than once.
547+ // std::call_once() should prevent this, but try/catch is additionally
548+ // used for safety.
549+ std::call_once (pytorch_interop_threads_flag, [inter_op_thread_count]() {
550+ try {
551+ at::set_num_interop_threads (inter_op_thread_count);
552+ }
553+ catch (const c10::Error& e) {
554+ // do nothing
555+ }
556+ });
537557 LOG_MESSAGE (
538558 TRITONSERVER_LOG_INFO,
539559 (std::string (" Inter op thread count is set to " ) +
540- std::to_string (inter_op_thread_count) + " for model instance ' " +
541- Name () + " '" )
560+ std::to_string (at::get_num_interop_threads ()) +
561+ " for model instance ' " + Name () + " '" )
542562 .c_str ());
543563 }
544564 }
0 commit comments