|
1 |
| -from typing import List |
| 1 | +from collections import deque |
| 2 | + |
| 3 | +from beartype import beartype |
| 4 | + |
| 5 | +from jaxtyping import Float, Bool, Int, Int64, jaxtyped |
2 | 6 |
|
3 | 7 | import torch
|
4 | 8 |
|
5 | 9 |
|
6 | 10 | class ExperienceBuffer:
|
7 |
| - def __init__(self, num_workers: int, device: torch.device) -> None: |
| 11 | + def __init__(self, num_workers: int, device: torch.device, size: int | None = None) -> None: |
8 | 12 | self.num_workers = num_workers
|
9 | 13 |
|
10 |
| - self.states: List[torch.Tensor] = [] |
11 |
| - self.actions: List[torch.Tensor] = [] |
12 |
| - self.prev_actions: List[torch.Tensor] = [ |
13 |
| - torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=device) |
14 |
| - ] |
15 |
| - self.values: List[torch.Tensor] = [] |
16 |
| - self.rewards: List[torch.Tensor] = [] |
17 |
| - self.dones: List[torch.Tensor] = [] |
18 |
| - self.log_probs: List[torch.Tensor] = [] |
| 14 | + self.states: deque[Float[torch.Tensor, "worker channels height width"]] = deque(maxlen=size) |
| 15 | + self.actions: deque[Int64[torch.Tensor, "worker"]] = deque(maxlen=size) |
| 16 | + self.prev_actions: deque[Int64[torch.Tensor, "worker"]] = deque(maxlen=size + 1) |
| 17 | + self.prev_actions.append(torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=device)) |
| 18 | + self.values: deque[Float[torch.Tensor, "worker"]] = deque(maxlen=size) |
| 19 | + self.rewards: deque[Float[torch.Tensor, "worker"]] = deque(maxlen=size) |
| 20 | + self.dones: deque[Int64[torch.Tensor, "worker"]] = deque(maxlen=size) |
| 21 | + self.log_probs: deque[Float[torch.Tensor, "worker"]] = deque(maxlen=size) |
19 | 22 |
|
20 | 23 | self.device = device
|
21 | 24 |
|
22 |
| - def buffer_states(self, states: torch.Tensor) -> None: |
23 |
| - assert states.dim() == 4 |
24 |
| - assert states.size(0) == self.num_workers |
| 25 | + def buffer_states(self, states: Float[torch.Tensor, "worker channels height width"]) -> None: |
25 | 26 | self.states.append(states.to(torch.float32))
|
26 | 27 |
|
27 |
| - def buffer_actions(self, actions: torch.Tensor) -> None: |
28 |
| - assert actions.dim() == 1 |
29 |
| - assert actions.size(0) == self.num_workers |
| 28 | + def buffer_actions(self, actions: Int[torch.Tensor, "worker"]) -> None: |
30 | 29 | actions_cast = actions.to(torch.int64)
|
31 | 30 | self.prev_actions.append(actions_cast)
|
32 | 31 | self.actions.append(actions_cast)
|
33 | 32 |
|
34 |
| - def buffer_values(self, values: torch.Tensor) -> None: |
35 |
| - assert values.dim() == 1 |
36 |
| - assert values.size(0) == self.num_workers |
| 33 | + def buffer_values(self, values: Float[torch.Tensor, "worker"]) -> None: |
37 | 34 | self.values.append(values.to(torch.float32))
|
38 | 35 |
|
39 |
| - def buffer_log_probs(self, log_probs: torch.Tensor) -> None: |
40 |
| - assert log_probs.dim() == 1 |
41 |
| - assert log_probs.size(0) == self.num_workers |
| 36 | + def buffer_log_probs(self, log_probs: Float[torch.Tensor, "worker"]) -> None: |
42 | 37 | self.log_probs.append(log_probs.to(torch.float32))
|
43 | 38 |
|
44 |
| - def buffer_rewards(self, rewards: torch.Tensor) -> None: |
45 |
| - assert rewards.dim() == 1 |
46 |
| - assert rewards.size(0) == self.num_workers |
| 39 | + def buffer_rewards(self, rewards: Float[torch.Tensor, "worker"]) -> None: |
47 | 40 | self.rewards.append(rewards.to(torch.float32))
|
48 | 41 |
|
49 |
| - def buffer_dones(self, dones: torch.Tensor) -> None: |
50 |
| - assert dones.dim() == 1 |
51 |
| - assert dones.size(0) == self.num_workers |
| 42 | + def buffer_dones(self, dones: Bool[torch.Tensor, "worker"]) -> None: |
52 | 43 | dones_cast = dones.to(torch.int64)
|
53 |
| - self.prev_actions.append(torch.multiply(self.prev_actions.pop(-1), 1 - dones_cast)) |
| 44 | + self.prev_actions.append(torch.multiply(self.prev_actions.pop(), 1 - dones_cast)) |
54 | 45 | self.dones.append(dones_cast)
|
55 | 46 |
|
56 | 47 | def reset(self, forget_prev_action: bool = False) -> None:
|
57 |
| - self.states = [] |
| 48 | + self.states.clear() |
58 | 49 | if forget_prev_action:
|
59 |
| - self.prev_actions = [torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=self.device)] |
| 50 | + self.prev_actions.clear() |
| 51 | + self.prev_actions.append(torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=self.device)) |
60 | 52 | else:
|
61 |
| - self.prev_actions = self.prev_actions[-1:] |
62 |
| - self.dones = [] |
63 |
| - self.actions = [] |
64 |
| - self.values = [] |
65 |
| - self.rewards = [] |
66 |
| - self.log_probs = [] |
67 |
| - |
68 |
| - def get_last_states(self) -> torch.Tensor: |
| 53 | + final_action = self.prev_actions[-1] |
| 54 | + self.prev_actions.clear() |
| 55 | + self.prev_actions.append(final_action) |
| 56 | + self.dones.clear() |
| 57 | + self.actions.clear() |
| 58 | + self.values.clear() |
| 59 | + self.rewards.clear() |
| 60 | + self.log_probs.clear() |
| 61 | + |
| 62 | + def get_last_states(self) -> Float[torch.Tensor, "worker channels height width"]: |
69 | 63 | return self.states[-1]
|
70 | 64 |
|
71 |
| - def get_last_actions(self) -> torch.Tensor: |
| 65 | + def get_last_actions(self) -> Float[torch.Tensor, "worker"]: |
72 | 66 | return self.prev_actions[-1]
|
73 | 67 |
|
74 |
| - def get_state_buffer(self) -> torch.Tensor: |
75 |
| - return torch.stack(self.states) |
| 68 | + def get_state_buffer(self) -> Float[torch.Tensor, "buffer worker channels height width"]: |
| 69 | + return torch.stack(list(self.states)) |
76 | 70 |
|
77 |
| - def get_action_buffer(self) -> torch.Tensor: |
78 |
| - return torch.stack(self.actions) |
| 71 | + def get_action_buffer(self) -> Int64[torch.Tensor, "buffer worker"]: |
| 72 | + return torch.stack(list(self.actions)) |
79 | 73 |
|
80 |
| - def get_prev_action_buffer(self) -> torch.Tensor: |
81 |
| - return torch.stack(self.prev_actions[:-1]) |
| 74 | + def get_prev_action_buffer(self) -> Int64[torch.Tensor, "buffer worker"]: |
| 75 | + return torch.stack(list(self.prev_actions)[:-1]) |
82 | 76 |
|
83 |
| - def get_value_buffer(self) -> torch.Tensor: |
84 |
| - return torch.stack(self.values) |
| 77 | + def get_value_buffer(self) -> Float[torch.Tensor, "buffer worker"]: |
| 78 | + return torch.stack(list(self.values)) |
85 | 79 |
|
86 |
| - def get_log_prob_buffer(self) -> torch.Tensor: |
87 |
| - return torch.stack(self.log_probs) |
| 80 | + def get_log_prob_buffer(self) -> Float[torch.Tensor, "buffer worker"]: |
| 81 | + return torch.stack(list(self.log_probs)) |
88 | 82 |
|
89 |
| - def get_reward_buffer(self) -> torch.Tensor: |
90 |
| - return torch.stack(self.rewards) |
| 83 | + def get_reward_buffer(self) -> Float[torch.Tensor, "buffer worker"]: |
| 84 | + return torch.stack(list(self.rewards)) |
91 | 85 |
|
92 |
| - def get_dones_buffer(self) -> torch.Tensor: |
93 |
| - return torch.stack(self.dones) |
| 86 | + def get_dones_buffer(self) -> Int64[torch.Tensor, "buffer worker"]: |
| 87 | + return torch.stack(list(self.dones)) |
0 commit comments