-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Feature/group offload pinning #12747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature/group offload pinning #12747
Conversation
7e50d90 to
3b3813d
Compare
|
Thanks for your PR. However, it's being worked on in #12721. |
|
Could we resolve conflicts so that it's a bit easier to review? Seems like there's some overlap from #12692. |
6d96002 to
33d8b52
Compare
|
Done! Rebased on latest main and resolved conflicts with #12692. Should be much cleaner to review now. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments.
| should_synchronize = ( | ||
| not self.group.onload_self and self.group.stream is not None and not should_onload_next_group | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if non_blocking=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even with non_blocking=True, if a previous group onloaded this one on a side stream, we need a sync before the default stream uses the weights or we risk reading half-copied tensors. I’ve limited the sync to the record_stream=False case, when record_stream=True the tensors are tied to the consumer stream so we can safely skip the sync.
| if len(tensors) == 0: | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means the group is empty. Why would we return True for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, that was misleading. Now an empty group returns False so a ‘pinned’ empty group will still onload instead of claiming it’s already on device.
| normalized_pin_groups = pin_groups | ||
| if isinstance(pin_groups, str): | ||
| normalized_pin_groups = pin_groups.lower() | ||
| if normalized_pin_groups not in {"first_last", "all"}: | ||
| raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.") | ||
| elif pin_groups is not None and not callable(pin_groups): | ||
| raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.") | ||
|
|
||
| pin_groups = normalized_pin_groups |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): would prefer to have a small utility function: _normalize_pin_groups().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a _normalize_pin_groups helper fn()
| for name, submodule in module.named_children(): | ||
| if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): | ||
| # Check if this is an explicitly defined block module | ||
| if name in block_modules: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if name in block_modules: | |
| if block_modules and name in block_modules: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
| block_modules: Optional[List[str]] = None | ||
| exclude_kwargs: Optional[List[str]] = None | ||
| module_prefix: Optional[str] = "" | ||
| pin_groups: Optional[Union[str, Callable]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a breaking change. Could you please elaborate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for flagging, have fixed to default it. Sorry for the oversight
| tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if self.record_stream: | ||
| tensor.data.record_stream(default_stream) | ||
| tensor.data.record_stream(self._torch_accelerator_module.current_stream()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate on this change?
#12721 explains why it's the way it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed to the correct behavior from #12721, record the tensor on the consumer/default stream ( captured before entering the transfer stream ) so its lifetime is tied to the forward stream
|
Thank you for the initial comment! We are working on the solutions right now |
What does this PR do?
Fixes #11966
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul