diff --git a/rlmeta/agents/dqn/apex_dqn_agent.py b/rlmeta/agents/dqn/apex_dqn_agent.py index b2bb7f8..5b86a7b 100644 --- a/rlmeta/agents/dqn/apex_dqn_agent.py +++ b/rlmeta/agents/dqn/apex_dqn_agent.py @@ -182,16 +182,20 @@ def make_replay(self) -> Optional[List[NestedTensor]]: replay = [] append = replay.append - for i in range(0, trajectory_len - self.multi_step): + # The last entry in the trajectory is an observation + done + for i in range(trajectory_len-1): cur = self.trajectory[i] - nxt = self.trajectory[i + self.multi_step] obs = cur["obs"] act = cur["action"] - next_obs = nxt["obs"] - done = nxt["done"] reward = 0.0 for j in range(self.multi_step): + if i + j == trajectory_len: + break reward += (self.gamma**j) * self.trajectory[i + j]["reward"] + # This call works because we never operate on the last observation + if i + j + 1 < trajectory_len: + next_obs = self.trajectory[i+j+1]["obs"] + done = self.trajectory[i+j+1]["done"] append({ "obs": obs, "action": act,