From 66d3f0c854ab92f01b332cbf8bdf3ce7db5b0360 Mon Sep 17 00:00:00 2001 From: xinpw8 Date: Sat, 10 May 2025 09:37:19 +0000 Subject: [PATCH] fixed snake. no weird stuff fr this time. tested locally. --- pufferlib/ocean/env_binding.h | 84 +++++++++++---------------------- pufferlib/ocean/snake/binding.c | 69 +++++---------------------- pufferlib/ocean/snake/snake.c | 4 +- pufferlib/ocean/snake/snake.py | 43 +++++++++++++---- setup.py | 5 +- 5 files changed, 76 insertions(+), 129 deletions(-) diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index a067c8b9d9..2230c53692 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -215,52 +215,6 @@ static VecEnv* unpack_vecenv(PyObject* args) { return vec; } -static double unpack_with_index(PyObject* kwargs, char* key, int env_index) { - PyObject* val = PyDict_GetItemString(kwargs, key); - if (val == NULL) { - // If the key doesn't exist, don't set an error - this allows optional parameters - // Just return a default value that the caller can check for - return 0.0; - } - - // If val is a list, extract the element at env_index - if (PyList_Check(val)) { - if (env_index >= 0 && env_index < PyList_Size(val)) { - val = PyList_GetItem(val, env_index); - } else { - // If index is out of bounds, use the first element - val = PyList_GetItem(val, 0); - } - } - - if (PyLong_Check(val)) { - long out = PyLong_AsLong(val); - if (out > INT_MAX || out < INT_MIN) { - char error_msg[100]; - snprintf(error_msg, sizeof(error_msg), "Value %ld of integer argument %s is out of range", out, key); - PyErr_SetString(PyExc_TypeError, error_msg); - return 1; - } - // Cast on return. Safe because double can represent all 32-bit ints exactly - return out; - } - if (PyFloat_Check(val)) { - return PyFloat_AsDouble(val); - } - if (PyBool_Check(val)) { - return PyObject_IsTrue(val); - } - char error_msg[100]; - snprintf(error_msg, sizeof(error_msg), "Failed to unpack keyword %s as int", key); - PyErr_SetString(PyExc_TypeError, error_msg); - return 1; -} - -// Original unpack function for backward compatibility -static double unpack(PyObject* kwargs, char* key) { - return unpack_with_index(kwargs, key, 0); -} - static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { if (PyTuple_Size(args) != 7) { PyErr_SetString(PyExc_TypeError, "vec_init requires 6 arguments"); @@ -374,7 +328,6 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it } - // Add an env_index to kwargs for use by my_init for (int i = 0; i < num_envs; i++) { Env* env = (Env*)calloc(1, sizeof(Env)); if (!env) { @@ -407,16 +360,6 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { } Py_DECREF(py_seed); - // Add the environment index to kwargs - PyObject* py_env_index = PyLong_FromLong(i); - if (PyDict_SetItemString(kwargs, "env_index", py_env_index) < 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to set env_index in kwargs"); - Py_DECREF(py_env_index); - Py_DECREF(kwargs); - return NULL; - } - Py_DECREF(py_env_index); - PyObject* empty_args = PyTuple_New(0); if (my_init(env, empty_args, kwargs)) { PyErr_SetString(PyExc_TypeError, "env_init failed"); @@ -584,6 +527,33 @@ static PyObject* vec_close(PyObject* self, PyObject* args) { Py_RETURN_NONE; } +static double unpack(PyObject* kwargs, char* key) { + PyObject* val = PyDict_GetItemString(kwargs, key); + if (val == NULL) { + // If the key doesn't exist, don't set an error - this allows optional parameters + // Just return a default value that the caller can check for + return 0.0; + } + if (PyLong_Check(val)) { + long out = PyLong_AsLong(val); + if (out > INT_MAX || out < INT_MIN) { + char error_msg[100]; + snprintf(error_msg, sizeof(error_msg), "Value %ld of integer argument %s is out of range", out, key); + PyErr_SetString(PyExc_TypeError, error_msg); + return 1; + } + // Cast on return. Safe because double can represent all 32-bit ints exactly + return out; + } + if (PyFloat_Check(val)) { + return PyFloat_AsDouble(val); + } + char error_msg[100]; + snprintf(error_msg, sizeof(error_msg), "Failed to unpack keyword %s as int", key); + PyErr_SetString(PyExc_TypeError, error_msg); + return 1; +} + // Method table static PyMethodDef methods[] = { {"env_init", (PyCFunction)env_init, METH_VARARGS | METH_KEYWORDS, "Init environment with observation, action, reward, terminal, truncation arrays"}, diff --git a/pufferlib/ocean/snake/binding.c b/pufferlib/ocean/snake/binding.c index ccb7d5c61b..360c021f6d 100644 --- a/pufferlib/ocean/snake/binding.c +++ b/pufferlib/ocean/snake/binding.c @@ -3,63 +3,18 @@ #define Env CSnake #include "../env_binding.h" -// Helper function to extract an int from a Python object, handling lists -static int extract_int(PyObject* kwargs, const char* key, int default_value) { - PyObject* obj = PyDict_GetItemString(kwargs, key); - if (obj != NULL) { - if (PyList_Check(obj)) { - obj = PyList_GetItem(obj, 0); - } - return PyLong_AsLong(obj); - } - return default_value; -} - -// Helper function to extract a float from a Python object, handling lists -static float extract_float(PyObject* kwargs, const char* key, float default_value) { - PyObject* obj = PyDict_GetItemString(kwargs, key); - if (obj != NULL) { - if (PyList_Check(obj)) { - obj = PyList_GetItem(obj, 0); - } - return PyFloat_AsDouble(obj); - } - return default_value; -} - -// Helper function to extract a bool from a Python object, handling lists -static int extract_bool(PyObject* kwargs, const char* key, int default_value) { - PyObject* obj = PyDict_GetItemString(kwargs, key); - if (obj != NULL) { - if (PyList_Check(obj)) { - obj = PyList_GetItem(obj, 0); - } - return PyObject_IsTrue(obj); - } - return default_value; -} - -static int my_init(Env* env, PyObject* args, PyObject* kwargs) { - // Get the environment index from kwargs - int env_index = 0; - PyObject* env_index_obj = PyDict_GetItemString(kwargs, "env_index"); - if (env_index_obj != NULL) { - env_index = PyLong_AsLong(env_index_obj); - } - - // Use unpack_with_index to properly handle lists - env->width = unpack_with_index(kwargs, "width", env_index); - env->height = unpack_with_index(kwargs, "height", env_index); - env->num_snakes = unpack_with_index(kwargs, "num_snakes", env_index); - env->vision = unpack_with_index(kwargs, "vision", env_index); - env->leave_corpse_on_death = unpack_with_index(kwargs, "leave_corpse_on_death", env_index); - env->food = unpack_with_index(kwargs, "num_food", env_index); - env->reward_food = unpack_with_index(kwargs, "reward_food", env_index); - env->reward_corpse = unpack_with_index(kwargs, "reward_corpse", env_index); - env->reward_death = unpack_with_index(kwargs, "reward_death", env_index); - env->max_snake_length = unpack_with_index(kwargs, "max_snake_length", env_index); - env->cell_size = unpack_with_index(kwargs, "cell_size", env_index); - +static int my_init(Env* env, PyObject* args, PyObject* kwargs) { + env->width = unpack(kwargs, "width"); + env->height = unpack(kwargs, "height"); + env->num_snakes = unpack(kwargs, "num_snakes"); + env->vision = unpack(kwargs, "vision"); + env->leave_corpse_on_death = unpack(kwargs, "leave_corpse_on_death"); + env->food = unpack(kwargs, "num_food"); + env->reward_food = unpack(kwargs, "reward_food"); + env->reward_corpse = unpack(kwargs, "reward_corpse"); + env->reward_death = unpack(kwargs, "reward_death"); + env->max_snake_length = unpack(kwargs, "max_snake_length"); + env->cell_size = unpack(kwargs, "cell_size"); init_csnake(env); return 0; } diff --git a/pufferlib/ocean/snake/snake.c b/pufferlib/ocean/snake/snake.c index d685915dc9..13abfaffd1 100644 --- a/pufferlib/ocean/snake/snake.c +++ b/pufferlib/ocean/snake/snake.c @@ -76,7 +76,7 @@ void test_performance(float test_time) { } int main() { - // demo(); - test_performance(30); + demo(); + // test_performance(30); return 0; } diff --git a/pufferlib/ocean/snake/snake.py b/pufferlib/ocean/snake/snake.py index 5473fd1621..5700541979 100644 --- a/pufferlib/ocean/snake/snake.py +++ b/pufferlib/ocean/snake/snake.py @@ -33,7 +33,6 @@ def __init__(self, num_envs=16, width=640, height=360, self.max_snake_length = min(max_snake_length, max_area) self.report_interval = report_interval - # This block required by advanced PufferLib env spec self.single_observation_space = gymnasium.spaces.Box( low=0, high=2, shape=(2*vision+1, 2*vision+1), dtype=np.int8) self.single_action_space = gymnasium.spaces.Discrete(4) @@ -41,18 +40,42 @@ def __init__(self, num_envs=16, width=640, height=360, self.render_mode = render_mode self.tick = 0 - # Calculate cell_size for rendering self.cell_size = int(np.ceil(1280 / max(max(width), max(height)))) super().__init__(buf) - self.c_envs = binding.vec_init(self.observations, self.actions, - self.rewards, self.terminals, self.truncations, - num_envs, seed, width=width, height=height, - num_snakes=num_snakes, num_food=num_food, vision=vision, - max_snake_length=max_snake_length, - leave_corpse_on_death=leave_corpse_on_death, - reward_food=reward_food, reward_corpse=reward_corpse, - reward_death=reward_death, cell_size=self.cell_size) + c_envs = [] + offset = 0 + for i in range(num_envs): + ns = num_snakes[i] + obs_slice = self.observations[offset:offset+ns] + act_slice = self.actions[offset:offset+ns] + rew_slice = self.rewards[offset:offset+ns] + term_slice = self.terminals[offset:offset+ns] + trunc_slice = self.truncations[offset:offset+ns] + # Seed each env uniquely: i + seed * num_envs + env_seed = i + seed * num_envs + env_id = binding.env_init( + obs_slice, + act_slice, + rew_slice, + term_slice, + trunc_slice, + env_seed, + width=width[i], + height=height[i], + num_snakes=ns, + num_food=num_food[i], + vision=vision, + leave_corpse_on_death=leave_corpse_on_death[i], + reward_food=reward_food, + reward_corpse=reward_corpse, + reward_death=reward_death, + max_snake_length=self.max_snake_length, + cell_size=self.cell_size + ) + c_envs.append(env_id) + offset += ns + self.c_envs = binding.vectorize(*c_envs) def reset(self, seed=None): self.tick = 0 diff --git a/setup.py b/setup.py index c5f8f4cb3d..ccbbf366cc 100644 --- a/setup.py +++ b/setup.py @@ -270,13 +270,12 @@ # 'pufferlib/ocean/tactical/c_tactical', #'pufferlib/ocean/squared/cy_squared', #'pufferlib/ocean/snake/cy_snake', - 'pufferlib/ocean/gpudrive/cy_gpudrive', #'pufferlib/ocean/pong/cy_pong', # 'pufferlib/ocean/breakout/cy_breakout', # 'pufferlib/ocean/cartpole/cy_cartpole', # 'pufferlib/ocean/connect4/cy_connect4', #'pufferlib/ocean/grid/cy_grid', - 'pufferlib/ocean/tripletriad/cy_tripletriad', + # 'pufferlib/ocean/tripletriad/cy_tripletriad', # 'pufferlib/ocean/go/cy_go', 'pufferlib/ocean/rware/cy_rware', 'pufferlib/ocean/trash_pickup/cy_trash_pickup', @@ -315,7 +314,7 @@ #c_args = ['-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION', '-DPLATFORM_DESKTOP', '-O2'] #c_args += "-Wsign-compare -DNDEBUG -g -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC".split() -pure_c_extensions = ['squared', 'pong', 'breakout', 'enduro', 'blastar', 'grid', 'nmmo3', 'tactical', 'go', 'cartpole', 'connect4'] +pure_c_extensions = ['squared', 'pong', 'breakout', 'enduro', 'blastar', 'grid', 'nmmo3', 'tactical', 'go', 'cartpole', 'connect4', 'snake'] extensions += [ Extension(