Skip to content

Commit 8e49ebf

Browse files
committed
Add sampling strategies and longer experience buffer
1 parent 3c04726 commit 8e49ebf

17 files changed

+228
-128
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
_target_: config.learning_rate.LinearDecayLRConfig
22

3-
learning_rate: 0.000723
3+
learning_rate: 0.0007
44
decay_iters: 1500

configs/agent/model/actor_critic.yaml

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
_target_: config.model.ActorCriticConfig
22

33
action_embedding_size: 128
4-
num_layers: 4
5-
num_filters: 128
6-
kernel_size: 4
7-
stride: 2
8-
fc_size: 260
4+
num_filters: [128, 64, 64, 32]
5+
kernel_sizes: [9, 5, 5, 3]
6+
fc_size: 256

configs/agent/ppo.yaml

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
defaults:
22
- model: actor_critic
33
- learning_rate: linear_decay
4+
- sampling_strategy: dist_sample
45
- _self_
56

67
_target_: config.agent.PPOAgentConfig
78

8-
gamma: 0.903
9-
tau: 0.854
9+
gamma: 0.9
10+
tau: 0.85
11+
exp_buffer_size: 512
1012
epochs_per_update: 5
1113
total_updates: 0
12-
batch_size: 256
14+
batch_size: 512
1315
clip_param: 0.2
1416
clip_value: False
15-
critic_loss_weight: 0.818
16-
max_entropy_loss_weight: 0.00118
17+
critic_loss_weight: 0.8
18+
max_entropy_loss_weight: 0.001
1719
grad_clip_norm: 0.5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_target_: config.sampling_strategy.AdaptiveEpsilonGreedySamplingStrategy
2+
3+
initial_epsilon: 0.5
4+
min_epsilon: 0.01
5+
decay_factor: 0.99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_target_: config.sampling_strategy.DistSamplingStrategy
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_target_: config.sampling_strategy.GreedySamplingStrategy

configs/environment/base.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ _target_: config.environment.EnvironmentConfig
33
env_name: SuperMarioBros-v0
44
complex_movement: True
55
num_repeat_frames: 4
6-
num_stack_frames: 3
6+
num_stack_frames: 2
77
clip_top: 0
88
clip_bot: 0
99
clip_left: 0

configs/neurio_config.yaml

+21-23
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ defaults:
66
- _self_
77

88
_target_: config.main_config.NeurioConfig
9-
level: "1-3"
9+
level: "1-1"
1010
num_workers: 32
11-
num_iters: 3000
11+
num_iters: 500
1212
steps_per_iter: 128
13-
save_frequency: 100
13+
save_frequency: 50
1414
render: False
1515

1616
agent:
@@ -19,23 +19,21 @@ agent:
1919
learning_rate:
2020
decay_iters: ${num_iters}
2121

22-
hydra:
23-
sweeper:
24-
# ToDo: Add pruner as soon as it's available in the optuna sweeper
25-
sampler:
26-
n_startup_trials: 10
27-
direction: maximize
28-
study_name: Neurio-lev-${level}
29-
storage: sqlite:///optuna_studies/${hydra.sweeper.study_name}.db
30-
n_trials: 1
31-
n_jobs: 1
32-
33-
params:
34-
agent.learning_rate.learning_rate: tag(log, interval(0.0002, 0.0009))
35-
agent.model.num_filters: choice(128, 256)
36-
# agent.model.kernel_size: range(3, 4)
37-
# agent.model.fc_size: range(100, 700)
38-
# agent.critic_loss_weight: interval(0.5, 1.5)
39-
agent.max_entropy_loss_weight: tag(log, interval(0.0001, 0.01))
40-
agent.gamma: interval(0.8, 1.0)
41-
agent.tau: interval(0.8, 1.0)
22+
#hydra:
23+
# sweeper:
24+
# # ToDo: Add pruner as soon as it's available in the optuna sweeper
25+
# sampler:
26+
# n_startup_trials: 10
27+
# direction: maximize
28+
# study_name: Neurio-lev-${level}
29+
# storage: sqlite:///optuna_studies/${hydra.sweeper.study_name}.db
30+
# n_trials: 1
31+
# n_jobs: 1
32+
#
33+
# params:
34+
# agent.learning_rate.learning_rate: tag(log, interval(0.0002, 0.0009))
35+
# agent.exp_buffer_size: choice(128, 256)
36+
# agent.batch_size: choice(64, 128, 256, 512)
37+
# agent.max_entropy_loss_weight: tag(log, interval(0.0001, 0.01))
38+
# agent.gamma: interval(0.8, 1.0)
39+
# agent.tau: interval(0.8, 1.0)

src/agents/experience_buffer.py

+49-55
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,87 @@
1-
from typing import List
1+
from collections import deque
2+
3+
from beartype import beartype
4+
5+
from jaxtyping import Float, Bool, Int, Int64, jaxtyped
26

37
import torch
48

59

610
class ExperienceBuffer:
7-
def __init__(self, num_workers: int, device: torch.device) -> None:
11+
def __init__(self, num_workers: int, device: torch.device, size: int | None = None) -> None:
812
self.num_workers = num_workers
913

10-
self.states: List[torch.Tensor] = []
11-
self.actions: List[torch.Tensor] = []
12-
self.prev_actions: List[torch.Tensor] = [
13-
torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=device)
14-
]
15-
self.values: List[torch.Tensor] = []
16-
self.rewards: List[torch.Tensor] = []
17-
self.dones: List[torch.Tensor] = []
18-
self.log_probs: List[torch.Tensor] = []
14+
self.states: deque[Float[torch.Tensor, "worker channels height width"]] = deque(maxlen=size)
15+
self.actions: deque[Int64[torch.Tensor, "worker"]] = deque(maxlen=size)
16+
self.prev_actions: deque[Int64[torch.Tensor, "worker"]] = deque(maxlen=size + 1)
17+
self.prev_actions.append(torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=device))
18+
self.values: deque[Float[torch.Tensor, "worker"]] = deque(maxlen=size)
19+
self.rewards: deque[Float[torch.Tensor, "worker"]] = deque(maxlen=size)
20+
self.dones: deque[Int64[torch.Tensor, "worker"]] = deque(maxlen=size)
21+
self.log_probs: deque[Float[torch.Tensor, "worker"]] = deque(maxlen=size)
1922

2023
self.device = device
2124

22-
def buffer_states(self, states: torch.Tensor) -> None:
23-
assert states.dim() == 4
24-
assert states.size(0) == self.num_workers
25+
def buffer_states(self, states: Float[torch.Tensor, "worker channels height width"]) -> None:
2526
self.states.append(states.to(torch.float32))
2627

27-
def buffer_actions(self, actions: torch.Tensor) -> None:
28-
assert actions.dim() == 1
29-
assert actions.size(0) == self.num_workers
28+
def buffer_actions(self, actions: Int[torch.Tensor, "worker"]) -> None:
3029
actions_cast = actions.to(torch.int64)
3130
self.prev_actions.append(actions_cast)
3231
self.actions.append(actions_cast)
3332

34-
def buffer_values(self, values: torch.Tensor) -> None:
35-
assert values.dim() == 1
36-
assert values.size(0) == self.num_workers
33+
def buffer_values(self, values: Float[torch.Tensor, "worker"]) -> None:
3734
self.values.append(values.to(torch.float32))
3835

39-
def buffer_log_probs(self, log_probs: torch.Tensor) -> None:
40-
assert log_probs.dim() == 1
41-
assert log_probs.size(0) == self.num_workers
36+
def buffer_log_probs(self, log_probs: Float[torch.Tensor, "worker"]) -> None:
4237
self.log_probs.append(log_probs.to(torch.float32))
4338

44-
def buffer_rewards(self, rewards: torch.Tensor) -> None:
45-
assert rewards.dim() == 1
46-
assert rewards.size(0) == self.num_workers
39+
def buffer_rewards(self, rewards: Float[torch.Tensor, "worker"]) -> None:
4740
self.rewards.append(rewards.to(torch.float32))
4841

49-
def buffer_dones(self, dones: torch.Tensor) -> None:
50-
assert dones.dim() == 1
51-
assert dones.size(0) == self.num_workers
42+
def buffer_dones(self, dones: Bool[torch.Tensor, "worker"]) -> None:
5243
dones_cast = dones.to(torch.int64)
53-
self.prev_actions.append(torch.multiply(self.prev_actions.pop(-1), 1 - dones_cast))
44+
self.prev_actions.append(torch.multiply(self.prev_actions.pop(), 1 - dones_cast))
5445
self.dones.append(dones_cast)
5546

5647
def reset(self, forget_prev_action: bool = False) -> None:
57-
self.states = []
48+
self.states.clear()
5849
if forget_prev_action:
59-
self.prev_actions = [torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=self.device)]
50+
self.prev_actions.clear()
51+
self.prev_actions.append(torch.zeros(size=(self.num_workers,), dtype=torch.int64, device=self.device))
6052
else:
61-
self.prev_actions = self.prev_actions[-1:]
62-
self.dones = []
63-
self.actions = []
64-
self.values = []
65-
self.rewards = []
66-
self.log_probs = []
67-
68-
def get_last_states(self) -> torch.Tensor:
53+
final_action = self.prev_actions[-1]
54+
self.prev_actions.clear()
55+
self.prev_actions.append(final_action)
56+
self.dones.clear()
57+
self.actions.clear()
58+
self.values.clear()
59+
self.rewards.clear()
60+
self.log_probs.clear()
61+
62+
def get_last_states(self) -> Float[torch.Tensor, "worker channels height width"]:
6963
return self.states[-1]
7064

71-
def get_last_actions(self) -> torch.Tensor:
65+
def get_last_actions(self) -> Float[torch.Tensor, "worker"]:
7266
return self.prev_actions[-1]
7367

74-
def get_state_buffer(self) -> torch.Tensor:
75-
return torch.stack(self.states)
68+
def get_state_buffer(self) -> Float[torch.Tensor, "buffer worker channels height width"]:
69+
return torch.stack(list(self.states))
7670

77-
def get_action_buffer(self) -> torch.Tensor:
78-
return torch.stack(self.actions)
71+
def get_action_buffer(self) -> Int64[torch.Tensor, "buffer worker"]:
72+
return torch.stack(list(self.actions))
7973

80-
def get_prev_action_buffer(self) -> torch.Tensor:
81-
return torch.stack(self.prev_actions[:-1])
74+
def get_prev_action_buffer(self) -> Int64[torch.Tensor, "buffer worker"]:
75+
return torch.stack(list(self.prev_actions)[:-1])
8276

83-
def get_value_buffer(self) -> torch.Tensor:
84-
return torch.stack(self.values)
77+
def get_value_buffer(self) -> Float[torch.Tensor, "buffer worker"]:
78+
return torch.stack(list(self.values))
8579

86-
def get_log_prob_buffer(self) -> torch.Tensor:
87-
return torch.stack(self.log_probs)
80+
def get_log_prob_buffer(self) -> Float[torch.Tensor, "buffer worker"]:
81+
return torch.stack(list(self.log_probs))
8882

89-
def get_reward_buffer(self) -> torch.Tensor:
90-
return torch.stack(self.rewards)
83+
def get_reward_buffer(self) -> Float[torch.Tensor, "buffer worker"]:
84+
return torch.stack(list(self.rewards))
9185

92-
def get_dones_buffer(self) -> torch.Tensor:
93-
return torch.stack(self.dones)
86+
def get_dones_buffer(self) -> Int64[torch.Tensor, "buffer worker"]:
87+
return torch.stack(list(self.dones))

