Skip to content

Conversation

@Aki-07
Copy link

@Aki-07 Aki-07 commented Nov 29, 2025

What does this PR do?

Fixes #11966

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul

@Aki-07 Aki-07 force-pushed the feature/group-offload-pinning branch from 7e50d90 to 3b3813d Compare November 29, 2025 14:03
@sayakpaul
Copy link
Member

Thanks for your PR. However, it's being worked on in #12721.

@sayakpaul
Copy link
Member

sayakpaul commented Dec 9, 2025

Could we resolve conflicts so that it's a bit easier to review? Seems like there's some overlap from #12692.

@Aki-07 Aki-07 force-pushed the feature/group-offload-pinning branch from 6d96002 to 33d8b52 Compare December 10, 2025 06:06
@Aki-07
Copy link
Author

Aki-07 commented Dec 10, 2025

Done! Rebased on latest main and resolved conflicts with #12692. Should be much cleaner to review now.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments.

Comment on lines 310 to 312
should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if non_blocking=True?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with non_blocking=True, if a previous group onloaded this one on a side stream, we need a sync before the default stream uses the weights or we risk reading half-copied tensors. I’ve limited the sync to the record_stream=False case, when record_stream=True the tensors are tied to the consumer stream so we can safely skip the sync.

Comment on lines 363 to 364
if len(tensors) == 0:
return True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means the group is empty. Why would we return True for this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, that was misleading. Now an empty group returns False so a ‘pinned’ empty group will still onload instead of claiming it’s already on device.

Comment on lines 643 to 651
normalized_pin_groups = pin_groups
if isinstance(pin_groups, str):
normalized_pin_groups = pin_groups.lower()
if normalized_pin_groups not in {"first_last", "all"}:
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
elif pin_groups is not None and not callable(pin_groups):
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")

pin_groups = normalized_pin_groups
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): would prefer to have a small utility function: _normalize_pin_groups().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a _normalize_pin_groups helper fn()

for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Check if this is an explicitly defined block module
if name in block_modules:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if name in block_modules:
if block_modules and name in block_modules:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

block_modules: Optional[List[str]] = None
exclude_kwargs: Optional[List[str]] = None
module_prefix: Optional[str] = ""
pin_groups: Optional[Union[str, Callable]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a breaking change. Could you please elaborate?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging, have fixed to default it. Sorry for the oversight

tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor.data.record_stream(default_stream)
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on this change?
#12721 explains why it's the way it is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed to the correct behavior from #12721, record the tensor on the consumer/default stream ( captured before entering the transfer stream ) so its lifetime is tied to the forward stream

@bconstantine
Copy link

Thank you for the initial comment! We are working on the solutions right now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

How about forcing the first and last block on device when groupoffloading is used?

3 participants