Skip to content

Commit

Permalink
[STF] Support generation of multiple CUDA graphs from separate threads (
Browse files Browse the repository at this point in the history
#3943)

* Introduce a mutex to protect the underlying CUDA graph of a graph context so that we can generate tasks concurrently

* Protect CUDA graphs against concurrent accesses

* do test results

* remove dead code

* Add a test with graph capture and threads

* Add and use with_locked_graph, also use weak_ptr

* Big code simplification ! no need for shared_ptr/weak_ptr with the mutex that will outlive its users

* use std::reference_wrapper to have graph_task be moved assignable

* replace complicated API based on lambda function with a simple lock guard

* restore a comment removed by mistake

* capture with the lock taken because we need to ensure the captured stream is not used concurrently

* Fix build

* comment why we put a reference_wrapper

* Add a sanity check to ensure we have finished all tasks

* Add a test which uses multiple threads to generate CUDA graphs

* clang-format

* Rework how we pass frozen logical data, not logical data to threads

* atomic variables are not automatically initialized, so we do set them

* Add missing mutex headers

* improve readability

* Save WIP : there currently needs to be a lock around the creation of the ctx

* Improve thread safety in dot, and reduce the visibility of methods which need not be public

* reserved::dot::finish needs to be public

* Add missing header

* There is no need to use a mutex in the ctor of a singleton

* Simplify example again

---------

Co-authored-by: Andrei Alexandrescu <[email protected]>
  • Loading branch information
caugonnet and andralex authored Feb 28, 2025
1 parent 82befb0 commit 83ba38c
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ protected:
generate_event_symbols = dot->is_tracing_prereqs();

// Record it in the list of all traced contexts
reserved::dot::instance().per_ctx.push_back(dot);
reserved::dot::instance().track_ctx(dot);
}

virtual ~impl()
Expand Down
283 changes: 148 additions & 135 deletions cudax/include/cuda/experimental/__stf/internal/dot.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,6 @@ public:
return ctx_symbol;
}

private:
mutable ::std::string ctx_symbol;

mutable ::std::mutex mtx;

::std::vector<int> vertices;

::std::unordered_set<int> discarded_tasks;

public:
// Keep track of existing edges, to make the output possibly look better
IntPairSet existing_edges;
Expand Down Expand Up @@ -338,6 +329,15 @@ public: // XXX protected, friend : dot
// strings of the previous epochs
mutable ::std::vector<::std::ostringstream> prev_oss;
::std::unordered_map<int /* id */, per_task_info> metadata;

private:
mutable ::std::string ctx_symbol;

mutable ::std::mutex mtx;

::std::vector<int> vertices;

::std::unordered_set<int> discarded_tasks;
};

class dot : public reserved::meyers_singleton<dot>
Expand Down Expand Up @@ -487,6 +487,8 @@ public:
protected:
dot()
{
::std::lock_guard<::std::mutex> lock(mtx);

const char* filename = getenv("CUDASTF_DOT_FILE");
if (!filename)
{
Expand All @@ -509,6 +511,137 @@ protected:
}

public:
bool is_tracing() const
{
return !dot_filename.empty();
}

bool is_tracing_prereqs()
{
return tracing_prereqs;
}

bool is_timing() const
{
return enable_timing;
}

// Add a context to the vector of contexts we track
void track_ctx(::std::shared_ptr<per_ctx_dot> pc)
{
::std::lock_guard<::std::mutex> lock(mtx);

per_ctx.push_back(mv(pc));
}

// This should not need to be called explicitly, unless we are doing some automatic tests for example
void finish()
{
single_threaded_section guard(mtx);

if (dot_filename.empty())
{
return;
}

for (const auto& pc : per_ctx)
{
pc->finish();
}

collapse_sections();

// Now we have executed all tasks, so we can compute the average execution
// times, and update the colors appropriately if needed.
update_colors_with_timing();

::std::ofstream outFile(dot_filename);
if (outFile.is_open())
{
outFile << "digraph {\n";
size_t ctx_cnt = 0;
bool display_clusters = (per_ctx.size() > 1);
/*
* For every context, we write the description of the DAG per
* epoch. Then we write the edges after removing redundant ones.
*/
for (const auto& pc : per_ctx)
{
// If the context has a parent, it will be printed by this parent itself
if (!pc->parent)
{
print_one_context(outFile, ctx_cnt, display_clusters, pc);
}
}

if (!getenv("CUDASTF_DOT_KEEP_REDUNDANT"))
{
remove_redundant_edges(existing_edges);
}

compute_critical_path(outFile);

/* Edges do not have to belong to the cluster (Vertices do) */
for (const auto& [from, to] : existing_edges)
{
outFile << "\"NODE_" << from << "\" -> \"NODE_" << to << "\"\n";
}

// Update node properties such as labels and colors now that we have all information
vertex_count = 0;
for (const auto& pc : per_ctx)
{
for (const auto& p : pc->metadata)
{
outFile << "\"NODE_" << p.first << "\" [style=\"filled\" fillcolor=\"" << p.second.color << "\" label=\""
<< p.second.label << "\"]\n";
vertex_count++;
}
}

edge_count = existing_edges.size();

outFile << "// Edge count : " << edge_count << "\n";
outFile << "// Vertex count : " << vertex_count << "\n";

outFile << "}\n";

outFile.close();
}
else
{
::std::cerr << "Unable to open file: " << dot_filename << ::std::endl;
}

const char* stats_filename_str = getenv("CUDASTF_DOT_STATS_FILE");
if (stats_filename_str)
{
::std::string stats_filename = stats_filename_str;
::std::ofstream statsFile(stats_filename);
if (statsFile.is_open())
{
statsFile << "#nedges,nvertices,total_work,critical_path\n";

// to display an optional value or NA
auto formatOptional = [](const ::std::optional<float>& opt) -> ::std::string {
return opt ? ::std::to_string(*opt) : "NA";
};

statsFile << edge_count << "," << vertex_count << "," << formatOptional(total_work) << ","
<< formatOptional(critical_path) << "\n";

statsFile.close();
}
else
{
::std::cerr << "Unable to open file: " << stats_filename << ::std::endl;
}
}

dot_filename.clear();
}

private:
void
print_one_context(::std::ofstream& outFile, size_t& ctx_cnt, bool display_clusters, ::std::shared_ptr<per_ctx_dot> pc)
{
Expand Down Expand Up @@ -782,130 +915,6 @@ public:
}
}

