-
Notifications
You must be signed in to change notification settings - Fork 11
[Transforms] Enable shared memory and introduce permutations #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: transform_apply_support
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this makes sense 👍
@@ -78,7 +79,7 @@ def load_transforms(model: Module, model_name_or_path: str): | |||
|
|||
state_dict = {} | |||
for weight_name, safe_path in weight_mappings.items(): | |||
if "transform" in weight_name: | |||
if "transform" in weight_name or "_perm_" in weight_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like we have to do some sort of name matching, but i'm wondering if some name collision down the road is going to cause this to run when we don't want it? if we came up with a more unique name or something to prevent false positives
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have this problem with any parameter we introduce (e.g weight_scale, weight_g_idx, etc) but yeah, we can work on making them more unique
): | ||
if module is None: | ||
self.transform.data.copy_(data) | ||
if self.permutation is not None and permutation_data is not None: | ||
self.permutation.data.copy_(permutation_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need update_parameter_data
here too in case of offloading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is in the case if the parameter isn't registered to a module for whatever reason. update_parameter_data handles module_params. I'm not sure if this case is totally necessary but yeah, we would have to add offloading/onloading around it if we decide to keep it
@@ -129,7 +135,7 @@ def _matmul_hadU(X, transpose=False): | |||
input = hadK.view(1, K, K).to(input) @ input | |||
|
|||
# normalize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment might need to go too if we are not normalizing?
__all__ = ["apply_matrix_transform", "SingletonMatrixRegistry"] | ||
|
||
|
||
class SingletonMatrixRegistry: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so that all matrices live in a single global key-value store, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, we can expand this but it seems like there will be a lot of repetition across decoder layers for example
I think if this goes too big in scope, we may have to consider other data stores to handle it
__all__ = ["apply_matrix_transform", "SingletonMatrixRegistry"] | ||
|
||
|
||
class SingletonMatrixRegistry: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is a new class singleton necessary? Since each matrix Transform requires own registry, shouldn't this be implemented on the class itself?
class Hadamard(Transforms):
registry: Dict[int, torch.Tensor] = {}
def __new__(cls, size, empty, transform_name, *args, **kwargs):
if empty:
matrix = ...
else:
matrix = cls.registry.get(size, torch.Tensor(deterministic_hadamard_matrix(size=self.size)))
return super().__new__(transform=matrix, transform_name=transform_name)
self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) | ||
self.transform = torch.nn.Parameter(transform, requires_grad=False) | ||
self.transform_name = transform_name | ||
self.permutation = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For randomized hadamards, splits up the math so that the underlying hadamard can be cached and the randomness is introduced as a separate permutation matrix
Since permutations are specific to RandomHadamard
s, shouldn't we be implementing this logic on the RandomHadamard
class, not the general Transforms
class?
Summary
1. Makes a series of updates to the registry by supporting matrix caching of the transform parameter
2. Also uses shared memory, so that layers with identical transforms use the same underlying transform data.
3. Move update/register functionality to be done inside the registry; introduce permutation parameter
4. Swap global to be called "shared"