Skip to content

Latest commit

 

History

History
96 lines (84 loc) · 2.34 KB

File metadata and controls

96 lines (84 loc) · 2.34 KB

lstm ppoc

Shapes

def rollout():
  for i in [states,self.prev_options,masks]:
    print(i.shape)
  # batch, feat_dim
  states (9, 8)
  prev_o torch.Size([9])
  masks torch.Size([1, 9])

  for k in prediction:
    print(prediction[k].shape)
  # batch, num_o, act_dim
  mean torch.Size([9, 4, 2])
  std torch.Size([9, 4, 2])
  q_o torch.Size([9, 4])
  pi_o torch.Size([9, 4])
  log_pi_o torch.Size([9, 4])
  beta torch.Size([9, 4])

def compute_adv():
  for i in [states,self.prev_options,storage.m[-1]]:
    print(i.shape)
  # batch, feat_dim
  states (9, 8)
  prev_o torch.Size([9])
  # todo: to be switch axis
  masks torch.Size([9, 1])

  for k in prediction:
    print(k, prediction[k].shape)
  # batch, num_o, act_dim
  mean torch.Size([9, 4, 2])
  std torch.Size([9, 4, 2])
  q_o torch.Size([9, 4])
  pi_o torch.Size([9, 4])
  log_pi_o torch.Size([9, 4])
  beta torch.Size([9, 4])

def learn():
  for i in [states,prev_options,masks]:
    print(i.shape)
  # num*workers * rollout length
  torch.Size([1800, 8])
  torch.Size([1800, 1])
  torch.Size([1800, 1])

  for k in prediction:
    print(k, prediction[k].shape)
  # batch, num_o, act_dim
  mean torch.Size([64, 4, 2])
  std torch.Size([64, 4, 2])
  q_o torch.Size([64, 4])
  pi_o torch.Size([64, 4])
  log_pi_o torch.Size([64, 4])
  beta torch.Size([64, 4])

Possible Bugs

Batch First

Current implementation not right

From https://pytorch.org/docs/stable/nn.html >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20)

>>> output, (hn, cn) = rnn(input, (h0, c0))

Done Mask

  • During forward stage (rollout, train_step)

Diff with nonrecur

agent

  • Initiate input_lstm_states
  • reshape data from [batch, feat_dim] to [timesteps, batch, feat_dim]
  • update input_lstm_states

Advantage

# From https://github.com/seungeunrho/minimalRL/blob/master/ppo-lstm.py
# at time t: second_hidden = lstm(s,first_hidden)
# at time t+1: s_prime = env.step(second_hidden)
v_prime = self.v(s_prime, second_hidden).squeeze(1)
td_target = r + gamma * v_prime * done_mask
v_s = self.v(s, first_hidden).squeeze(1)
delta = td_target - v_s