Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pufferlib/ocean/go/binding.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "go.h"
#define Env CGo
#include "../env_binding.h"

static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
env->width = unpack(kwargs, "width");
env->height = unpack(kwargs, "height");
env->grid_size = unpack(kwargs, "grid_size");
env->board_width = unpack(kwargs, "board_width");
env->board_height = unpack(kwargs, "board_height");
env->grid_square_size = unpack(kwargs, "grid_square_size");
env->moves_made = unpack(kwargs, "moves_made");
env->komi = unpack(kwargs, "komi");
env->score = unpack(kwargs, "score");
env->last_capture_position = unpack(kwargs, "last_capture_position");
env->reward_move_pass = unpack(kwargs, "reward_move_pass");
env->reward_move_invalid = unpack(kwargs, "reward_move_invalid");
env->reward_move_valid = unpack(kwargs, "reward_move_valid");
env->reward_player_capture = unpack(kwargs, "reward_player_capture");
env->reward_opponent_capture = unpack(kwargs, "reward_opponent_capture");

init(env);
return 0;
}

static int my_log(PyObject* dict, Log* log) {
assign_to_dict(dict, "perf", log->perf);
assign_to_dict(dict, "score", log->score);
assign_to_dict(dict, "episode_length", log->episode_length);
assign_to_dict(dict, "episode_return", log->episode_return);
assign_to_dict(dict, "n", log->n);
return 0;
}
145 changes: 0 additions & 145 deletions pufferlib/ocean/go/cy_go.pyx

This file was deleted.

111 changes: 38 additions & 73 deletions pufferlib/ocean/go/go.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#define PLAYER_WIN 1
static const int DIRECTIONS[NUM_DIRECTIONS][2] = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
// LD_LIBRARY_PATH=raylib/lib ./go
#define LOG_BUFFER_SIZE 1024

