Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pufferlib/ocean/cartpole/binding.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "cartpole.h"
#define Env Cartpole
#include "../env_binding.h"

static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
env->continuous = unpack(kwargs, "continuous");
init(env);
return 0;
}

static int my_log(PyObject* dict, Log* log) {
assign_to_dict(dict, "score", log->score);
assign_to_dict(dict, "perf", log->perf);
assign_to_dict(dict, "episode_length", log->episode_length);
assign_to_dict(dict, "x_threshold_termination", log->x_threshold_termination);
assign_to_dict(dict, "pole_angle_termination", log->pole_angle_termination);
assign_to_dict(dict, "max_steps_termination", log->max_steps_termination);
assign_to_dict(dict, "n", log->n);
return 0;
}
44 changes: 24 additions & 20 deletions pufferlib/ocean/cartpole/cartpole.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// local compile/eval not implemented
// local compile/eval implemented for discrete actions only
// eval with python demo.py --mode eval --env puffer_cartpole --eval-mode-path <path to model>

#include <math.h>
Expand All @@ -11,8 +11,8 @@
#define NUM_WEIGHTS 133123
#define OBSERVATIONS_SIZE 4
#define ACTIONS_SIZE 2
#define CONTINUOUS 1
const char* WEIGHTS_PATH = "/puffertank/pufferlib/pufferlib/resources/cartpole/cartpole_weights.bin";
#define CONTINUOUS 0
const char* WEIGHTS_PATH = "/puffertank/test_newbind/pufferlib/pufferlib/resources/cartpole/cartpole_weights.bin";

float movement(int discrete_action, int userControlMode) {
if (userControlMode) {
Expand All @@ -26,14 +26,15 @@ void demo() {
Weights* weights = load_weights(WEIGHTS_PATH, NUM_WEIGHTS);
LinearLSTM* net;

if (CONTINUOUS) {
net = make_linearlstm_float(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE, ACTION_TYPE_FLOAT);
} else {
net = make_linearlstm(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE, ACTION_TYPE_INT);
}
// if (CONTINUOUS) {
// net = make_linearlstm_float(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE);
// } else {
// net = make_linearlstm_int(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE);
// }

CartPole env = {0};
env.continuous = CONTINUOUS;
net = make_linearlstm(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE);
Cartpole env = {0};
env.is_continuous = CONTINUOUS;
allocate(&env);
Client* client = make_client(&env);
c_reset(&env);
Expand All @@ -46,14 +47,17 @@ void demo() {
int userControlMode = IsKeyDown(KEY_LEFT_SHIFT);

if (!userControlMode) {
if (CONTINUOUS) {
forward_linearlstm_float(net, env.observations, env.actions);
env.actions[0] = tanhf(env.actions[0]);
} else {
int action_value;
forward_linearlstm_int(net, env.observations, &action_value);
env.actions[0] = movement(action_value, 0);
}
// if (CONTINUOUS) {
// forward_linearlstm_float(net, env.observations, env.actions);
// env.actions[0] = tanhf(env.actions[0]);
// } else {
// int action_value;
// forward_linearlstm_int(net, env.observations, &action_value);
// env.actions[0] = movement(action_value, 0);
// }
int action_value;
forward_linearlstm(net, env.observations, &action_value);
env.actions[0] = movement(action_value, 0);
} else {
env.actions[0] = movement(env.actions[0], userControlMode);
}
Expand All @@ -64,11 +68,11 @@ void demo() {

BeginDrawing();
ClearBackground(RAYWHITE);
c_render(client, &env);
c_render(&env);
DrawText("Evaluating policy...", 10, 160, 20, DARKGRAY);
EndDrawing();

if (env.dones[0]) {
if (env.terminals[0]) {
printf("Episode done. Steps: %d, Return: %.2f\n\n", episode_steps, episode_return);
episode_steps = 0;
episode_return = 0.0f;
Expand Down
Loading