Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added BCQ #378

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions genrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NeuralNoiseSamplingAgent,
)
from genrl.agents.bandits.contextual.variational import VariationalAgent # noqa
from genrl.agents.bandits.multiarmed.base import MABAgent # noqa
from genrl.agents.bandits.multiarmed.bayesian import BayesianUCBMABAgent # noqa
from genrl.agents.bandits.multiarmed.bernoulli_mab import BernoulliMAB # noqa
from genrl.agents.bandits.multiarmed.epsgreedy import EpsGreedyMABAgent # noqa
Expand All @@ -41,5 +42,4 @@
from genrl.agents.deep.sac.sac import SAC # noqa
from genrl.agents.deep.td3.td3 import TD3 # noqa
from genrl.agents.deep.vpg.vpg import VPG # noqa

from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa
from genrl.agents.offline.bcq.bcq import BCQ # noqa
5 changes: 3 additions & 2 deletions genrl/agents/deep/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class BaseAgent(ABC):
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy
value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
seed (int): Seed for randomness
Expand Down
51 changes: 14 additions & 37 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from torch.nn import functional as F

from genrl.agents.deep.base import BaseAgent
from genrl.core import (

from genrl.core import ( # PrioritizedReplayBufferSamples,; ReplayBufferSamples,
PrioritizedBuffer,
PrioritizedReplayBufferSamples,
ReplayBuffer,
ReplayBufferSamples,
)


Expand All @@ -23,8 +22,9 @@ class OffPolicyAgent(BaseAgent):
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy
value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
Expand Down Expand Up @@ -67,19 +67,6 @@ def update_target_model(self) -> None:
"""
raise NotImplementedError

def _reshape_batch(self, batch: List):
"""Function to reshape experiences

Can be modified for individual algorithm usage

Args:
batch (:obj:`list`): List of experiences that are being replayed

Returns:
batch (:obj:`list`): Reshaped experiences for replay
"""
return [*batch]

def sample_from_buffer(self, beta: float = None):
"""Samples experiences from the buffer and converts them into usable formats

Expand All @@ -95,18 +82,6 @@ def sample_from_buffer(self, beta: float = None):
else:
batch = self.replay_buffer.sample(self.batch_size)

states, actions, rewards, next_states, dones = self._reshape_batch(batch)

# Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples.
if isinstance(self.replay_buffer, ReplayBuffer):
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
elif isinstance(self.replay_buffer, PrioritizedBuffer):
indices, weights = batch[5], batch[6]
batch = PrioritizedReplayBufferSamples(
*[states, actions, rewards, next_states, dones, indices, weights]
)
else:
raise NotImplementedError
return batch

def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
Expand Down Expand Up @@ -136,8 +111,9 @@ class OffPolicyAgentAC(OffPolicyAgent):
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy
value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
Expand All @@ -154,7 +130,7 @@ def __init__(self, *args, polyak=0.995, **kwargs):
self.doublecritic = False

def select_action(
self, state: torch.Tensor, deterministic: bool = True
self, state: torch.Tensor, deterministic: bool = True, noise: bool = True
) -> torch.Tensor:
"""Select action given state

Expand All @@ -163,6 +139,7 @@ def select_action(
Args:
state (:obj:`torch.Tensor`): Current state of the environment
deterministic (bool): Should the policy be deterministic or stochastic
noise (bool): Should noise be added to the agent

Returns:
action (:obj:`torch.Tensor`): Action taken by the agent
Expand All @@ -171,7 +148,7 @@ def select_action(
action = action.detach()

# add noise to output from policy network
if self.noise is not None:
if noise and self.noise is not None:
action += self.noise()

return torch.clamp(
Expand Down Expand Up @@ -210,7 +187,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
def get_target_q_values(
self, next_states: torch.Tensor, rewards: List[float], dones: List[bool]
) -> torch.Tensor:
"""Get target Q values for the TD3
"""Get target Q values

Args:
next_states (:obj:`torch.Tensor`): Next states for which target Q-values
Expand All @@ -219,7 +196,7 @@ def get_target_q_values(
dones (:obj:`list`): Game over status for each environment

Returns:
target_q_values (:obj:`torch.Tensor`): Target Q values for the TD3
target_q_values (:obj:`torch.Tensor`): Target Q values
"""
next_target_actions = self.ac_target.get_action(next_states, True)[0]

Expand Down Expand Up @@ -265,7 +242,7 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor:
Returns:
loss (:obj:`torch.Tensor`): Calculated policy loss
"""
next_best_actions = self.ac.get_action(states, True)[0]
next_best_actions = self.select_action(states, deterministic=True, noise=False)
q_values = self.ac.get_value(torch.cat([states, next_best_actions], dim=-1))
policy_loss = -torch.mean(q_values)
return policy_loss
Expand Down
2 changes: 1 addition & 1 deletion genrl/agents/deep/base/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(

if buffer_type == "rollout":
self.rollout = RolloutBuffer(
self.rollout_size, self.env, gae_lambda=gae_lambda
self.rollout_size, self.env, gae_lambda=gae_lambda, gamma=self.gamma
)
else:
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions genrl/agents/deep/dqn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if self.create_model:
self._create_model()

def _create_model(self, *args, **kwargs) -> None:
def _create_model(self, **kwargs) -> None:
"""Function to initialize Q-value model

This will create the Q-value function of the agent.
Expand Down Expand Up @@ -153,7 +153,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
q_values (:obj:`torch.Tensor`): Q values for the given states and actions
"""
q_values = self.model(states)
q_values = q_values.gather(2, actions)
q_values = q_values.gather(2, actions.unsqueeze(-1))
return q_values

def get_target_q_values(
Expand Down
6 changes: 4 additions & 2 deletions genrl/agents/deep/dqn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def categorical_q_values(agent: DQN, states: torch.Tensor, actions: torch.Tensor
# Size of q_value_dist should be [batch_size, n_envs, action_dim, num_atoms] here
# To gather the q_values of the respective actions, actions must be of the shape:
# [batch_size, n_envs, 1, num_atoms]. It's current shape is [batch_size, n_envs, 1]
actions = actions.unsqueeze(-1).expand(
agent.batch_size, agent.env.n_envs, 1, agent.num_atoms
actions = (
actions.unsqueeze(-1)
.unsqueeze(-1)
.expand(agent.batch_size, agent.env.n_envs, 1, agent.num_atoms)
)
# Now as we gather q_values from the action_dim dimension which is at index 2
q_values = q_value_dist.gather(2, actions)
Expand Down
Empty file.
Empty file.
Loading