// This should not need to be called explicitly, unless we are doing some automatic tests for example
void finish()
{
single_threaded_section guard(mtx);

if (dot_filename.empty())
{
return;
}

for (const auto& pc : per_ctx)
{
pc->finish();
}

collapse_sections();

// Now we have executed all tasks, so we can compute the average execution
// times, and update the colors appropriately if needed.
update_colors_with_timing();

::std::ofstream outFile(dot_filename);
if (outFile.is_open())
{
outFile << "digraph {\n";
size_t ctx_cnt = 0;
bool display_clusters = (per_ctx.size() > 1);
/*
* For every context, we write the description of the DAG per
* epoch. Then we write the edges after removing redundant ones.
*/
for (const auto& pc : per_ctx)
{
// If the context has a parent, it will be printed by this parent itself
if (!pc->parent)
{
print_one_context(outFile, ctx_cnt, display_clusters, pc);
}
}

if (!getenv("CUDASTF_DOT_KEEP_REDUNDANT"))
{
remove_redundant_edges(existing_edges);
}

compute_critical_path(outFile);

/* Edges do not have to belong to the cluster (Vertices do) */
for (const auto& [from, to] : existing_edges)
{
outFile << "\"NODE_" << from << "\" -> \"NODE_" << to << "\"\n";
}

// Update node properties such as labels and colors now that we have all information
vertex_count = 0;
for (const auto& pc : per_ctx)
{
for (const auto& p : pc->metadata)
{
outFile << "\"NODE_" << p.first << "\" [style=\"filled\" fillcolor=\"" << p.second.color << "\" label=\""
<< p.second.label << "\"]\n";
vertex_count++;
}
}

edge_count = existing_edges.size();

outFile << "// Edge count : " << edge_count << "\n";
outFile << "// Vertex count : " << vertex_count << "\n";

outFile << "}\n";

outFile.close();
}
else
{
::std::cerr << "Unable to open file: " << dot_filename << ::std::endl;
}

const char* stats_filename_str = getenv("CUDASTF_DOT_STATS_FILE");
if (stats_filename_str)
{
::std::string stats_filename = stats_filename_str;
::std::ofstream statsFile(stats_filename);
if (statsFile.is_open())
{
statsFile << "#nedges,nvertices,total_work,critical_path\n";

// to display an optional value or NA
auto formatOptional = [](const ::std::optional<float>& opt) -> ::std::string {
return opt ? ::std::to_string(*opt) : "NA";
};

statsFile << edge_count << "," << vertex_count << "," << formatOptional(total_work) << ","
<< formatOptional(critical_path) << "\n";

statsFile.close();
}
else
{
::std::cerr << "Unable to open file: " << stats_filename << ::std::endl;
}
}

dot_filename.clear();
}

bool is_tracing() const
{
return !dot_filename.empty();
}

bool is_tracing_prereqs()
{
return tracing_prereqs;
}

bool is_timing() const
{
return enable_timing;
}

::std::vector<::std::shared_ptr<per_ctx_dot>> per_ctx;

private:
// Function to get a color based on task duration relative to the average
::std::string get_color_for_duration(double duration, double avg_duration)
Expand Down Expand Up @@ -1141,8 +1150,6 @@ private:

::std::unordered_map<int, ::std::vector<int>> predecessors;

mutable ::std::mutex mtx;

::std::string dot_filename;

// Map to get dot sections from their ID
Expand All @@ -1153,6 +1160,12 @@ private:
::std::optional<float> total_work; // T1
size_t edge_count;
size_t vertex_count;

private:
mutable ::std::mutex mtx;

// A vector that keeps track of all per context stored data
::std::vector<::std::shared_ptr<per_ctx_dot>> per_ctx;
};

inline int per_ctx_dot::get_current_section_id()
Expand Down
1 change: 1 addition & 0 deletions cudax/test/stf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ set(stf_test_sources
interface/data_from_device_async.cu
interface/move_operator.cu
local_stf/legacy_to_stf.cu
local_stf/threads_multiple_graphs.cu
places/managed.cu
places/managed_from_user.cu
places/non_current_device.cu
Expand Down
Loading

0 comments on commit 83ba38c

Please sign in to comment.