Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/param reset #328

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
52 changes: 37 additions & 15 deletions d3rlpy/algos/qlearning/torch/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABCMeta, abstractmethod
from typing import Sequence
from typing import Sequence, List
import torch.nn as nn

from ... import QLearningAlgoBase, QLearningAlgoImplBase
from ....constants import IMPL_NOT_INITIALIZED_ERROR
Expand All @@ -15,27 +16,45 @@ def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int):


class ParameterReset(QLearningCallback):
def __init__(self, replay_ratio: int, layer_reset:Sequence[bool],
algo:QLearningAlgoBase=None) -> None:
def __init__(self, replay_ratio: int, encoder_reset:Sequence[bool],
output_reset:bool, algo:QLearningAlgoBase=None) -> None:
self._replay_ratio = replay_ratio
self._layer_reset = layer_reset
self._encoder_reset = encoder_reset
self._output_reset = output_reset
self._check = False
if algo is not None:
self._check_layer_resets(algo=algo)


def _get_layers(self, q_func:nn.ModuleList)->List[nn.Module]:
all_modules = {nm:module for (nm, module) in q_func.named_modules()}
q_func_layers = [
*all_modules["_encoder._layers"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuseno assuming you're happy with the general approach of using the epoch_callback to inject the parameter reset functionality - I wondered if you could recommend a better approach for obtaining the encoder and fc layers which follows static typing?

all_modules["_fc"]
]
return q_func_layers

def _check_layer_resets(self, algo:QLearningAlgoBase):
assert algo._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo._impl, QLearningAlgoImplBase)

if len(self._layer_reset) != len(algo._impl.q_function):
raise ValueError
valid_layers = [
hasattr(layer, 'reset_parameters') for lr, layer in zip(
self._layer_reset, algo._impl.q_function)
if lr
]
self._check = all(valid_layers)
all_valid_layers = []
for q_func in algo._impl.q_function:
q_func_layers = self._get_layers(q_func)
if len(self._encoder_reset) + 1 != len(q_func_layers):
raise ValueError(
f"""
q_function layers: {q_func_layers};
specified encoder layers: {self._encoder_reset}
"""
)
valid_layers = [
hasattr(layer, 'reset_parameters') for lr, layer in zip(
self._encoder_reset, q_func_layers)
if lr
]
all_valid_layers.append(all(valid_layers))
self._check = all(all_valid_layers)
if not self._check:
raise ValueError(
"Some layer do not contain resettable parameters"
Expand All @@ -46,6 +65,9 @@ def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int):
self._check_layer_resets(algo=algo)
assert isinstance(algo._impl, QLearningAlgoImplBase)
if epoch % self._replay_ratio == 0:
for lr, layer in zip(self._layer_reset, algo._impl.q_function):
if lr:
layer.reset_parameters()
reset_lst = [*self._encoder_reset, self._output_reset]
for q_func in algo._impl.q_function:
q_func_layers = self._get_layers(q_func)
for lr, layer in zip(reset_lst, q_func_layers):
if lr:
layer.reset_parameters()