Skip to content

Commit 9386b31

Browse files
committed
On destroy - acquire gil
1 parent a1d1657 commit 9386b31

2 files changed

Lines changed: 58 additions & 46 deletions

File tree

tvm-python/PyGlobal.cpp

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,78 @@
22

33
#include <memory>
44
#include <mutex>
5+
#include "third-party/pybind11/include/pybind11/pybind11.h"
56

67
#include "PyGlobal.h"
78

89
std::mutex scheduler_init_mutex;
910
bool scheduler_running = false;
1011
static std::unique_ptr<td::actor::Scheduler> thread_local_scheduler;
1112
static std::unique_ptr<td::thread> scheduler_thread;
13+
namespace py = pybind11;
1214

1315
namespace pyglobal {
14-
class Runner : public td::actor::Actor {
15-
public:
16-
explicit Runner(std::function<void()> f) : f_(std::move(f)) {
17-
}
16+
class Runner : public td::actor::Actor {
17+
public:
18+
explicit Runner(std::function<void()> f)
19+
: f_(std::move(f)) {
20+
}
1821

19-
void start_up() override {
20-
f_();
21-
stop();
22-
}
22+
~Runner() override {
23+
py::gil_scoped_acquire gil;
24+
f_ = {};
25+
}
2326

24-
private:
25-
std::function<void()> f_;
26-
};
2727

28+
void start_up() override {
29+
f_();
30+
stop();
31+
}
2832

29-
void init_thread_scheduler(int thread_count) {
30-
std::lock_guard<std::mutex> lock(scheduler_init_mutex);
31-
if (!thread_local_scheduler) {
32-
thread_local_scheduler = std::unique_ptr<td::actor::Scheduler>(new td::actor::Scheduler({thread_count}));
33-
scheduler_running = true;
33+
private:
34+
std::function<void()> f_;
35+
};
3436

35-
scheduler_thread = std::make_unique<td::thread>([&] {
36-
thread_local_scheduler->run();
37-
});
38-
}
39-
}
4037

41-
td::actor::Scheduler *get_thread_scheduler() {
42-
if (!scheduler_running) {
43-
init_thread_scheduler(6);
44-
}
45-
return thread_local_scheduler.get();
46-
}
38+
void init_thread_scheduler(int thread_count) {
39+
std::lock_guard<std::mutex> lock(scheduler_init_mutex);
40+
if (!thread_local_scheduler) {
41+
thread_local_scheduler = std::unique_ptr<td::actor::Scheduler>(new td::actor::Scheduler({thread_count}));
42+
scheduler_running = true;
4743

48-
void stop_scheduler_thread() {
49-
if (scheduler_running) {
50-
std::lock_guard<std::mutex> lock(scheduler_init_mutex);
51-
if (thread_local_scheduler) {
52-
thread_local_scheduler->run_in_context_external([] {
53-
td::actor::SchedulerContext::get()->stop();
54-
});
55-
}
56-
if (scheduler_thread) {
57-
scheduler_thread->join();
58-
scheduler_thread.reset();
59-
}
60-
thread_local_scheduler.reset();
61-
scheduler_running = false;
62-
}
63-
}
44+
scheduler_thread = std::make_unique<td::thread>([&] {
45+
thread_local_scheduler->run();
46+
});
47+
}
48+
}
6449

65-
void execute_async(std::function<void()> f) {
66-
get_thread_scheduler()->run_in_context_external([&] {
67-
td::actor::create_actor<pyglobal::Runner>("executeasync", std::move(f)).release();
50+
td::actor::Scheduler* get_thread_scheduler() {
51+
if (!scheduler_running) {
52+
init_thread_scheduler(6);
53+
}
54+
return thread_local_scheduler.get();
55+
}
56+
57+
void stop_scheduler_thread() {
58+
if (scheduler_running) {
59+
std::lock_guard<std::mutex> lock(scheduler_init_mutex);
60+
if (thread_local_scheduler) {
61+
thread_local_scheduler->run_in_context_external([] {
62+
td::actor::SchedulerContext::get()->stop();
6863
});
6964
}
65+
if (scheduler_thread) {
66+
scheduler_thread->join();
67+
scheduler_thread.reset();
68+
}
69+
thread_local_scheduler.reset();
70+
scheduler_running = false;
71+
}
72+
}
73+
74+
void execute_async(std::function<void()> f) {
75+
get_thread_scheduler()->run_in_context_external([&] {
76+
td::actor::create_actor<pyglobal::Runner>("executeasync", std::move(f)).release();
77+
});
7078
}
79+
}

tvm-python/PyTVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ void PyTVM::set_c7(PyStackEntry x) {
3535

3636
void PyTVM::log(const std::string &log_string, int level) {
3737
if (log_level >= level && level == LOG_INFO) {
38+
py::gil_scoped_acquire gil;
3839
py::print("INFO: " + log_string);
3940
} else if (log_level >= level && level == LOG_DEBUG) {
41+
py::gil_scoped_acquire gil;
4042
py::print("DEBUG: " + log_string);
4143
}
4244
}
@@ -129,6 +131,7 @@ class PythonLogger : public td::LogInterface {
129131
}
130132

131133
if (!muted) {
134+
py::gil_scoped_acquire gil;
132135
py::print(slice.str());
133136
}
134137
}

0 commit comments

Comments
 (0)