Skip to content

Commit 41aefa3

Browse files
committed
feat(bootstrap): add PALS/PBS environment detection support
Add support for automatically detecting rank and world size when running in PALS (Parallel Application Launching System) environments, commonly used with PBS workload managers. Key changes: * Implement `query_pals_ranksize` to resolve rank/size from PALS/PBS env vars. * Add logic to calculate world size using `PBS_NODEFILE` and `PALS_LOCAL_SIZE` when `PALS_SIZE` is unavailable. * Add `count_file_lines` utility for parsing nodefiles with empty line handling. * Integrate PALS detection into `query_ranksize` bootstrap logic. * Update test helpers to support PALS environments. * Refactor rank/size validation to ensure non-negative rank and positive size.
1 parent 2f7f66b commit 41aefa3

File tree

3 files changed

+116
-6
lines changed

3 files changed

+116
-6
lines changed

comms/torchcomms/TorchCommUtils.cpp

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include "comms/torchcomms/TorchCommUtils.hpp"
44
#include <algorithm>
5+
#include <fstream>
6+
#include <iterator>
57
#include <sstream>
68
#include <stdexcept>
79
#include <string>
@@ -70,6 +72,91 @@ T env_to_value(const std::string& env_key, const T& default_value) {
7072
}
7173
}
7274

