Skip to content

Commit e9cd781

Browse files
Ilia Cherniavskiifacebook-github-bot
Ilia Cherniavskii
authored andcommitted
Back out "Revert D13043261: [caffe2] Task graph and task future abstractions in executor"
Summary: Pull Request resolved: pytorch#15030 Reviewed By: bddppq Differential Revision: D13408998 fbshipit-source-id: 9eb675e09fbc4829eab34df7aa660a0590816feb
1 parent 83f32ee commit e9cd781

9 files changed

+824
-1
lines changed

caffe2/core/net_async_task.cc

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#include "caffe2/core/net_async_task.h"
2+
3+
#include "caffe2/core/net_async_task_graph.h"
4+
5+
namespace caffe2 {
6+
7+
AsyncTask::AsyncTask(const std::vector<OperatorBase*>& ops) : ops_(ops) {
8+
CAFFE_ENFORCE(!ops_.empty());
9+
device_option_ = ops_.front()->device_option();
10+
for (auto& op : ops_) {
11+
CAFFE_ENFORCE(IsSameDevice(device_option_, op->device_option()));
12+
}
13+
Reset();
14+
}
15+
16+
void AsyncTask::handleChainError(
17+
OperatorBase* op,
18+
const char* err_str,
19+
bool save_exception) {
20+
std::string err_msg = err_str;
21+
if (op) {
22+
err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown");
23+
}
24+
LOG(ERROR) << err_msg;
25+
26+
// save error message and exception in chain's Event
27+
auto last_op = ops_.back();
28+
if (save_exception) {
29+
last_op->event().SetFinishedWithException(err_msg.c_str());
30+
} else {
31+
last_op->event().SetFinished(err_msg.c_str());
32+
}
33+
34+
// set future as completed with an error
35+
// TODO: exceptions in future
36+
future_.SetCompleted(err_msg.c_str());
37+
}
38+
39+
bool AsyncTask::Run(const ExecutionOptions& options) {
40+
// TODO: insert CUDA's async stream waits; tracing and counters
41+
OperatorBase* op = nullptr;
42+
try {
43+
for (auto op_idx = 0; op_idx < ops_.size(); ++op_idx) {
44+
op = ops_[op_idx];
45+
int stream_id = 0; // TODO: thread local stream id
46+
if (!op->RunAsync(stream_id)) {
47+
handleChainError(op, "Failed to execute an op");
48+
return false;
49+
}
50+
}
51+
52+
if (options.finish_chain_) {
53+
op = ops_.back();
54+
op->Finish();
55+
}
56+
57+
// set the future as successfully completed or, in case of async CPU,
58+
// use op's callback
59+
if (IsCPUDeviceType(device_option_.device_type()) &&
60+
ops_.back()->HasAsyncPart()) {
61+
auto& event = ops_.back()->event();
62+
event.SetCallback([this, &event]() {
63+
CAFFE_ENFORCE(event.IsFinished());
64+
if (event.Query() == EventStatus::EVENT_SUCCESS) {
65+
future_.SetCompleted();
66+
} else {
67+
// TODO: support for exceptions
68+
future_.SetCompleted(event.ErrorMessage().c_str());
69+
}
70+
});
71+
} else {
72+
future_.SetCompleted();
73+
}
74+
} catch (const std::exception& e) {
75+
handleChainError(op, e.what(), /* save_exception */ true);
76+
return false;
77+
} catch (...) {
78+
handleChainError(
79+
op,
80+
"Failed to execute task: unknown error",
81+
/* save_exception */ true);
82+
return false;
83+
}
84+
85+
return true;
86+
}
87+
88+
void AsyncTask::Reset() {
89+
for (auto& op : ops_) {
90+
op->ResetEvent();
91+
}
92+
future_.ResetState();
93+
}
94+
95+
DeviceOption AsyncTask::GetDeviceOption() const {
96+
return device_option_;
97+
}
98+
99+
AsyncTaskFuture& AsyncTask::GetFuture() {
100+
return future_;
101+
}
102+
103+
const AsyncTaskFuture& AsyncTask::GetFuture() const {
104+
return future_;
105+
}
106+
107+
}; // namespace caffe2

