diff --git a/src/jaxued/environments/maze/renderer.py b/src/jaxued/environments/maze/renderer.py index 7ff68fb..480a603 100644 --- a/src/jaxued/environments/maze/renderer.py +++ b/src/jaxued/environments/maze/renderer.py @@ -34,8 +34,9 @@ def render_level(self, level, env_params): @partial(jax.jit, static_argnums=(0,)) def render_state(self, env_state: EnvState, env_params: EnvParams) -> chex.Array: tile_size = self.tile_size - nrows = self.env.max_height + 2*self.render_border - ncols = self.env.max_width + 2*self.render_border + max_height, max_width = env_state.wall_map.shape + nrows = max_height + 2*self.render_border + ncols = max_width + 2*self.render_border width_px = ncols * tile_size height_px = nrows * tile_size