75+
int count_file_lines(const std::string& filepath, bool ignore_empty_lines) {
76+
std::ifstream filestream(filepath);
77+
if (!filestream.is_open()) {
78+
throw std::runtime_error("Failed to open file for reading: " + filepath);
79+
}
80+
81+
int line_count = 0;
82+
if (ignore_empty_lines) {
83+
std::string line;
84+
while (std::getline(filestream, line)) {
85+
if (!line.empty()) {
86+
++line_count;
87+
}
88+
}
89+
} else {
90+
line_count = std::count(
91+
std::istreambuf_iterator<char>(filestream),
92+
std::istreambuf_iterator<char>(),
93+
'\n');
94+
}
95+
96+
if (filestream.bad()) {
97+
throw std::runtime_error("Error while reading file: " + filepath);
98+
} else if (!filestream.eof()) {
99+
throw std::runtime_error("File read did not reach EOF: " + filepath);
100+
}
101+
102+
return line_count;
103+
}
104+
105+
std::pair<int, int> query_pals_ranksize() {
106+
// Try to get rank and size directly from PALS environment variables first
107+
auto rank = env_to_value<int>("PALS_RANKID", -1);
108+
// PALS_SIZE is currently not available but may be included in future versions
109+
auto comm_size = env_to_value<int>("PALS_SIZE", -1);
110+
111+
// If rank & size are both found from PALS return them
112+
if (rank > -1 && comm_size > 0) {
113+
TC_LOG(INFO) << "Found rank and size from PALS environment (rank=" << rank
114+
<< ", size=" << comm_size << ").";
115+
return {rank, comm_size};
116+
}
117+
118+
// PALS currently supports only the PBS workload manager.
119+
// We calculate the size using the PBS_NODEFILE (which lists one node per
120+
// line) to get number of nodes and PALS_LOCAL_SIZE to get the number of ranks
121+
// per node:
122+
// size = (number of nodes) * (ranks per node)
123+
// which assumes all nodes have the same number of ranks.
124+
const auto num_ranks_per_node = env_to_value<int>("PALS_LOCAL_SIZE", -1);
125+
const auto nodefile_path = env_to_value<std::string>("PBS_NODEFILE", "");
126+
127+
if (comm_size == -1 && !nodefile_path.empty() && num_ranks_per_node > 0) {
128+
try {
129+
const auto num_nodes = count_file_lines(nodefile_path);
130+
if (num_nodes > 0) {
131+
comm_size = num_nodes * num_ranks_per_node;
132+
}
133+
} catch (const std::exception& e) {
134+
TC_LOG(ERROR) << "Failed to determine size from PALS/PBS environment. "
135+
<< "Could not count lines in PBS_NODEFILE ("
136+
<< nodefile_path << "): " << e.what();
137+
comm_size = -1;
138+
}
139+
}
140+
141+
// Only log warnings if we have partial information, indicating a possible
142+
// broken PALS/PBS environment. If neither rank nor size are found, we are
143+
// likely not in a PALS environment.
144+
if (rank > -1 && comm_size == -1) {
145+
TC_LOG(WARNING)
146+
<< "Found rank from PALS environment but unable to determine size. "
147+
<< "Please set PALS_SIZE or ensure PBS_NODEFILE and PALS_LOCAL_SIZE are available.";
148+
} else if (rank == -1 && comm_size > 0) {
149+
TC_LOG(WARNING)
150+
<< "Found size from PALS/PBS environment but unable to determine rank. "
151+
<< "Please set PALS_RANKID.";
152+
} else {
153+
TC_LOG(INFO) << "Found rank and size from PALS/PBS environment (rank="
154+
<< rank << ", size=" << comm_size << ").";
155+
}
156+
157+
return {rank, comm_size};
158+
}
159+
73160
// Explicit instantiations for common types
74161
template bool env_to_value<bool>(const std::string&, const bool&);
75162
template int env_to_value<int>(const std::string&, const int&);
@@ -84,6 +171,7 @@ std::pair<int, int> query_ranksize() {
84171
const std::string kRanksizeQueryMethodAuto = "auto";
85172
const std::string kRanksizeQueryMethodTorchrun = "torchrun";
86173
const std::string kRanksizeQueryMethodMPI = "mpi";
174+
const std::string kRanksizeQueryMethodPALS = "pals";
87175
const std::string& kRanksizeQueryMethodDefault = kRanksizeQueryMethodAuto;
88176

89177
// Get the ranksize query method from environment variable
@@ -110,7 +198,7 @@ std::pair<int, int> query_ranksize() {
110198
// Read from TORCHCOMM_RANK and TORCHCOMM_SIZE environment variables
111199
rank = env_to_value<int>("TORCHCOMM_RANK", -1);
112200
comm_size = env_to_value<int>("TORCHCOMM_SIZE", -1);
113-
if (rank != -1 && comm_size != -1) {
201+
if (rank > -1 && comm_size > 0) {
114202
break;
115203
}
116204

@@ -120,14 +208,23 @@ std::pair<int, int> query_ranksize() {
120208
// See if we are in an OpenMPI environment
121209
rank = env_to_value<int>("OMPI_COMM_WORLD_RANK", -1);
122210
comm_size = env_to_value<int>("OMPI_COMM_WORLD_SIZE", -1);
123-
if (rank != -1 && comm_size != -1) {
211+
if (rank > -1 && comm_size > 0) {
124212
break;
125213
}
126214

127215
// See if we are in an MPICH environment
128216
rank = env_to_value<int>("PMI_RANK", -1);
129217
comm_size = env_to_value<int>("PMI_SIZE", -1);
130-
if (rank != -1 && comm_size != -1) {
218+
if (rank > -1 && comm_size > 0) {
219+
break;
220+
}
221+
}
222+
223+
// See if we are in a PALS environment
224+
if (ranksize_query_method == kRanksizeQueryMethodAuto ||
225+
ranksize_query_method == kRanksizeQueryMethodPALS) {
226+
std::tie(rank, comm_size) = query_pals_ranksize();
227+
if (rank > -1 && comm_size > 0) {
131228
break;
132229
}
133230
}
@@ -137,17 +234,17 @@ std::pair<int, int> query_ranksize() {
137234
ranksize_query_method == kRanksizeQueryMethodTorchrun) {
138235
rank = env_to_value<int>("RANK", -1);
139236
comm_size = env_to_value<int>("WORLD_SIZE", -1);
140-
if (rank != -1 && comm_size != -1) {
237+
if (rank > -1 && comm_size > 0) {
141238
break;
142239
}
143240
}
144241
} while (0);
145242

146-
if (rank == -1 || comm_size == -1) {
243+
if (rank < 0 || comm_size < 1) {
147244
throw std::runtime_error(
148245
"Unable to determine rank and size from environment variables. "
149246
"Please set TORCHCOMM_RANK and TORCHCOMM_SIZE, or ensure you are "
150-
"running in a supported environment (Torchrun or MPI).");
247+
"running in a supported environment (Torchrun, MPI, PALS).");
151248
}
152249

153250
return std::make_pair(rank, comm_size);

comms/torchcomms/TorchCommUtils.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ bool string_to_bool(const std::string& str);
1313
template <typename T>
1414
T env_to_value(const std::string& env_key, const T& default_value);
1515

16+
// Counts the number of lines in a file
17+
int count_file_lines(
18+
const std::string& filepath,
19+
bool ignore_empty_lines = true);
20+
21+
std::pair<int, int> query_pals_ranksize();
22+
1623
// Query rank and size based on TORCHCOMM_BOOTSTRAP_RANKSIZE_QUERY_METHOD
1724
std::pair<int, int> query_ranksize();
1825

comms/torchcomms/tests/integration/cpp/TorchCommTestHelpers.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "comms/torchcomms/StoreManager.hpp"
77
#include "comms/torchcomms/TorchCommLogging.hpp"
8+
#include "comms/torchcomms/TorchCommUtils.hpp"
89

910
using namespace torch::comms;
1011

@@ -111,6 +112,11 @@ std::tuple<int, int> getRankAndSize() {
111112
return {std::stoi(torchrun_rank), std::stoi(torchrun_size)};
112113
}
113114

115+
const auto [rank, size] = query_pals_ranksize();
116+
if (rank > -1 && size > 0) {
117+
return {rank, size};
118+
}
119+
114120
throw std::runtime_error(
115121
"Could not determine rank or world size from environment variables.");
116122
}

0 commit comments

Comments
 (0)