caffe2/core/net_async_task.h

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef CAFFE2_NET_ASYNC_TASK_H
2+
#define CAFFE2_NET_ASYNC_TASK_H
3+
4+
#include "caffe2/core/net_async_base.h"
5+
#include "caffe2/core/net_async_task_future.h"
6+
#include "caffe2/core/operator.h"
7+
8+
#include <vector>
9+
10+
namespace caffe2 {
11+
12+
// AsyncTask represents an asynchronous execution of a chain of ops.
13+
class AsyncTask {
14+
public:
15+
AsyncTask(const std::vector<OperatorBase*>& ops);
16+
17+
bool Run(const ExecutionOptions& options);
18+
19+
void Reset();
20+
21+
DeviceOption GetDeviceOption() const;
22+
23+
AsyncTaskFuture& GetFuture();
24+
const AsyncTaskFuture& GetFuture() const;
25+
26+
private:
27+
void handleChainError(
28+
OperatorBase* op,
29+
const char* err_msg,
30+
bool save_exception = false);
31+
32+
std::vector<OperatorBase*> ops_;
33+
DeviceOption device_option_;
34+
AsyncTaskFuture future_;
35+
};
36+
37+
} // namespace caffe2
38+
39+
#endif // CAFFE2_NET_ASYNC_TASK_H

caffe2/core/net_async_task_future.cc

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#include "caffe2/core/net_async_task_future.h"
2+
3+
#include "c10/util/Logging.h"
4+
#include "caffe2/core/common.h"
5+
6+
namespace caffe2 {
7+
8+
AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {}
9+
10+
AsyncTaskFuture::AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures)
11+
: completed_(false), failed_(false) {
12+
if (futures.size() > 1) {
13+
parent_counter_ = caffe2::make_unique<ParentCounter>(futures.size());
14+
for (auto future : futures) {
15+
future->SetCallback([this](const AsyncTaskFuture* f) {
16+
if (f->IsFailed()) {
17+
std::unique_lock<std::mutex> lock(parent_counter_->err_mutex);
18+
if (parent_counter_->parent_failed) {
19+
parent_counter_->err_msg += ", " + f->ErrorMessage();
20+
} else {
21+
parent_counter_->parent_failed = true;
22+
parent_counter_->err_msg = f->ErrorMessage();
23+
}
24+
}
25+
int count = --parent_counter_->parent_count;
26+
if (count == 0) {
27+
// thread safe to use parent_counter here
28+
if (!parent_counter_->parent_failed) {
29+
SetCompleted();
30+
} else {
31+
SetCompleted(parent_counter_->err_msg.c_str());
32+
}
33+
}
34+
});
35+
}
36+
} else {
37+
CAFFE_ENFORCE_EQ(futures.size(), 1);
38+
auto future = futures.back();
39+
future->SetCallback([this](const AsyncTaskFuture* f) {
40+
if (!f->IsFailed()) {
41+
SetCompleted();
42+
} else {
43+
SetCompleted(f->ErrorMessage().c_str());
44+
}
45+
});
46+
}
47+
}
48+
49+
bool AsyncTaskFuture::IsCompleted() const {
50+
return completed_;
51+
}
52+
53+
bool AsyncTaskFuture::IsFailed() const {
54+
return failed_;
55+
}
56+
57+
std::string AsyncTaskFuture::ErrorMessage() const {
58+
return err_msg_;
59+
}
60+
61+
void AsyncTaskFuture::Wait() const {
62+
std::unique_lock<std::mutex> lock(mutex_);
63+
while (!completed_) {
64+
cv_completed_.wait(lock);
65+
}
66+
}
67+
68+
void AsyncTaskFuture::SetCallback(
69+
std::function<void(const AsyncTaskFuture*)> callback) {
70+
std::unique_lock<std::mutex> lock(mutex_);
71+
72+
callbacks_.push_back(callback);
73+
if (completed_) {
74+
callback(this);
75+
}
76+
}
77+
78+
void AsyncTaskFuture::SetCompleted(const char* err_msg) {
79+
std::unique_lock<std::mutex> lock(mutex_);
80+
81+
CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future");
82+
completed_ = true;
83+
84+
if (err_msg) {
85+
failed_ = true;
86+
err_msg_ = err_msg;
87+
}
88+
89+
for (auto& callback : callbacks_) {
90+
callback(this);
91+
}
92+
93+
cv_completed_.notify_all();
94+
}
95+
96+
// ResetState is called on a completed future,
97+
// does not reset callbacks to keep task graph structure
98+
void AsyncTaskFuture::ResetState() {
99+
std::unique_lock<std::mutex> lock(mutex_);
100+
if (parent_counter_) {
101+
parent_counter_->Reset();
102+
}
103+
completed_ = false;
104+
failed_ = false;
105+
err_msg_ = "";
106+
}
107+
108+
AsyncTaskFuture::~AsyncTaskFuture() {}
109+
110+
} // namespace caffe2