src/agents/ppo_agent.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,14 @@ def __init__(
5151
device_name = torch.cuda.get_device_name(self.cpu)
5252
log.info(f"Using CPU {device_name}")
5353

54-
self.experience_buffer = ExperienceBuffer(self.num_workers, device=self.device)
54+
self.experience_buffer = ExperienceBuffer(
55+
self.num_workers, size=self.config.exp_buffer_size, device=self.device
56+
)
5557

5658
self.actor_critic = get_model(config=self.config.model, env_info=self.env_info).to(self.device)
5759

60+
self.sampling_strategy = self.config.sampling_strategy
61+
5862
mlflow.log_text(
5963
str(
6064
torchinfo.summary(
@@ -106,10 +110,10 @@ def _compute_probs_values(self) -> Tuple[torch.Tensor, torch.Tensor]:
106110
def next_actions(self, train: bool = True) -> Tuple[List[int], List[float]]:
107111
with torch.no_grad():
108112
probs, values = self._compute_probs_values()
113+
actions = self.sampling_strategy.sample_action(probs)
109114

110-
action_dist = torch.distributions.Categorical(probs=probs)
111-
actions = action_dist.sample()
112-
log_probs = action_dist.log_prob(actions)
115+
eps = torch.finfo(probs.dtype).eps
116+
log_probs = torch.log(probs.clamp(min=eps, max=1 - eps)).gather(-1, actions.unsqueeze(-1)).squeeze(-1)
113117
self.experience_buffer.buffer_values(values)
114118
self.experience_buffer.buffer_log_probs(log_probs)
115119
self.experience_buffer.buffer_actions(actions)
@@ -236,10 +240,10 @@ def _current_entropy_loss_weight(self) -> float:
236240

237241
def update(self) -> None:
238242
losses = {
239-
"actor": 0.0,
240-
"critic": 0.0,
241-
"entropy": 0.0,
242-
"total": 0.0,
243+
"actor_loss": 0.0,
244+
"critic_loss": 0.0,
245+
"entropy_loss": 0.0,
246+
"total_loss": 0.0,
243247
}
244248

245249
dataset = self._create_dataset_from_buffers()
@@ -261,19 +265,25 @@ def update(self) -> None:
261265

262266
# Actor loss
263267
act_loss, action_dist = self._calculate_actor_loss(probs, b_actions, b_log_probs, b_advantages)
264-
losses["actor"] += act_loss.item()
268+
losses["actor_loss"] += act_loss.item()
265269

266270
# Critic loss
267271
crit_loss = self._calculate_critic_loss(v, b_returns, b_values)
268-
losses["critic"] += crit_loss.item()
272+
losses["critic_loss"] += crit_loss.item()
269273

270274
# Entropy loss
271275
entropy = torch.mean(action_dist.entropy())
272-
losses["entropy"] += entropy.item()
276+
losses["entropy_loss"] += entropy.item()
273277

274278
# Total
275-
loss = act_loss + self.critic_loss_weight * crit_loss - self._current_entropy_loss_weight() * entropy
276-
losses["total"] += loss.item()
279+
loss = torch.add(
280+
act_loss,
281+
torch.sub(
282+
torch.mul(self.critic_loss_weight, crit_loss),
283+
torch.mul(self._current_entropy_loss_weight(), entropy),
284+
),
285+
)
286+
losses["total_loss"] += loss.item()
277287
total_epoch_loss += loss.item()
278288
total_epoch_batches += 1
279289

@@ -287,6 +297,7 @@ def update(self) -> None:
287297
losses[key] /= len(dataloader) * self.epochs_per_update
288298
log.debug(f"Update finished. Losses: {losses}")
289299
self.scheduler.step()
300+
self.sampling_strategy.update(-losses["total_loss"])
290301

291302
mlflow.log_metrics(losses, self.update_step)
292303
self.update_step += 1

src/config/agent.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .learning_rate import LRConfig
44
from .model import ModelConfig
5+
from .sampling_strategy import SamplingStrategy
56

67

78
@dataclass
@@ -18,8 +19,10 @@ class RandomAgentConfig(AgentConfig):
1819
class PPOAgentConfig(AgentConfig):
1920
model: ModelConfig
2021
learning_rate: LRConfig
22+
sampling_strategy: SamplingStrategy
2123
gamma: float
2224
tau: float
25+
exp_buffer_size: int
2326
epochs_per_update: int
2427
total_updates: int
2528
batch_size: int
@@ -32,6 +35,7 @@ class PPOAgentConfig(AgentConfig):
3235
def __post_init__(self) -> None:
3336
assert 0 <= self.gamma <= 1
3437
assert 0 <= self.tau <= 1
38+
assert self.exp_buffer_size >= 1
3539
assert self.epochs_per_update >= 1
3640
assert self.total_updates >= 1
3741
assert self.batch_size >= 1

0 commit comments

Comments
 (0)