From 86cd00cad207bb841c7baeb78a0b0cd912e56230 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Jul 2022 09:56:47 +0200 Subject: [PATCH 1/5] init --- functorch/_src/make_functional.py | 151 ++++++++++++++++++++++++------ 1 file changed, 121 insertions(+), 30 deletions(-) diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index 6075466aa..e4c5c1321 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -207,69 +207,160 @@ def make_split_names(lst): return [name.split('.') for name in lst] -class FunctionalModuleWithBuffers(nn.Module): +class FunctionalModule(nn.Module): """ This is the callable object returned by :func:`make_functional_with_buffers`. """ - def __init__(self, stateless_model, param_names, buffer_names): - super(FunctionalModuleWithBuffers, self).__init__() + def __init__(self, stateless_model, module_names, modules, param_names): + super(FunctionalModule, self).__init__() self.stateless_model = stateless_model + self.param_modules = modules self.param_names = param_names - self.buffer_names = buffer_names - self.split_names = make_split_names(param_names + buffer_names) + self.module_names = module_names @staticmethod - def _create_from(model): + def _create_from(model, disable_autograd_tracking=False): # TODO: We don't need to copy the model to create a stateless copy model_copy = copy.deepcopy(model) - params, param_names = extract_weights(model_copy) - buffers, buffer_names = extract_buffers(model_copy) + param_container = list(extract_weights(model_copy)) + if len(param_container): + module_names, param_modules, param_names, params = zip(*param_container) + else: + module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return ( - FunctionalModuleWithBuffers(model_copy, param_names, buffer_names), + FunctionalModule(model_copy, module_names, param_modules, param_names), params, - buffers, ) - def forward(self, params, buffers, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state( - self.stateless_model, - self.split_names, - list(params) + list(buffers)) + def forward(self, params, *args, **kwargs): + old_params = [] + for module, param_name, param in zip(self.param_modules, self.param_names, params): + old_params.append(getattr(module, param_name)) + setattr(module, param_name, param) + try: return self.stateless_model(*args, **kwargs) finally: # Remove the loaded state on self.stateless_model - _swap_state(self.stateless_model, self.split_names, old_state) - + for module, param_name, param in zip(self.param_modules, self.param_names, old_params): + old_params.append(getattr(module, param_name)) + setattr(module, param_name, param) + + def __getstate__(self): + state = self.__dict__.copy() + state["param_modules"] = None + return state + + def __setstate__(self, state): + state["param_modules"] = [] + for module_name in self.module_names: + found = False + for other_name, module in self.named_modules(): + if other_name == module_name: + state["param_modules"].append(module) + break + if not found: + raise RuntimeError("module not found") + return super().__setstate__(state) -class FunctionalModule(nn.Module): +class FunctionalModuleWithBuffers(nn.Module): """ - This is the callable object returned by :func:`make_functional`. + This is the callable object returned by :func:`make_functional_with_buffers`. """ - def __init__(self, stateless_model, param_names): - super(FunctionalModule, self).__init__() + def __init__(self, stateless_model, param_module_names, param_modules, param_names, buffer_module_names, buffer_modules, buffer_names): + super(FunctionalModuleWithBuffers, self).__init__() self.stateless_model = stateless_model + self.param_module_names = param_module_names + self.param_modules = param_modules self.param_names = param_names - self.split_names = make_split_names(param_names) + self.buffer_module_names = buffer_module_names + self.buffer_modules = buffer_modules + self.buffer_names = buffer_names @staticmethod - def _create_from(model): + def _create_from(model, disable_autograd_tracking=False): # TODO: We don't need to copy the model to create a stateless copy model_copy = copy.deepcopy(model) - params, param_names = extract_weights(model_copy) - return FunctionalModule(model_copy, param_names), params + param_container = list(extract_weights(model_copy)) + if len(param_container): + param_module_names, param_modules, param_names, params = zip(*param_container) + else: + param_module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + + buffer_container = list(extract_buffers(model_copy)) + if len(buffer_container): + buffer_module_names, buffer_modules, buffer_names, buffers = zip(*buffer_container) + else: + buffer_module_names, buffer_modules, buffer_names, buffers = tuple(), tuple(), tuple(), tuple() + return ( + FunctionalModuleWithBuffers( + model_copy, + param_module_names, + param_modules, + param_names, + buffer_module_names, + buffer_modules, + buffer_names), + params, + buffers, + ) + + def forward(self, params, buffers, *args, **kwargs): + old_params = [] + for module, param_name, param in zip(self.param_modules, self.param_names, params): + old_params.append(getattr(module, param_name)) + setattr(module, param_name, param) + old_buffers = [] + for module, buffer_name, buffer in zip(self.buffer_modules, self.buffer_names, buffers): + old_buffers.append(getattr(module, buffer_name)) + setattr(module, buffer_name, buffer) - def forward(self, params, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state(self.stateless_model, self.split_names, params) try: return self.stateless_model(*args, **kwargs) finally: # Remove the loaded state on self.stateless_model - _swap_state(self.stateless_model, self.split_names, old_state) + for module, param_name, param in zip(self.param_modules, self.param_names, old_params): + old_params.append(getattr(module, param_name)) + setattr(module, param_name, param) + for module, buffer_name, buffer in zip(self.buffer_modules, self.buffer_names, old_buffers): + old_buffers.append(getattr(module, buffer_name)) + setattr(module, buffer_name, buffer) + + def __getstate__(self): + state = self.__dict__.copy() + state["param_modules"] = None + state["buffer_modules"] = None + return state + + def __setstate__(self, state): + state["param_modules"] = [] + state["buffer_modules"] = [] + for module_name in self.module_names: + found = False + for other_name, module in self.named_modules(): + if other_name == module_name: + state["param_modules"].append(module) + break + if not found: + raise RuntimeError("module not found") + for module_name in self.module_names: + found = False + for other_name, module in self.named_modules(): + if other_name == module_name: + state["param_modules"].append(module) + break + if not found: + raise RuntimeError("module not found") + return super().__setstate__(state) def make_functional(model: nn.Module): From 58e38c873a07ee17d334d248d7037011415d2f78 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Jul 2022 10:20:14 +0200 Subject: [PATCH 2/5] init --- functorch/_src/make_functional.py | 76 ++++++++++++------------------- 1 file changed, 29 insertions(+), 47 deletions(-) diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index e4c5c1321..6d8bb45e1 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -46,27 +46,11 @@ def _get_nested_attr(obj: nn.Module, names: List[str]) -> None: _get_nested_attr(getattr(obj, names[0]), names[1:]) -def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]: - """ - This function removes all the Parameters from the model and - return them as a tuple as well as their original attribute names. - The weights must be re-loaded with `load_weights` before the model - can be used again. - Note that this function modifies the model in place and after this - call, mod.parameters() will be empty. - """ - orig_params = tuple(mod.parameters()) - # Remove all the parameters in the model - names = [] - for name, p in list(mod.named_parameters()): - replacement = nn.Parameter(torch.empty_like(p, device='meta')) - _set_nested_attr(mod, name.split("."), replacement) - names.append(name) - - # Make params regular Tensors instead of nn.Parameter - params = tuple(p for p in orig_params) - return params, names - +def extract_weights(model): + for module_name, m in model.named_modules(): + for param_name, p in m.named_parameters(recurse=False): + setattr(m, param_name, None) + yield (module_name, m, param_name, p) def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None: """ @@ -90,18 +74,11 @@ def _swap_state(mod: nn.Module, split_names: List[str], elems): return result -def extract_buffers(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]: - orig_params = tuple(mod.buffers()) - # Remove all the parameters in the model - names = [] - for name, p in list(mod.named_buffers()): - replacement = torch.empty_like(p, device='meta') - _set_nested_attr(mod, name.split("."), replacement) - names.append(name) - - # Make params regular Tensors instead of nn.Parameter - params = tuple(p for p in orig_params) - return params, names +def extract_buffers(model): + for module_name, m in model.named_modules(): + for buffer_name, b in m.named_buffers(recurse=False): + setattr(m, buffer_name, None) + yield (module_name, m, buffer_name, b) def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None: @@ -212,12 +189,12 @@ class FunctionalModule(nn.Module): This is the callable object returned by :func:`make_functional_with_buffers`. """ - def __init__(self, stateless_model, module_names, modules, param_names): + def __init__(self, stateless_model, param_module_names, modules, param_names): super(FunctionalModule, self).__init__() self.stateless_model = stateless_model self.param_modules = modules self.param_names = param_names - self.module_names = module_names + self.param_module_names = param_module_names @staticmethod def _create_from(model, disable_autograd_tracking=False): @@ -258,15 +235,17 @@ def __getstate__(self): def __setstate__(self, state): state["param_modules"] = [] - for module_name in self.module_names: + out = super().__setstate__(state) + for module_name in self.param_module_names: found = False - for other_name, module in self.named_modules(): + for other_name, module in self.stateless_model.named_modules(): if other_name == module_name: + found = True state["param_modules"].append(module) break if not found: - raise RuntimeError("module not found") - return super().__setstate__(state) + raise RuntimeError(f"module not found: {module_name}") + return out class FunctionalModuleWithBuffers(nn.Module): """ @@ -344,23 +323,26 @@ def __getstate__(self): def __setstate__(self, state): state["param_modules"] = [] state["buffer_modules"] = [] - for module_name in self.module_names: + out = super().__setstate__(state) + for module_name in self.param_module_names: found = False - for other_name, module in self.named_modules(): + for other_name, module in self.stateless_model.named_modules(): if other_name == module_name: + found = True state["param_modules"].append(module) break if not found: - raise RuntimeError("module not found") - for module_name in self.module_names: + raise RuntimeError(f"module not found: {module_name}") + for module_name in self.buffer_module_names: found = False - for other_name, module in self.named_modules(): + for other_name, module in self.stateless_model.named_modules(): if other_name == module_name: - state["param_modules"].append(module) + found = True + state["buffer_modules"].append(module) break if not found: - raise RuntimeError("module not found") - return super().__setstate__(state) + raise RuntimeError(f"module not found: {module_name}") + return out def make_functional(model: nn.Module): From a939817f08e11cd23c0910f12c870585169ead6f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Jul 2022 20:49:07 +0100 Subject: [PATCH 3/5] amend --- functorch/_src/make_functional.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index 6d8bb45e1..3e1b868f9 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -48,7 +48,8 @@ def _get_nested_attr(obj: nn.Module, names: List[str]) -> None: def extract_weights(model): for module_name, m in model.named_modules(): - for param_name, p in m.named_parameters(recurse=False): + for param_name, p in list(m.named_parameters(recurse=False)): + delattr(m, param_name) setattr(m, param_name, None) yield (module_name, m, param_name, p) @@ -76,7 +77,8 @@ def _swap_state(mod: nn.Module, split_names: List[str], elems): def extract_buffers(model): for module_name, m in model.named_modules(): - for buffer_name, b in m.named_buffers(recurse=False): + for buffer_name, b in list(m.named_buffers(recurse=False)): + delattr(m, buffer_name) setattr(m, buffer_name, None) yield (module_name, m, buffer_name, b) @@ -218,6 +220,7 @@ def forward(self, params, *args, **kwargs): old_params = [] for module, param_name, param in zip(self.param_modules, self.param_names, params): old_params.append(getattr(module, param_name)) + delattr(module, param_name) setattr(module, param_name, param) try: @@ -297,10 +300,12 @@ def forward(self, params, buffers, *args, **kwargs): old_params = [] for module, param_name, param in zip(self.param_modules, self.param_names, params): old_params.append(getattr(module, param_name)) + delattr(module, param_name) setattr(module, param_name, param) old_buffers = [] for module, buffer_name, buffer in zip(self.buffer_modules, self.buffer_names, buffers): old_buffers.append(getattr(module, buffer_name)) + delattr(module, buffer_name) setattr(module, buffer_name, buffer) try: From 76d14c7c73daea6e4d6ee33ff19de7e98815a58d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Jul 2022 21:04:40 +0100 Subject: [PATCH 4/5] amend --- functorch/_src/make_functional.py | 58 +++++++++++++++++++------------ 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index 3e1b868f9..420fb44b7 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -53,6 +53,7 @@ def extract_weights(model): setattr(m, param_name, None) yield (module_name, m, param_name, p) + def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None: """ Reload a set of weights so that `mod` can be used again to perform a forward pass. @@ -66,13 +67,21 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a _set_nested_attr(mod, name.split("."), p) -def _swap_state(mod: nn.Module, split_names: List[str], elems): - result = [] - for split_name, elem in zip(split_names, elems): - result.append(_get_nested_attr(mod, split_name)) - _del_nested_attr(mod, split_name) - _set_nested_attr(mod, split_name, elem) - return result +# def _swap_state(mod: nn.Module, split_names: List[str], elems): +# result = [] +# for split_name, elem in zip(split_names, elems): +# result.append(_get_nested_attr(mod, split_name)) +# _del_nested_attr(mod, split_name) +# _set_nested_attr(mod, split_name, elem) +# return result + +def _swap_state(param_modules, param_names, params): + old_params = [] + for module, param_name, param in zip(param_modules, param_names, params): + old_params.append(getattr(module, param_name)) + delattr(module, param_name) + setattr(module, param_name, param) + return old_params def extract_buffers(model): @@ -217,11 +226,7 @@ def _create_from(model, disable_autograd_tracking=False): ) def forward(self, params, *args, **kwargs): - old_params = [] - for module, param_name, param in zip(self.param_modules, self.param_names, params): - old_params.append(getattr(module, param_name)) - delattr(module, param_name) - setattr(module, param_name, param) + old_params = _swap_state(self.param_modules, self.param_names, params) try: return self.stateless_model(*args, **kwargs) @@ -250,12 +255,22 @@ def __setstate__(self, state): raise RuntimeError(f"module not found: {module_name}") return out + class FunctionalModuleWithBuffers(nn.Module): """ This is the callable object returned by :func:`make_functional_with_buffers`. """ - def __init__(self, stateless_model, param_module_names, param_modules, param_names, buffer_module_names, buffer_modules, buffer_names): + def __init__( + self, + stateless_model, + param_module_names, + param_modules, + param_names, + buffer_module_names, + buffer_modules, + buffer_names + ): super(FunctionalModuleWithBuffers, self).__init__() self.stateless_model = stateless_model self.param_module_names = param_module_names @@ -297,16 +312,13 @@ def _create_from(model, disable_autograd_tracking=False): ) def forward(self, params, buffers, *args, **kwargs): - old_params = [] - for module, param_name, param in zip(self.param_modules, self.param_names, params): - old_params.append(getattr(module, param_name)) - delattr(module, param_name) - setattr(module, param_name, param) - old_buffers = [] - for module, buffer_name, buffer in zip(self.buffer_modules, self.buffer_names, buffers): - old_buffers.append(getattr(module, buffer_name)) - delattr(module, buffer_name) - setattr(module, buffer_name, buffer) + old_states = _swap_state( + self.param_modules + self.buffer_modules, + self.param_names + self.buffer_names, + list(params) + list(buffers) + ) + old_params = old_states[:len(self.param_modules)] + old_buffers = old_states[len(self.param_modules):] try: return self.stateless_model(*args, **kwargs) From f9787f9133c0b332fe410f263bb61bcd9a755308 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 24 Jul 2022 21:28:25 +0100 Subject: [PATCH 5/5] amend --- functorch/_src/make_functional.py | 122 +++++++++++++++--------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index 90114e924..cc3f880b3 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -196,67 +196,6 @@ def make_split_names(lst): return [name.split('.') for name in lst] -class FunctionalModule(nn.Module): - """ - This is the callable object returned by :func:`make_functional_with_buffers`. - """ - - def __init__(self, stateless_model, param_module_names, modules, param_names): - super(FunctionalModule, self).__init__() - self.stateless_model = stateless_model - self.param_modules = modules - self.param_names = param_names - self.param_module_names = param_module_names - - @staticmethod - def _create_from(model): - # TODO: We don't need to copy the model to create a stateless copy - model_copy = copy.deepcopy(model) - param_container = list(extract_weights(model_copy)) - if len(param_container): - module_names, param_modules, param_names, params = zip(*param_container) - else: - module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() - if disable_autograd_tracking: - for param in params: - param.requires_grad_(False) - - return ( - FunctionalModule(model_copy, module_names, param_modules, param_names), - params, - ) - - def forward(self, params, *args, **kwargs): - old_params = _swap_state(self.param_modules, self.param_names, params) - - try: - return self.stateless_model(*args, **kwargs) - finally: - # Remove the loaded state on self.stateless_model - for module, param_name, param in zip(self.param_modules, self.param_names, old_params): - old_params.append(getattr(module, param_name)) - setattr(module, param_name, param) - - def __getstate__(self): - state = self.__dict__.copy() - state["param_modules"] = None - return state - - def __setstate__(self, state): - state["param_modules"] = [] - out = super().__setstate__(state) - for module_name in self.param_module_names: - found = False - for other_name, module in self.stateless_model.named_modules(): - if other_name == module_name: - found = True - state["param_modules"].append(module) - break - if not found: - raise RuntimeError(f"module not found: {module_name}") - return out - - class FunctionalModuleWithBuffers(nn.Module): """ This is the callable object returned by :func:`make_functional_with_buffers`. @@ -363,6 +302,67 @@ def __setstate__(self, state): return out +class FunctionalModule(nn.Module): + """ + This is the callable object returned by :func:`make_functional_with_buffers`. + """ + + def __init__(self, stateless_model, param_module_names, modules, param_names): + super(FunctionalModule, self).__init__() + self.stateless_model = stateless_model + self.param_modules = modules + self.param_names = param_names + self.param_module_names = param_module_names + + @staticmethod + def _create_from(model): + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + param_container = list(extract_weights(model_copy)) + if len(param_container): + module_names, param_modules, param_names, params = zip(*param_container) + else: + module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + + return ( + FunctionalModule(model_copy, module_names, param_modules, param_names), + params, + ) + + def forward(self, params, *args, **kwargs): + old_params = _swap_state(self.param_modules, self.param_names, params) + + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + for module, param_name, param in zip(self.param_modules, self.param_names, old_params): + old_params.append(getattr(module, param_name)) + setattr(module, param_name, param) + + def __getstate__(self): + state = self.__dict__.copy() + state["param_modules"] = None + return state + + def __setstate__(self, state): + state["param_modules"] = [] + out = super().__setstate__(state) + for module_name in self.param_module_names: + found = False + for other_name, module in self.stateless_model.named_modules(): + if other_name == module_name: + found = True + state["param_modules"].append(module) + break + if not found: + raise RuntimeError(f"module not found: {module_name}") + return out + + def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): """make_functional(model, disable_autograd_tracking=False) -> func, params