typedef struct Log Log;
struct Log {
Expand All @@ -24,55 +23,6 @@ struct Log {
float n;
};

typedef struct LogBuffer LogBuffer;
struct LogBuffer {
Log* logs;
int length;
int idx;
};

LogBuffer* allocate_logbuffer(int size) {
LogBuffer* logs = (LogBuffer*)calloc(1, sizeof(LogBuffer));
logs->logs = (Log*)calloc(size, sizeof(Log));
logs->length = size;
logs->idx = 0;
return logs;
}

void free_logbuffer(LogBuffer* buffer) {
free(buffer->logs);
free(buffer);
}

void add_log(LogBuffer* logs, Log* log) {
if (logs->idx == logs->length) {
return;
}
logs->logs[logs->idx] = *log;
logs->idx += 1;
//printf("Log: %f, %f, %f\n", log->episode_return, log->episode_length, log->score);
}

Log aggregate_and_clear(LogBuffer* logs) {
Log log = {0};
if (logs->idx == 0) {
return log;
}
for (int i = 0; i < logs->idx; i++) {
log.episode_return += logs->logs[i].episode_return;
log.episode_length += logs->logs[i].episode_length;
log.n += logs->logs[i].n;
log.score += logs->logs[i].score;
log.perf += logs->logs[i].perf;
}
log.episode_return /= logs->idx;
log.episode_length /= logs->idx;
log.score /= logs->idx;
log.perf /= logs->idx;
logs->idx = 0;
return log;
}

typedef struct Group Group;
struct Group {
int parent;
Expand Down Expand Up @@ -114,8 +64,7 @@ struct CGo {
float* observations;
int* actions;
float* rewards;
unsigned char* dones;
LogBuffer* log_buffer;
unsigned char* terminals;
Log log;
float score;
int width;
Expand All @@ -141,8 +90,31 @@ struct CGo {
float reward_move_valid;
float reward_player_capture;
float reward_opponent_capture;
float tick;
};

void add_log(CGo* env) {
env->log.episode_length += env->tick;

// Calculate perf as a win rate (1.0 if win, 0.0 if loss)
float win_value = 0.0;
if (env->score > 0) {
win_value = 1.0; // Win
}
else if (env->score < 0) {
win_value = 0.0; // Loss
}
else {
win_value = 0.0; // Tie
}

env->log.perf = (env->log.perf * env->log.n + win_value) / (env->log.n + 1.0);

env->log.score += env->score;
env->log.episode_return += env->rewards[0];
env->log.n += 1.0;
}

void generate_board_positions(CGo* env) {
for (int i = 0; i < (env->grid_size-1) * (env->grid_size-1); i++) {
int row = i / (env->grid_size-1);
Expand Down Expand Up @@ -182,8 +154,7 @@ void allocate(CGo* env) {
env->observations = (float*)calloc((env->grid_size)*(env->grid_size)*2 + 2, sizeof(float));
env->actions = (int*)calloc(1, sizeof(int));
env->rewards = (float*)calloc(1, sizeof(float));
env->dones = (unsigned char*)calloc(1, sizeof(unsigned char));
env->log_buffer = allocate_logbuffer(LOG_BUFFER_SIZE);
env->terminals = (unsigned char*)calloc(1, sizeof(unsigned char));
}

void free_initialized(CGo* env) {
Expand All @@ -201,9 +172,8 @@ void free_initialized(CGo* env) {
void free_allocated(CGo* env) {
free(env->actions);
free(env->observations);
free(env->dones);
free(env->terminals);
free(env->rewards);
free_logbuffer(env->log_buffer);
free_initialized(env);
}

Expand Down Expand Up @@ -547,7 +517,7 @@ void enemy_random_move(CGo* env){
}
}
// If no move is possible, pass or end the game
env->dones[0] = 1;
env->terminals[0] = 1;
}

int find_group_liberty(CGo* env, int root){
Expand Down Expand Up @@ -649,8 +619,9 @@ void enemy_greedy_easy(CGo* env){
}

void c_reset(CGo* env) {
env->log = (Log){0};
env->dones[0] = 0;
env->tick = 0;
// We don't reset the log struct - leave it accumulating like in Pong
env->terminals[0] = 0;
env->score = 0;
for (int i = 0; i < (env->grid_size)*(env->grid_size); i++) {
env->board_states[i] = 0;
Expand All @@ -672,32 +643,26 @@ void c_reset(CGo* env) {
void end_game(CGo* env){
compute_score_tromp_taylor(env);
if (env->score > 0) {
env->rewards[0] = 1.0 ;
env->log.perf = 1.0;
env->rewards[0] = 1.0;
}
else if (env->score < 0) {
env->rewards[0] = -1.0;
env->log.perf = 0.0;
}
else {
env->rewards[0] = 0.0;
env->log.perf = 0.0;
}
env->log.score = env->score;
env->log.n++;
env->log.episode_return += env->rewards[0];
add_log(env->log_buffer, &env->log);
add_log(env);
c_reset(env);
}

void c_step(CGo* env) {
env->log.episode_length += 1;
env->tick += 1;
env->rewards[0] = 0.0;
int action = (int)env->actions[0];
// useful for training , can prob be a hyper param. Recommend to increase with larger board size
float max_moves = 3 * env->grid_size * env->grid_size;
if (env->log.episode_length > max_moves) {
env->dones[0] = 1;
if (env->tick > max_moves) {
env->terminals[0] = 1;
end_game(env);
compute_observations(env);
return;
Expand All @@ -706,7 +671,7 @@ void c_step(CGo* env) {
env->rewards[0] = env->reward_move_pass;
env->log.episode_return += env->reward_move_pass;
enemy_greedy_hard(env);
if (env->dones[0] == 1) {
if (env->terminals[0] == 1) {
end_game(env);
return;
}
Expand Down Expand Up @@ -735,7 +700,7 @@ void c_step(CGo* env) {
env->rewards[0] = -1;
}

if (env->dones[0] == 1) {
if (env->terminals[0] == 1) {
end_game(env);
return;
}
Expand Down Expand Up @@ -767,7 +732,7 @@ Client* make_client(int width, int height) {
return client;
}

void c_render(Client* client, CGo* env) {
void c_render(CGo* env) {
if (IsKeyDown(KEY_ESCAPE)) {
exit(0);
}
Expand Down
Loading