From 65e999088259b07c3a8cbfad4b9785d0e9478e3c Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 20 Oct 2025 13:57:42 -0700 Subject: [PATCH 1/4] Fix lint issues: remove unused import --- src/forge/actors/generator.py | 108 ++++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 6 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 6c2efd5e6..273a849eb 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -286,6 +286,45 @@ def split_keys(keys): return state_dict + async def _fetch_weights_dcp( + self, + version: int, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from DCP checkpoint and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("generator_perf/_fetch_weights_dcp") + t.start() + + # Get the DCP handle from torchstore + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + dcp_handle = await ts.get(dcp_whole_state_dict_key) + hf_param_names = dcp_handle.param_names + checkpoint_id = dcp_handle.checkpoint_id + + n_fetchers = self.weight_fetchers.size() + + def split_keys(keys): + return [keys[i::n_fetchers] for i in range(n_fetchers)] + + futures = [] + for i, names in enumerate(split_keys(hf_param_names)): + fut = self.weight_fetchers.slice(procs=i).fetch_dcp.call_one( + checkpoint_id=checkpoint_id, param_names=names + ) + futures.append(fut) + + sub_state_dicts = [await fut for fut in futures] + + state_dict = {} + for sd in sub_state_dicts: + state_dict.update(sd) + + # Clean up the DCP handle after fetching to shared memory + dcp_handle.drop() + + t.stop() + + return state_dict + @endpoint async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt @@ -439,12 +478,25 @@ async def update_weights(self, version: int) -> None: >>> await trainer.push_weights() >>> generator.update_weights(version) """ - # TODO: enable shared memory prefetch for DCP-based weight sync - if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: - logger.info(f"[Generator] Fetching weights for v{version} to shared memory") - fetch_fut = asyncio.create_task(self._fetch_weights(version)) - else: - fetch_fut = None + # Prefetch weights to shared memory if enabled + fetch_fut = None + if self.prefetch_weights_to_shm: + # Check if DCP is being used for this version + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys + + if use_dcp_for_weight_sync: + logger.info( + f"[Generator] Fetching weights for v{version} from DCP to shared memory" + ) + fetch_fut = asyncio.create_task(self._fetch_weights_dcp(version)) + else: + logger.info( + f"[Generator] Fetching weights for v{version} to shared memory" + ) + fetch_fut = asyncio.create_task(self._fetch_weights(version)) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -747,3 +799,47 @@ async def fetch( sd[name] = handle del param # Explicitly free the tensor after copying to shared memory return sd + + @endpoint + async def fetch_dcp( + self, + *, + checkpoint_id: str, + param_names: list[str], + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from DCP checkpoint and load them into shared memory.""" + from forge.actors._torchstore_utils import DcpHandle + + sd = {} + # Create a minimal DCP handle for loading tensors + # We only need checkpoint_id and metadata, which load_tensor_from_dcp will fetch + for name in param_names: + # Load tensor from DCP using the checkpoint_id + # Note: load_tensor_from_dcp handles reading from the DCP checkpoint directory + from torch.distributed.checkpoint.metadata import load_metadata + + # Load metadata if not already loaded (first time) + if ( + not hasattr(self, "_dcp_metadata") + or getattr(self, "_dcp_checkpoint_id", None) != checkpoint_id + ): + self._dcp_metadata = load_metadata(checkpoint_id) + self._dcp_checkpoint_id = checkpoint_id + + # Create a DCP handle with the loaded metadata + dcp_handle = DcpHandle( + checkpoint_id=checkpoint_id, + metadata=self._dcp_metadata, + param_names=param_names, + ) + + # Load the tensor from DCP + param = load_tensor_from_dcp(dcp_handle, name) + + # Use context manager to ensure cleanup after getting handle + with SharedTensor(tensor=param) as shared_tensor: + handle = shared_tensor.get_handle() + sd[name] = handle + del param # Explicitly free the tensor after copying to shared memory + + return sd From 28220c72e5ab269f01678218cf1d7d16f506a63b Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 20 Oct 2025 14:02:04 -0700 Subject: [PATCH 2/4] Apply ufmt formatting --- src/forge/actors/generator.py | 134 ++++++++++++++++++---------------- 1 file changed, 72 insertions(+), 62 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 273a849eb..e12eda57c 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -253,16 +253,27 @@ async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): for handle in state_dict.values(): handle.drop() - async def _fetch_weights( + async def _fetch_weights_parallel( self, - version: int, + param_names: list[str], + *, + version: int | None = None, + checkpoint_id: str | None = None, + tracer_name: str, ) -> dict[str, SharedTensorHandle]: - """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" - t = Tracer("generator_perf/_fetch_weights") + """Fetch weights in parallel using multiple fetcher processes. + + Args: + param_names: List of parameter names to fetch + version: Version number for individual tensor loading (mutually exclusive with checkpoint_id) + checkpoint_id: DCP checkpoint ID for DCP loading (mutually exclusive with version) + tracer_name: Name for the performance tracer + + Returns: + Dictionary mapping parameter names to SharedTensorHandles + """ + t = Tracer(tracer_name) t.start() - prefix = get_param_prefix(version) - matching_keys = await ts.keys(prefix) - hf_param_names = [extract_param_name(key) for key in matching_keys] n_fetchers = self.weight_fetchers.size() @@ -270,10 +281,15 @@ def split_keys(keys): return [keys[i::n_fetchers] for i in range(n_fetchers)] futures = [] - for i, names in enumerate(split_keys(hf_param_names)): - fut = self.weight_fetchers.slice(procs=i).fetch.call_one( - version=version, param_names=names - ) + for i, names in enumerate(split_keys(param_names)): + if checkpoint_id is not None: + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( + checkpoint_id=checkpoint_id, param_names=names + ) + else: + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( + version=version, param_names=names + ) futures.append(fut) sub_state_dicts = [await fut for fut in futures] @@ -286,43 +302,41 @@ def split_keys(keys): return state_dict + async def _fetch_weights( + self, + version: int, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + + return await self._fetch_weights_parallel( + param_names=hf_param_names, + version=version, + tracer_name="generator_perf/_fetch_weights", + ) + async def _fetch_weights_dcp( self, version: int, ) -> dict[str, SharedTensorHandle]: """Fetch weights from DCP checkpoint and return a dict of {name: SharedTensorHandle}.""" - t = Tracer("generator_perf/_fetch_weights_dcp") - t.start() - # Get the DCP handle from torchstore dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) dcp_handle = await ts.get(dcp_whole_state_dict_key) hf_param_names = dcp_handle.param_names checkpoint_id = dcp_handle.checkpoint_id - n_fetchers = self.weight_fetchers.size() - - def split_keys(keys): - return [keys[i::n_fetchers] for i in range(n_fetchers)] - - futures = [] - for i, names in enumerate(split_keys(hf_param_names)): - fut = self.weight_fetchers.slice(procs=i).fetch_dcp.call_one( - checkpoint_id=checkpoint_id, param_names=names - ) - futures.append(fut) - - sub_state_dicts = [await fut for fut in futures] - - state_dict = {} - for sd in sub_state_dicts: - state_dict.update(sd) + state_dict = await self._fetch_weights_parallel( + param_names=hf_param_names, + checkpoint_id=checkpoint_id, + tracer_name="generator_perf/_fetch_weights_dcp", + ) # Clean up the DCP handle after fetching to shared memory dcp_handle.drop() - t.stop() - return state_dict @endpoint @@ -785,40 +799,29 @@ class _WeightFetcher(ForgeActor): async def fetch( self, *, - version: int, + version: int | None = None, + checkpoint_id: str | None = None, param_names: list[str], ) -> dict[str, SharedTensorHandle]: - """Fetch weights from torchstore and load them into shared memory.""" - sd = {} - for name in param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - # Use context manager to ensure cleanup after getting handle - with SharedTensor(tensor=param) as shared_tensor: - handle = shared_tensor.get_handle() - sd[name] = handle - del param # Explicitly free the tensor after copying to shared memory - return sd + """Fetch weights and load them into shared memory. + + Args: + version: Version number for individual tensor loading (mutually exclusive with checkpoint_id) + checkpoint_id: DCP checkpoint ID for DCP loading (mutually exclusive with version) + param_names: List of parameter names to fetch + + Returns: + Dictionary mapping parameter names to SharedTensorHandles + """ + from torch.distributed.checkpoint.metadata import load_metadata - @endpoint - async def fetch_dcp( - self, - *, - checkpoint_id: str, - param_names: list[str], - ) -> dict[str, SharedTensorHandle]: - """Fetch weights from DCP checkpoint and load them into shared memory.""" from forge.actors._torchstore_utils import DcpHandle sd = {} - # Create a minimal DCP handle for loading tensors - # We only need checkpoint_id and metadata, which load_tensor_from_dcp will fetch - for name in param_names: - # Load tensor from DCP using the checkpoint_id - # Note: load_tensor_from_dcp handles reading from the DCP checkpoint directory - from torch.distributed.checkpoint.metadata import load_metadata - # Load metadata if not already loaded (first time) + # Setup for DCP loading if checkpoint_id is provided + if checkpoint_id is not None: + # Load metadata if not already cached for this checkpoint if ( not hasattr(self, "_dcp_metadata") or getattr(self, "_dcp_checkpoint_id", None) != checkpoint_id @@ -833,8 +836,15 @@ async def fetch_dcp( param_names=param_names, ) - # Load the tensor from DCP - param = load_tensor_from_dcp(dcp_handle, name) + # Fetch each parameter + for name in param_names: + if checkpoint_id is not None: + # Load tensor from DCP checkpoint + param = load_tensor_from_dcp(dcp_handle, name) + else: + # Load tensor from torchstore + param_key = get_param_key(version, name) + param = await ts.get(param_key) # Use context manager to ensure cleanup after getting handle with SharedTensor(tensor=param) as shared_tensor: From 75d23346e425181410fea4aaaa132e703d262635 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 20 Oct 2025 14:03:23 -0700 Subject: [PATCH 3/4] Fix: Don't drop DCP handle after fetching to shared memory The DCP handle should not be dropped immediately after fetching weights to shared memory. Dropping it will delete the checkpoint files on disk, which we need to keep for potential recovery if something goes wrong. The checkpoint cleanup should happen later when we're certain we don't need the checkpoint for recovery. --- src/forge/actors/generator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index e12eda57c..ecd70aadc 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -328,17 +328,12 @@ async def _fetch_weights_dcp( hf_param_names = dcp_handle.param_names checkpoint_id = dcp_handle.checkpoint_id - state_dict = await self._fetch_weights_parallel( + return await self._fetch_weights_parallel( param_names=hf_param_names, checkpoint_id=checkpoint_id, tracer_name="generator_perf/_fetch_weights_dcp", ) - # Clean up the DCP handle after fetching to shared memory - dcp_handle.drop() - - return state_dict - @endpoint async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt From b728ed94ae41021162a894fb2bbe762ff7ac847f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 20 Oct 2025 14:10:20 -0700 Subject: [PATCH 4/4] Refactor: Rename checkpoint_id to dcp_key and use ts.get for DCP handle - Renamed 'checkpoint_id' parameter to 'dcp_key' for clarity - The parameter is actually the torchstore key (e.g., 'policy_ver_X.dcp_whole_state_dict') - Not the actual checkpoint_id from the DCP handle itself - Each fetcher now calls ts.get(dcp_key) to retrieve the DCP handle - This gives access to both metadata and the actual checkpoint path - More efficient than manually loading metadata in each fetcher - Removed redundant metadata loading and DcpHandle construction code --- src/forge/actors/generator.py | 48 ++++++++++++----------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index ecd70aadc..164933082 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -258,15 +258,15 @@ async def _fetch_weights_parallel( param_names: list[str], *, version: int | None = None, - checkpoint_id: str | None = None, + dcp_key: str | None = None, tracer_name: str, ) -> dict[str, SharedTensorHandle]: """Fetch weights in parallel using multiple fetcher processes. Args: param_names: List of parameter names to fetch - version: Version number for individual tensor loading (mutually exclusive with checkpoint_id) - checkpoint_id: DCP checkpoint ID for DCP loading (mutually exclusive with version) + version: Version number for individual tensor loading (mutually exclusive with dcp_key) + dcp_key: Torchstore key for DCP handle (mutually exclusive with version) tracer_name: Name for the performance tracer Returns: @@ -282,9 +282,9 @@ def split_keys(keys): futures = [] for i, names in enumerate(split_keys(param_names)): - if checkpoint_id is not None: + if dcp_key is not None: fut = self.weight_fetchers.slice(procs=i).fetch.call_one( - checkpoint_id=checkpoint_id, param_names=names + dcp_key=dcp_key, param_names=names ) else: fut = self.weight_fetchers.slice(procs=i).fetch.call_one( @@ -322,15 +322,15 @@ async def _fetch_weights_dcp( version: int, ) -> dict[str, SharedTensorHandle]: """Fetch weights from DCP checkpoint and return a dict of {name: SharedTensorHandle}.""" - # Get the DCP handle from torchstore + # Get the DCP handle from torchstore to access param names dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) dcp_handle = await ts.get(dcp_whole_state_dict_key) hf_param_names = dcp_handle.param_names - checkpoint_id = dcp_handle.checkpoint_id + # Pass the DCP torchstore key so each fetcher can get the handle return await self._fetch_weights_parallel( param_names=hf_param_names, - checkpoint_id=checkpoint_id, + dcp_key=dcp_whole_state_dict_key, tracer_name="generator_perf/_fetch_weights_dcp", ) @@ -795,45 +795,29 @@ async def fetch( self, *, version: int | None = None, - checkpoint_id: str | None = None, + dcp_key: str | None = None, param_names: list[str], ) -> dict[str, SharedTensorHandle]: """Fetch weights and load them into shared memory. Args: - version: Version number for individual tensor loading (mutually exclusive with checkpoint_id) - checkpoint_id: DCP checkpoint ID for DCP loading (mutually exclusive with version) + version: Version number for individual tensor loading (mutually exclusive with dcp_key) + dcp_key: Torchstore key for DCP handle (mutually exclusive with version) param_names: List of parameter names to fetch Returns: Dictionary mapping parameter names to SharedTensorHandles """ - from torch.distributed.checkpoint.metadata import load_metadata - - from forge.actors._torchstore_utils import DcpHandle - sd = {} - # Setup for DCP loading if checkpoint_id is provided - if checkpoint_id is not None: - # Load metadata if not already cached for this checkpoint - if ( - not hasattr(self, "_dcp_metadata") - or getattr(self, "_dcp_checkpoint_id", None) != checkpoint_id - ): - self._dcp_metadata = load_metadata(checkpoint_id) - self._dcp_checkpoint_id = checkpoint_id - - # Create a DCP handle with the loaded metadata - dcp_handle = DcpHandle( - checkpoint_id=checkpoint_id, - metadata=self._dcp_metadata, - param_names=param_names, - ) + # Setup for DCP loading if dcp_key is provided + if dcp_key is not None: + # Get the DCP handle from torchstore - this gives us the metadata and checkpoint path + dcp_handle = await ts.get(dcp_key) # Fetch each parameter for name in param_names: - if checkpoint_id is not None: + if dcp_key is not None: # Load tensor from DCP checkpoint param = load_tensor_from_dcp(dcp_handle, name) else: