Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@ async def update_weights(self):
# main memory can trade sync time to avoid OOM
self.rollout.sleep_level = 1

do_lora_base_sync = not self.base_sync_done or self.rollout.sleep_level != 1
do_lora_base_sync = (not self.base_sync_done) or (
self.rollout.sleep_level != 1 and self.config.rollout.free_cache_engine
)

if do_lora_base_sync:
per_tensor_base_params, _ = self.actor.engine.get_per_tensor_param(
Expand Down
13 changes: 11 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,12 @@ async def rollout_mode(self):
# When sleep_level=2, base model weights are destroyed during each sleep cycle.
# separately collect and update LoRA weights and base model weights through their respective interfaces.
# Here: params contains LoRA weights, base_model_params contains base model weights.
if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
# Only needed if the rollout engine actually sleeps/frees weights (free_cache_engine=True).
if (
peft_config is not None
and getattr(self.rollout, "sleep_level", None) == 2
and self.config.rollout.free_cache_engine
Comment thread
JohnConnor123 marked this conversation as resolved.
):
Comment thread
HollowMan6 marked this conversation as resolved.
base_model_params = collect_lora_params(
module=self.actor_module_fsdp,
layered_summon=self.layered_summon,
Expand Down Expand Up @@ -736,7 +741,11 @@ async def rollout_mode(self):
await self.rollout.resume(tags=["weights"])
log_gpu_memory_usage("After resume weights", logger=logger)

if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
if (
peft_config is not None
and getattr(self.rollout, "sleep_level", None) == 2
and self.config.rollout.free_cache_engine
):
per_tensor_base_params = (
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in base_model_params.items()
Expand Down
4 changes: 3 additions & 1 deletion verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,9 @@ async def rollout_mode(self):
# main memory can trade sync time to avoid OOM
self.rollout.sleep_level = 1

do_lora_base_sync = not self.base_sync_done or self.rollout.sleep_level != 1
do_lora_base_sync = (not self.base_sync_done) or (
self.rollout.sleep_level != 1 and self.config.rollout.free_cache_engine
)

if self.bridge is not None:
if self.vanilla_bridge:
Expand Down
Loading