diff --git a/comms/torchcomms/TorchCommUtils.cpp b/comms/torchcomms/TorchCommUtils.cpp index bb82fbfb..545b6c74 100644 --- a/comms/torchcomms/TorchCommUtils.cpp +++ b/comms/torchcomms/TorchCommUtils.cpp @@ -2,6 +2,8 @@ #include "comms/torchcomms/TorchCommUtils.hpp" #include +#include +#include #include #include #include @@ -70,6 +72,107 @@ T env_to_value(const std::string& env_key, const T& default_value) { } } +int count_file_lines(const std::string& filepath, bool ignore_empty_lines) { + std::ifstream filestream(filepath); + if (!filestream.is_open()) { + throw std::runtime_error("Failed to open file for reading: " + filepath); + } + + int line_count = 0; + if (ignore_empty_lines) { + std::string line; + while (std::getline(filestream, line)) { + if (!line.empty()) { + ++line_count; + } + } + } else { + line_count = std::count( + std::istreambuf_iterator(filestream), + std::istreambuf_iterator(), + '\n'); + + // Clear the stream state so we can use tellg/seekg + filestream.clear(); + + // Check if the file is empty + const bool is_empty_file = + filestream.tellg() == std::ifstream::pos_type(0) && + filestream.peek() == std::ifstream::traits_type::eof(); + + // If the file does not end with a newline, we need to add one to the count + if (!is_empty_file) { + filestream.seekg(-1, std::ios::end); + char last_char; + filestream.get(last_char); + if (last_char != '\n') { + ++line_count; + } + } + } + + if (filestream.bad()) { + throw std::runtime_error("Error while reading file: " + filepath); + } + + return line_count; +} + +std::pair query_pals_ranksize() { + // Try to get rank and size directly from PALS environment variables first + auto rank = env_to_value("PALS_RANKID", -1); + // PALS_SIZE is currently not available but may be included in future versions + auto comm_size = env_to_value("PALS_SIZE", -1); + + // If rank & size are both found from PALS return them + if (rank > -1 && comm_size > 0) { + TC_LOG(INFO) << "Found rank and size from PALS environment (rank=" << rank + << ", size=" << comm_size << ")."; + return {rank, comm_size}; + } + + // PALS currently supports only the PBS workload manager. + // We calculate the size using the PBS_NODEFILE (which lists one node per + // line) to get number of nodes and PALS_LOCAL_SIZE to get the number of ranks + // per node: + // size = (number of nodes) * (ranks per node) + // which assumes all nodes have the same number of ranks. + const auto num_ranks_per_node = env_to_value("PALS_LOCAL_SIZE", -1); + const auto nodefile_path = env_to_value("PBS_NODEFILE", ""); + + if (comm_size == -1 && !nodefile_path.empty() && num_ranks_per_node > 0) { + try { + const auto num_nodes = count_file_lines(nodefile_path); + if (num_nodes > 0) { + comm_size = num_nodes * num_ranks_per_node; + } + } catch (const std::exception& e) { + TC_LOG(ERROR) << "Failed to determine size from PALS/PBS environment. " + << "Could not count lines in PBS_NODEFILE (" + << nodefile_path << "): " << e.what(); + comm_size = -1; + } + } + + // Only log warnings if we have partial information, indicating a possible + // broken PALS/PBS environment. If neither rank nor size are found, we are + // likely not in a PALS environment. + if (rank > -1 && comm_size == -1) { + TC_LOG(WARNING) + << "Found rank from PALS environment but unable to determine size. " + << "Please set PALS_SIZE or ensure PBS_NODEFILE and PALS_LOCAL_SIZE are available."; + } else if (rank == -1 && comm_size > 0) { + TC_LOG(WARNING) + << "Found size from PALS/PBS environment but unable to determine rank. " + << "Please set PALS_RANKID."; + } else if (rank > -1 && comm_size > 0) { + TC_LOG(INFO) << "Found rank and size from PALS/PBS environment (rank=" + << rank << ", size=" << comm_size << ")."; + } + + return {rank, comm_size}; +} + // Explicit instantiations for common types template bool env_to_value(const std::string&, const bool&); template int env_to_value(const std::string&, const int&); @@ -84,6 +187,7 @@ std::pair query_ranksize() { const std::string kRanksizeQueryMethodAuto = "auto"; const std::string kRanksizeQueryMethodTorchrun = "torchrun"; const std::string kRanksizeQueryMethodMPI = "mpi"; + const std::string kRanksizeQueryMethodPALS = "pals"; const std::string& kRanksizeQueryMethodDefault = kRanksizeQueryMethodAuto; // Get the ranksize query method from environment variable @@ -110,7 +214,7 @@ std::pair query_ranksize() { // Read from TORCHCOMM_RANK and TORCHCOMM_SIZE environment variables rank = env_to_value("TORCHCOMM_RANK", -1); comm_size = env_to_value("TORCHCOMM_SIZE", -1); - if (rank != -1 && comm_size != -1) { + if (rank > -1 && comm_size > 0) { break; } @@ -120,14 +224,23 @@ std::pair query_ranksize() { // See if we are in an OpenMPI environment rank = env_to_value("OMPI_COMM_WORLD_RANK", -1); comm_size = env_to_value("OMPI_COMM_WORLD_SIZE", -1); - if (rank != -1 && comm_size != -1) { + if (rank > -1 && comm_size > 0) { break; } // See if we are in an MPICH environment rank = env_to_value("PMI_RANK", -1); comm_size = env_to_value("PMI_SIZE", -1); - if (rank != -1 && comm_size != -1) { + if (rank > -1 && comm_size > 0) { + break; + } + } + + // See if we are in a PALS environment + if (ranksize_query_method == kRanksizeQueryMethodAuto || + ranksize_query_method == kRanksizeQueryMethodPALS) { + std::tie(rank, comm_size) = query_pals_ranksize(); + if (rank > -1 && comm_size > 0) { break; } } @@ -137,17 +250,17 @@ std::pair query_ranksize() { ranksize_query_method == kRanksizeQueryMethodTorchrun) { rank = env_to_value("RANK", -1); comm_size = env_to_value("WORLD_SIZE", -1); - if (rank != -1 && comm_size != -1) { + if (rank > -1 && comm_size > 0) { break; } } } while (0); - if (rank == -1 || comm_size == -1) { + if (rank < 0 || comm_size < 1) { throw std::runtime_error( "Unable to determine rank and size from environment variables. " "Please set TORCHCOMM_RANK and TORCHCOMM_SIZE, or ensure you are " - "running in a supported environment (Torchrun or MPI)."); + "running in a supported environment (Torchrun, MPI, PALS)."); } return std::make_pair(rank, comm_size); diff --git a/comms/torchcomms/TorchCommUtils.hpp b/comms/torchcomms/TorchCommUtils.hpp index f626141a..c88da332 100644 --- a/comms/torchcomms/TorchCommUtils.hpp +++ b/comms/torchcomms/TorchCommUtils.hpp @@ -13,6 +13,13 @@ bool string_to_bool(const std::string& str); template T env_to_value(const std::string& env_key, const T& default_value); +// Counts the number of lines in a file +int count_file_lines( + const std::string& filepath, + bool ignore_empty_lines = true); + +std::pair query_pals_ranksize(); + // Query rank and size based on TORCHCOMM_BOOTSTRAP_RANKSIZE_QUERY_METHOD std::pair query_ranksize(); diff --git a/comms/torchcomms/tests/integration/cpp/TorchCommTestHelpers.cpp b/comms/torchcomms/tests/integration/cpp/TorchCommTestHelpers.cpp index c2dcb1d3..7f8ec6aa 100644 --- a/comms/torchcomms/tests/integration/cpp/TorchCommTestHelpers.cpp +++ b/comms/torchcomms/tests/integration/cpp/TorchCommTestHelpers.cpp @@ -5,6 +5,7 @@ #include "comms/torchcomms/StoreManager.hpp" #include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/TorchCommUtils.hpp" using namespace torch::comms; @@ -111,6 +112,11 @@ std::tuple getRankAndSize() { return {std::stoi(torchrun_rank), std::stoi(torchrun_size)}; } + const auto [rank, size] = query_pals_ranksize(); + if (rank > -1 && size > 0) { + return {rank, size}; + } + throw std::runtime_error( "Could not determine rank or world size from environment variables."); }