22
33#include " comms/torchcomms/TorchCommUtils.hpp"
44#include < algorithm>
5+ #include < fstream>
6+ #include < iterator>
57#include < sstream>
68#include < stdexcept>
79#include < string>
@@ -70,6 +72,91 @@ T env_to_value(const std::string& env_key, const T& default_value) {
7072 }
7173}
7274
75+ int count_file_lines (const std::string& filepath, bool ignore_empty_lines) {
76+ std::ifstream filestream (filepath);
77+ if (!filestream.is_open ()) {
78+ throw std::runtime_error (" Failed to open file for reading: " + filepath);
79+ }
80+
81+ int line_count = 0 ;
82+ if (ignore_empty_lines) {
83+ std::string line;
84+ while (std::getline (filestream, line)) {
85+ if (!line.empty ()) {
86+ ++line_count;
87+ }
88+ }
89+ } else {
90+ line_count = std::count (
91+ std::istreambuf_iterator<char >(filestream),
92+ std::istreambuf_iterator<char >(),
93+ ' \n ' );
94+ }
95+
96+ if (filestream.bad ()) {
97+ throw std::runtime_error (" Error while reading file: " + filepath);
98+ } else if (!filestream.eof ()) {
99+ throw std::runtime_error (" File read did not reach EOF: " + filepath);
100+ }
101+
102+ return line_count;
103+ }
104+
105+ std::pair<int , int > query_pals_ranksize () {
106+ // Try to get rank and size directly from PALS environment variables first
107+ auto rank = env_to_value<int >(" PALS_RANKID" , -1 );
108+ // PALS_SIZE is currently not available but may be included in future versions
109+ auto comm_size = env_to_value<int >(" PALS_SIZE" , -1 );
110+
111+ // If rank & size are both found from PALS return them
112+ if (rank > -1 && comm_size > 0 ) {
113+ TC_LOG (INFO) << " Found rank and size from PALS environment (rank=" << rank
114+ << " , size=" << comm_size << " )." ;
115+ return {rank, comm_size};
116+ }
117+
118+ // PALS currently supports only the PBS workload manager.
119+ // We calculate the size using the PBS_NODEFILE (which lists one node per
120+ // line) to get number of nodes and PALS_LOCAL_SIZE to get the number of ranks
121+ // per node:
122+ // size = (number of nodes) * (ranks per node)
123+ // which assumes all nodes have the same number of ranks.
124+ const auto num_ranks_per_node = env_to_value<int >(" PALS_LOCAL_SIZE" , -1 );
125+ const auto nodefile_path = env_to_value<std::string>(" PBS_NODEFILE" , " " );
126+
127+ if (comm_size == -1 && !nodefile_path.empty () && num_ranks_per_node > 0 ) {
128+ try {
129+ const auto num_nodes = count_file_lines (nodefile_path);
130+ if (num_nodes > 0 ) {
131+ comm_size = num_nodes * num_ranks_per_node;
132+ }
133+ } catch (const std::exception& e) {
134+ TC_LOG (ERROR) << " Failed to determine size from PALS/PBS environment. "
135+ << " Could not count lines in PBS_NODEFILE ("
136+ << nodefile_path << " ): " << e.what ();
137+ comm_size = -1 ;
138+ }
139+ }
140+
141+ // Only log warnings if we have partial information, indicating a possible
142+ // broken PALS/PBS environment. If neither rank nor size are found, we are
143+ // likely not in a PALS environment.
144+ if (rank > -1 && comm_size == -1 ) {
145+ TC_LOG (WARNING)
146+ << " Found rank from PALS environment but unable to determine size. "
147+ << " Please set PALS_SIZE or ensure PBS_NODEFILE and PALS_LOCAL_SIZE are available." ;
148+ } else if (rank == -1 && comm_size > 0 ) {
149+ TC_LOG (WARNING)
150+ << " Found size from PALS/PBS environment but unable to determine rank. "
151+ << " Please set PALS_RANKID." ;
152+ } else {
153+ TC_LOG (INFO) << " Found rank and size from PALS/PBS environment (rank="
154+ << rank << " , size=" << comm_size << " )." ;
155+ }
156+
157+ return {rank, comm_size};
158+ }
159+
73160// Explicit instantiations for common types
74161template bool env_to_value<bool >(const std::string&, const bool &);
75162template int env_to_value<int >(const std::string&, const int &);
@@ -84,6 +171,7 @@ std::pair<int, int> query_ranksize() {
84171 const std::string kRanksizeQueryMethodAuto = " auto" ;
85172 const std::string kRanksizeQueryMethodTorchrun = " torchrun" ;
86173 const std::string kRanksizeQueryMethodMPI = " mpi" ;
174+ const std::string kRanksizeQueryMethodPALS = " pals" ;
87175 const std::string& kRanksizeQueryMethodDefault = kRanksizeQueryMethodAuto ;
88176
89177 // Get the ranksize query method from environment variable
@@ -110,7 +198,7 @@ std::pair<int, int> query_ranksize() {
110198 // Read from TORCHCOMM_RANK and TORCHCOMM_SIZE environment variables
111199 rank = env_to_value<int >(" TORCHCOMM_RANK" , -1 );
112200 comm_size = env_to_value<int >(" TORCHCOMM_SIZE" , -1 );
113- if (rank != -1 && comm_size != - 1 ) {
201+ if (rank > -1 && comm_size > 0 ) {
114202 break ;
115203 }
116204
@@ -120,14 +208,23 @@ std::pair<int, int> query_ranksize() {
120208 // See if we are in an OpenMPI environment
121209 rank = env_to_value<int >(" OMPI_COMM_WORLD_RANK" , -1 );
122210 comm_size = env_to_value<int >(" OMPI_COMM_WORLD_SIZE" , -1 );
123- if (rank != -1 && comm_size != - 1 ) {
211+ if (rank > -1 && comm_size > 0 ) {
124212 break ;
125213 }
126214
127215 // See if we are in an MPICH environment
128216 rank = env_to_value<int >(" PMI_RANK" , -1 );
129217 comm_size = env_to_value<int >(" PMI_SIZE" , -1 );
130- if (rank != -1 && comm_size != -1 ) {
218+ if (rank > -1 && comm_size > 0 ) {
219+ break ;
220+ }
221+ }
222+
223+ // See if we are in a PALS environment
224+ if (ranksize_query_method == kRanksizeQueryMethodAuto ||
225+ ranksize_query_method == kRanksizeQueryMethodPALS ) {
226+ std::tie (rank, comm_size) = query_pals_ranksize ();
227+ if (rank > -1 && comm_size > 0 ) {
131228 break ;
132229 }
133230 }
@@ -137,17 +234,17 @@ std::pair<int, int> query_ranksize() {
137234 ranksize_query_method == kRanksizeQueryMethodTorchrun ) {
138235 rank = env_to_value<int >(" RANK" , -1 );
139236 comm_size = env_to_value<int >(" WORLD_SIZE" , -1 );
140- if (rank != -1 && comm_size != - 1 ) {
237+ if (rank > -1 && comm_size > 0 ) {
141238 break ;
142239 }
143240 }
144241 } while (0 );
145242
146- if (rank == - 1 || comm_size == - 1 ) {
243+ if (rank < 0 || comm_size < 1 ) {
147244 throw std::runtime_error (
148245 " Unable to determine rank and size from environment variables. "
149246 " Please set TORCHCOMM_RANK and TORCHCOMM_SIZE, or ensure you are "
150- " running in a supported environment (Torchrun or MPI)." );
247+ " running in a supported environment (Torchrun, MPI, PALS )." );
151248 }
152249
153250 return std::make_pair (rank, comm_size);
0 commit comments