Skip to content

Commit

Permalink
Move stacking from agent to env (#294)
Browse files Browse the repository at this point in the history
* Move stacking from Agent to Env

* Add docstring

* Simplify preprocessor
  • Loading branch information
mot0 authored and mthrok committed Mar 21, 2017
1 parent b9ad897 commit a69773f
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 230 deletions.
27 changes: 7 additions & 20 deletions luchador/agent/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ def __init__(
self._saver = None
self._ql = None
self._eg = None
self._stack_buffer = None
self._previous_stack = None
self._summary_writer = None
self._summary_values = {
'errors': [],
Expand Down Expand Up @@ -202,9 +200,7 @@ def _init_network(self, n_actions):

###########################################################################
# Methods for `reset`
def reset(self, initial_observation):
self._stack_buffer = [initial_observation[0]]
self._previous_stack = None
def reset(self, _):
self._ready = False

###########################################################################
Expand All @@ -227,26 +223,17 @@ def _predict_q(self):
# Methods for `learn`
def learn(self, state0, action, reward, state1, terminal, info=None):
self._n_obs += 1
self._record(action, reward, state1, terminal)
self._record(state0, action, reward, state1, terminal)
self._train()

def _record(self, action, reward, state1, terminal):
def _record(self, state0, action, reward, state1, terminal):
"""Stack states and push them to recorder, then sort memory"""
self._stack_buffer.append(state1[0])
self._recorder.push(1, {
'state0': state0, 'action': action, 'reward': reward,
'state1': state1, 'terminal': terminal})
self._ready = True

cfg = self.args['record_config']
if len(self._stack_buffer) == cfg['stack'] + 1:
if self._previous_stack is None:
self._previous_stack = np.array(self._stack_buffer[:-1])
state0_ = self._previous_stack
state1_ = np.array(self._stack_buffer[1:])
self._recorder.push(1, {
'state0': state0_, 'action': action, 'reward': reward,
'state1': state1_, 'terminal': terminal})
self._stack_buffer = self._stack_buffer[1:]
self._previous_stack = state1_
self._ready = True

sort_freq = cfg['sort_frequency']
if sort_freq > 0 and self._n_obs % sort_freq == 0:
_LG.info('Sorting Memory')
Expand Down
Loading

0 comments on commit a69773f

Please sign in to comment.