caffe2/core/net_async_task_future.h

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#ifndef CAFFE2_NET_ASYNC_TASK_FUTURE_H
2+
#define CAFFE2_NET_ASYNC_TASK_FUTURE_H
3+
4+
#include <atomic>
5+
#include <condition_variable>
6+
#include <functional>
7+
#include <memory>
8+
#include <mutex>
9+
#include <string>
10+
#include <vector>
11+
12+
namespace caffe2 {
13+
14+
// Represents the state of AsyncTask execution, that can be queried with
15+
// IsCompleted/IsFailed. Callbacks are supported through SetCallback and
16+
// are called upon future's completion.
17+
18+
class AsyncTaskFuture {
19+
public:
20+
AsyncTaskFuture();
21+
// Creates a future completed when all given futures are completed
22+
explicit AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures);
23+
~AsyncTaskFuture();
24+
25+
AsyncTaskFuture(const AsyncTaskFuture&) = delete;
26+
27+
AsyncTaskFuture& operator=(const AsyncTaskFuture&) = delete;
28+
29+
bool IsCompleted() const;
30+
31+
bool IsFailed() const;
32+
33+
std::string ErrorMessage() const;
34+
35+
void Wait() const;
36+
37+
void SetCallback(std::function<void(const AsyncTaskFuture*)> callback);
38+
39+
void SetCompleted(const char* err_msg = nullptr);
40+
41+
void ResetState();
42+
43+
private:
44+
mutable std::mutex mutex_;
45+
mutable std::condition_variable cv_completed_;
46+
std::atomic<bool> completed_;
47+
std::atomic<bool> failed_;
48+
std::string err_msg_;
49+
std::vector<std::function<void(const AsyncTaskFuture*)>> callbacks_;
50+
51+
struct ParentCounter {
52+
explicit ParentCounter(int init_parent_count)
53+
: init_parent_count_(init_parent_count),
54+
parent_count(init_parent_count),
55+
parent_failed(false) {}
56+
57+
void Reset() {
58+
std::unique_lock<std::mutex> lock(err_mutex);
59+
parent_count = init_parent_count_;
60+
parent_failed = false;
61+
err_msg = "";
62+
}
63+
64+
const int init_parent_count_;
65+
std::atomic<int> parent_count;
66+
std::mutex err_mutex;
67+
std::atomic<bool> parent_failed;
68+
std::string err_msg;
69+
};
70+
71+
std::unique_ptr<ParentCounter> parent_counter_;
72+
};
73+
74+
} // namespace caffe2
75+
76+
#endif // CAFFE2_NET_ASYNC_TASK_FUTURE_H

0 commit comments

Comments
 (0)