Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions c_advantage.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__global__ void advantage_kernel(
__global__ void p3o_kernel(
float* reward_block, // [num_steps, horizon]
float* reward_mask, // [num_steps, horizon]
float* values_mean, // [num_steps, horizon]
Expand Down Expand Up @@ -57,7 +57,7 @@ __global__ void advantage_kernel(
reward_mask[idx] = 1.0f;
}

float bootstrap = 0.0f;
//float bootstrap = 0.0f;
//if (k == horizon-1) {
// bootstrap = buf[i*horizon + horizon - 1]*values_mean[i*horizon + horizon - 1];
//}
Expand Down Expand Up @@ -85,3 +85,26 @@ __global__ void advantage_kernel(
advantages[i] = R;
bounds[i] = k;
}


__global__ void gae_kernel(
float* values, // [num_steps, horizon]
float* rewards, // [num_steps, horizon]
float* dones, // [num_steps, horizon]
float* advantages, // [num_steps, horizon]
float gamma,
float gae_lambda,
int num_steps,
int horizon
) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
float lastgaelam = 0;
for (int t = horizon-2; t >= 0; t--) {
int idx = row*horizon + t;
int idx_next = idx + 1;
float nextnonterminal = 1.0 - dones[idx_next];
float delta = rewards[idx_next] + gamma*values[idx_next]*nextnonterminal - values[idx];
lastgaelam = delta + gamma*gae_lambda*nextnonterminal * lastgaelam;
advantages[idx] = lastgaelam;
}
}
39 changes: 20 additions & 19 deletions c_advantage.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,26 @@ def fast_rewards_and_masks(float[:, :] reward_block, float[:, :] reward_mask,

memcpy(&reward_block[i, 0], &rewards[i+1], h * sizeof(float))

def compute_gae(cnp.ndarray dones, cnp.ndarray values,
cnp.ndarray rewards, float gamma, float gae_lambda):
def compute_gae(cnp.ndarray dones, float[:, :] values,
float[:, :] rewards, float gamma, float gae_lambda):
'''Fast Cython implementation of Generalized Advantage Estimation (GAE)'''
cdef int num_steps = len(rewards)
cdef cnp.ndarray advantages = np.zeros(num_steps, dtype=np.float32)
cdef float[:] c_advantages = advantages
cdef float[:] c_dones = dones
cdef float[:] c_values = values
cdef float[:] c_rewards = rewards

cdef float lastgaelam = 0
cdef float nextnonterminal, delta
cdef int t, t_cur, t_next
for t in range(num_steps-1):
t_cur = num_steps - 2 - t
t_next = num_steps - 1 - t
nextnonterminal = 1.0 - c_dones[t_next]
delta = c_rewards[t_next] + gamma * c_values[t_next] * nextnonterminal - c_values[t_cur]
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
c_advantages[t_cur] = lastgaelam
cdef:
float[:, :] c_dones = dones
int num_rows = dones.shape[0]
int horizon = dones.shape[1]
float lastgaelam = 0
float nextnonterminal, delta
int t, t_cur, t_next
cnp.ndarray advantages = np.zeros((num_rows, horizon), dtype=np.float32)
float[:, :] c_advantages = advantages

for row in range(num_rows-1, -1, -1):
lastgaelam = 0
for t in range(horizon-2, -1, -1):
t_next = t + 1
nextnonterminal = 1.0 - c_dones[row, t_next]
delta = rewards[row, t_next] + gamma*values[row, t_next]*nextnonterminal - values[row, t]
lastgaelam = delta + gamma*gae_lambda*nextnonterminal * lastgaelam
c_advantages[row, t] = lastgaelam

return advantages
Loading