diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index 7b8c15196..cc3f880b3 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -8,7 +8,6 @@ import torch.nn as nn from torch import Tensor from typing import List, Tuple -from .named_members_polyfill import _named_parameters, _named_buffers import copy # Utilities to make nn.Module "functional" @@ -56,66 +55,12 @@ def raise_parameter_tying_error(): "https://github.com/pytorch/functorch/issues/446") -def create_names_map(named_params, tied_named_params): - """ - named_params is a dictionary of tensors: {'A': A, 'B': B} - tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} - with potentially tied (or 'duplicated') tensors - - This function creates a mapping from the names in named_params to the - names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. - """ - named_params = {k: v for k, v in named_params} - tied_named_params = {k: v for k, v in tied_named_params} - - tensors_dict_keys = set(named_params.keys()) - tied_tensors_dict_keys = set(tied_named_params.keys()) - assert tensors_dict_keys.issubset(tied_tensors_dict_keys) - - tensor_to_mapping = {} - for key, tensor in named_params.items(): - tensor_to_mapping[tensor] = (key, []) - for key, tensor in tied_named_params.items(): - assert tensor in tensor_to_mapping - tensor_to_mapping[tensor][1].append(key.split('.')) - result = {key: value for key, value in tensor_to_mapping.values()} - return result - - -def _extract_members(mod: nn.Module, _named_members, named_members, subclass): - all_named_members = tuple(_named_members(mod, remove_duplicate=False)) - named_members = tuple(named_members()) - names_map = create_names_map(named_members, all_named_members) - - # Remove all the members in the model - memo = {} - for name, p in all_named_members: - if p not in memo: - memo[p] = subclass(torch.empty_like(p, device='meta')) - replacement = memo[p] - _set_nested_attr(mod, name.split("."), replacement) - - if len(named_members) == 0: - names, params = (), () - else: - names, params = zip(*named_members) - return params, names, names_map - - -def extract_weights(mod: nn.Module): - """ - This function removes all the Parameters from the model and - return them as a tuple as well as their original attribute names. - The weights must be re-loaded with `load_weights` before the model - can be used again. - Note that this function modifies the model in place and after this - call, mod.parameters() will be empty. - """ - return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter) - - -def extract_buffers(mod: nn.Module): - return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x) +def extract_weights(model): + for module_name, m in model.named_modules(): + for param_name, p in list(m.named_parameters(recurse=False)): + delattr(m, param_name) + setattr(m, param_name, None) + yield (module_name, m, param_name, p) def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None: @@ -131,15 +76,21 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a _set_nested_attr(mod, name.split("."), p) -def _swap_state(mod: nn.Module, names_map: List[str], elems): - result = [] - for (_, attr_names), elem in zip(names_map.items(), elems): - for i, attr_name in enumerate(attr_names): - if i == 0: - result.append(_get_nested_attr(mod, attr_name)) - _del_nested_attr(mod, attr_name) - _set_nested_attr(mod, attr_name, elem) - return result +def _swap_state(param_modules, param_names, params): + old_params = [] + for module, param_name, param in zip(param_modules, param_names, params): + old_params.append(getattr(module, param_name)) + delattr(module, param_name) + setattr(module, param_name, param) + return old_params + + +def extract_buffers(model): + for module_name, m in model.named_modules(): + for buffer_name, b in list(m.named_buffers(recurse=False)): + delattr(m, buffer_name) + setattr(m, buffer_name, None) + yield (module_name, m, buffer_name, b) def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None: @@ -194,7 +145,7 @@ def make_functional_deprecated_v1(model: nn.Module): if len(buffers) > 0: raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use ' 'make_functional_with_buffers_deprecated_v1(model) instead.') - weights, descriptors, _ = extract_weights(model) + weights, descriptors = extract_weights(model) def fun(weights, data): mutable_model = copy.deepcopy(model) @@ -241,79 +192,175 @@ def fun(weights, buffers, data): return weights, buffers, fun, weight_descriptors, buf_descriptors +def make_split_names(lst): + return [name.split('.') for name in lst] + + class FunctionalModuleWithBuffers(nn.Module): """ This is the callable object returned by :func:`make_functional_with_buffers`. """ - def __init__(self, stateless_model, param_names, buffer_names, - param_names_map, buffer_names_map): + def __init__( + self, + stateless_model, + param_module_names, + param_modules, + param_names, + buffer_module_names, + buffer_modules, + buffer_names + ): super(FunctionalModuleWithBuffers, self).__init__() self.stateless_model = stateless_model + self.param_module_names = param_module_names + self.param_modules = param_modules self.param_names = param_names + self.buffer_module_names = buffer_module_names + self.buffer_modules = buffer_modules self.buffer_names = buffer_names - self.all_names_map = dict(param_names_map) - self.all_names_map.update(buffer_names_map) - @staticmethod - def _create_from(model, disable_autograd_tracking=False): + def _create_from(model): # TODO: We don't need to copy the model to create a stateless copy model_copy = copy.deepcopy(model) - params, param_names, param_names_map = extract_weights(model_copy) - buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) + param_container = list(extract_weights(model_copy)) + if len(param_container): + param_module_names, param_modules, param_names, params = zip(*param_container) + else: + param_module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() if disable_autograd_tracking: for param in params: param.requires_grad_(False) + + buffer_container = list(extract_buffers(model_copy)) + if len(buffer_container): + buffer_module_names, buffer_modules, buffer_names, buffers = zip(*buffer_container) + else: + buffer_module_names, buffer_modules, buffer_names, buffers = tuple(), tuple(), tuple(), tuple() return ( - FunctionalModuleWithBuffers(model_copy, param_names, buffer_names, - param_names_map, buffer_names_map), + FunctionalModuleWithBuffers( + model_copy, + param_module_names, + param_modules, + param_names, + buffer_module_names, + buffer_modules, + buffer_names), params, buffers, ) def forward(self, params, buffers, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state( - self.stateless_model, - self.all_names_map, - list(params) + list(buffers)) + old_states = _swap_state( + self.param_modules + self.buffer_modules, + self.param_names + self.buffer_names, + list(params) + list(buffers) + ) + old_params = old_states[:len(self.param_modules)] + old_buffers = old_states[len(self.param_modules):] + try: return self.stateless_model(*args, **kwargs) finally: # Remove the loaded state on self.stateless_model - _swap_state(self.stateless_model, self.all_names_map, old_state) + for module, param_name, param in zip(self.param_modules, self.param_names, old_params): + old_params.append(getattr(module, param_name)) + setattr(module, param_name, param) + for module, buffer_name, buffer in zip(self.buffer_modules, self.buffer_names, old_buffers): + old_buffers.append(getattr(module, buffer_name)) + setattr(module, buffer_name, buffer) + + def __getstate__(self): + state = self.__dict__.copy() + state["param_modules"] = None + state["buffer_modules"] = None + return state + + def __setstate__(self, state): + state["param_modules"] = [] + state["buffer_modules"] = [] + out = super().__setstate__(state) + for module_name in self.param_module_names: + found = False + for other_name, module in self.stateless_model.named_modules(): + if other_name == module_name: + found = True + state["param_modules"].append(module) + break + if not found: + raise RuntimeError(f"module not found: {module_name}") + for module_name in self.buffer_module_names: + found = False + for other_name, module in self.stateless_model.named_modules(): + if other_name == module_name: + found = True + state["buffer_modules"].append(module) + break + if not found: + raise RuntimeError(f"module not found: {module_name}") + return out class FunctionalModule(nn.Module): """ - This is the callable object returned by :func:`make_functional`. + This is the callable object returned by :func:`make_functional_with_buffers`. """ - def __init__(self, stateless_model, param_names, names_map): + def __init__(self, stateless_model, param_module_names, modules, param_names): super(FunctionalModule, self).__init__() self.stateless_model = stateless_model + self.param_modules = modules self.param_names = param_names - self.names_map = names_map + self.param_module_names = param_module_names @staticmethod - def _create_from(model, disable_autograd_tracking=False): + def _create_from(model): # TODO: We don't need to copy the model to create a stateless copy model_copy = copy.deepcopy(model) - params, param_names, names_map = extract_weights(model_copy) + param_container = list(extract_weights(model_copy)) + if len(param_container): + module_names, param_modules, param_names, params = zip(*param_container) + else: + module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() if disable_autograd_tracking: for param in params: param.requires_grad_(False) - return FunctionalModule(model_copy, param_names, names_map), params + + return ( + FunctionalModule(model_copy, module_names, param_modules, param_names), + params, + ) def forward(self, params, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state(self.stateless_model, self.names_map, params) + old_params = _swap_state(self.param_modules, self.param_names, params) + try: return self.stateless_model(*args, **kwargs) finally: # Remove the loaded state on self.stateless_model - _swap_state(self.stateless_model, self.names_map, old_state) + for module, param_name, param in zip(self.param_modules, self.param_names, old_params): + old_params.append(getattr(module, param_name)) + setattr(module, param_name, param) + + def __getstate__(self): + state = self.__dict__.copy() + state["param_modules"] = None + return state + + def __setstate__(self, state): + state["param_modules"] = [] + out = super().__setstate__(state) + for module_name in self.param_module_names: + found = False + for other_name, module in self.stateless_model.named_modules(): + if other_name == module_name: + found = True + state["param_modules"].append(module) + break + if not found: + raise RuntimeError(f"module not found: {module_name}") + return out def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):