-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LSTM #147
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice that you're adding LSTM. However, I'm not confident that it works because I don't know if cell
is being used to calculate the new hidden
. If you could point me to where that's happening in ppo_lstm.py
, it might help me understand. Also, does this run on the parity tests? When I first developed the memory component, LSTM didn't automatically batch the first dimension, so you had to know beforehand how big the batch size was going to be when passing in some inputs. I'm not sure if you were able to find a work around for that. My misunderstandings are probably also due to me not seeing the code base for a while.
If you could show that this works on a few parity tests and that the cell
part is being used for updating the state, as well as addressing those few minor comments, then I think we're good to go.
@@ -358,15 +374,27 @@ def forward_fn( | |||
inputs: jnp.ndarray, state: jnp.ndarray | |||
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: | |||
"""forward function""" | |||
torso = hk.nets.MLP( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this being removed?
behavior_values: jnp.ndarray | ||
behavior_log_probs: jnp.ndarray | ||
|
||
# GRU specific |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to LSTM specific or Recurrent specific. Also, wouldn't this need cell
as well?
seed=seed, | ||
player_id=player_id, | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would change this to an elif args.ppo.rnn_type == "gru"
, then an else
that raises an error. We wouldn't want any string other than lstm
to set the rnn_type to gru.
agent1: 'PPO' | ||
|
||
# Environment | ||
env_id: MountainCar-v0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the file is called pendulum.yaml
but the env_id: MountainCar-v0
. Am I missing something?
@@ -2,7 +2,7 @@ | |||
|
|||
# Agents | |||
agent1: 'PPO_memory' | |||
agent2: 'TitForTat' | |||
agent2: 'PPO_memory' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file is called ppo_mem_v_tft.yaml
but agent2: PPO_memory
. Why was this changed?
key = jax.random.split( | ||
agent2._state.random_key, args.popsize * args.num_opps | ||
).reshape(args.popsize, args.num_opps, -1) | ||
if args.ppo.rnn_type == "lstm" and args.agent2 == "PPO_memory": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need some help understanding this. If we want to use an LSTM, the initial hidden state is a Haiku LSTMState
object that holds the hidden
and cell
states. And if we want a GRU, then the hidden state is in jnp.tile(agent2._mem.hidden, (args.popsize, args.num_opps, 1, 1))
. Are these both NamedTuples
and is agent2.batch_init()
able to handle both of them?
Maybe I'm missing something that changed in how the agents and the agent methods are initialized, but I don't see any diffs for that file here.
hiddens: jnp.ndarray, | ||
): | ||
"""Surrogate loss using clipped probability ratios.""" | ||
(distribution, values), _ = network.apply( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are now using an LSTM, is it now the case that network.apply()
now requires both hidden
and cell
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no it needs the LSTMHIdden.
initial_hidden_state=initial_hidden_state, | ||
optimizer=optimizer, | ||
random_key=random_key, | ||
gru_dim=args.ppo.hidden_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the ppo file, you could change this input to something more general now such as recurrent_dim
, rather than gru_dim
.
@newtonkwan - can you pick this PR up and get into main before the release? |
Adding an LSTMAgent.
This adds an LSTM option to the PPO Agents.
Changes to Core Features: