Skip to content

Commit

Permalink
fixes #1 by not modifying param groups of optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
fedorsc committed Jan 18, 2021
1 parent 8080ddd commit 4d60ce6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
6 changes: 3 additions & 3 deletions doc/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ To look into the context and policy of the last time step you can do:
.. code-block:: python
>>> print(context) # doctest: +ELLIPSIS
tensor([[[8.3..., 8.6...]]], requires_grad=True)
tensor([[[7.8..., 9.1...]]], requires_grad=True)
>>> print(policy) # doctest: +ELLIPSIS
tensor([[[-2..., -2..., 5..., 5...]],
tensor([[[ 6..., -7..., -6..., 7...]],
...
[[-2..., -1..., 4..., 5...]]], grad_fn=<CloneBackward>)
... [[ 4..., -7..., -6..., 7...]]], grad_fn=<CloneBackward>)
18 changes: 10 additions & 8 deletions reprise/context_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def __init__(
self._model_inputs = []
self._observations = []

self._model_state = initial_model_state
for s in self._opt_accessor(self._model_state):
s.requires_grad_()

assert (len(self._opt_accessor(self._model_state)) ==
len(self._optimizer.param_groups[1]['params']))

def predict(self, state):
"""
Predict from the past.
Expand Down Expand Up @@ -149,13 +156,6 @@ def infer_contexts(self, model_input, observation):
self._observations.append(observation)
self._observations = self._observations[-self._inference_length:]

for _ in self._opt_accessor(self._model_state):
self._optimizer.param_groups[1]['params'].pop(-1)

for o in self._opt_accessor(self._model_state):
o.requires_grad_()
self._optimizer.param_groups[1]['params'].append(o)

# Perform context inference cycles
for _ in range(self._inference_cycles):
self._optimizer.zero_grad()
Expand All @@ -177,6 +177,8 @@ def infer_contexts(self, model_input, observation):
# the final output and state to be returned
with torch.no_grad():
outputs, states = self.predict(self._model_state)
self._model_state = states[0]
for i in range(len(self._model_state)):
for j in range(len(self._model_state[i])):
self._model_state[i][j].data = states[0][i][j].data

return self._context, outputs, states
Binary file modified tests/references/test_reprise_actions.npy
Binary file not shown.
Binary file modified tests/references/test_reprise_contexts.npy
Binary file not shown.

0 comments on commit 4d60ce6

Please sign in to comment.