Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions verl/single_controller/ray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,23 @@ def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[Placement


@ray.remote
def get_master_addr_port() -> tuple[str, str]:
def get_master_addr_port(master_port_range: Optional[list[int]] = None) -> tuple[str, str]:
addr = ray.util.get_node_ip_address().strip("[]")
with socket.socket() as sock:
sock.bind(("", 0))
port = sock.getsockname()[1]

if master_port_range is None:
with socket.socket() as s:
s.bind(("", 0))
port = s.getsockname()[1]
else:
port = master_port_range[0]
while port < master_port_range[1]:
try:
with socket.socket() as s:
s.bind(("", port))
break
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
Comment on lines +97 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If all ports within the specified master_port_range are in use, the current implementation will fall through the loop and return addr, str(port) will be called with port equal to master_port_range[1]. This port is outside the specified range [start, end) and is not guaranteed to be free. This can lead to unexpected behavior or failures. The function should instead raise an exception if no free port can be found within the given range.

Suggested change
else:
port = master_port_range[0]
while port < master_port_range[1]:
try:
with socket.socket() as s:
s.bind(("", port))
break
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
else:
port = master_port_range[0]
while port < master_port_range[1]:
try:
with socket.socket() as s:
s.bind(('', port))
break
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
else:
raise RuntimeError(f"Could not find a free port in range {master_port_range}")

return addr, str(port)


Expand Down Expand Up @@ -495,14 +507,15 @@ def _init_with_detached_workers(self, worker_names, worker_handles):
self._workers = workers
self._world_size = len(workers)

def _get_master_addr_port(self, pg, bundle_index=0):
def _get_master_addr_port(self, pg, bundle_index=0, master_port_range=None):
"""Get master addr and port for this worker group"""
if self._master_addr is None and self._master_port is None:
self._master_addr, self._master_port = ray.get(
get_master_addr_port.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_bundle_index=bundle_index
),
master_port_range=master_port_range,
).remote()
)
elif self._master_addr is not None and self._master_port is not None:
Expand All @@ -513,7 +526,9 @@ def _get_master_addr_port(self, pg, bundle_index=0):
"or neither should be provided to use Ray's default assignment."
)

def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached, worker_env=None):
def _init_with_resource_pool(
self, resource_pool, ray_cls_with_init, bin_pack, detached, worker_env=None, master_port_range=None
):
"""Initialize the worker group by creating new workers from a resource pool.

Args:
Expand All @@ -523,7 +538,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d
detached: Whether workers should be detached
"""
self.resource_pool = resource_pool

self.master_port_range = master_port_range
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While master_port_range is correctly stored as an instance attribute here, it is not used later in this method. The call to self._get_master_addr_port at line 555 needs to be updated to pass self.master_port_range to make this feature functional. This is one of several places where the new parameter needs to be plumbed through.

strategy = "PACK"
if bin_pack:
strategy = "STRICT_PACK"
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ trainer:
# mode: "auto", "enable", or "disable"
use_legacy_worker_impl: auto

#master port range for ray to find a free port
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this, we can't setup it in the global config

master_port_range: null

global_profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: null # choose between nsys, npu, torch, torch_memory
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ trainer:
# mode: "auto", "enable", or "disable"
use_legacy_worker_impl: auto

#master port range for ray to find a free port
master_port_range: null

# profiler configs
global_profiler:

Expand Down
1 change: 1 addition & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def init_workers(self):
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
)
wg_kwargs["device_name"] = self.device_name
wg_kwargs["master_port_range"] = self.config.trainer.get("master_port_range", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The master_port_range is being added to wg_kwargs, but the RayWorkerGroup constructor does not handle this argument. It will be ignored, and the port range configuration will have no effect. You need to update RayWorkerGroup.__init__ to accept master_port_range from **kwargs and then pass it down to _init_with_resource_pool.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to pass the config to the trainer, just provide the API so that it is configurable based on actual use case.


for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
Expand Down