diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 9652099f38..813d1970eb 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -242,7 +242,7 @@ env.offroad_behavior = 1 env.traffic_light_behavior = 0 env.reward_randomization = False env.termination_mode = 0 -env.num_agents = 512 +env.num_agents = 1024 env.target_type = "static" env.goal_speed = 3.0 env.reward_collision = 3.0 @@ -286,10 +286,12 @@ inherits = "validation_defaults" type = "multi_scenario" enabled = true render = true +render_backend = "egl" render_views = ["sim_state", "bev"] env.simulation_mode = "gigaflow" env.map_dir = "pufferlib/resources/drive/binaries/carla" env.num_maps = 8 +env.num_agents = 1024 env.min_agents_per_env = 40 env.max_agents_per_env = 40 env.scenario_length = 500 @@ -306,8 +308,8 @@ render_backend = "triage_html" env.simulation_mode = "gigaflow" env.map_dir = "pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin" env.num_maps = 1 -env.min_agents_per_env = 40 -env.max_agents_per_env = 40 +env.min_agents_per_env = 1 +env.max_agents_per_env = 1 env.scenario_length = 500 env.resample_frequency = 500 eval.num_scenarios = 32 diff --git a/pufferlib/ocean/benchmark/evaluators/base.py b/pufferlib/ocean/benchmark/evaluators/base.py index 7226d7a7e8..5288e41792 100644 --- a/pufferlib/ocean/benchmark/evaluators/base.py +++ b/pufferlib/ocean/benchmark/evaluators/base.py @@ -4,8 +4,32 @@ from dataclasses import dataclass, field from typing import ClassVar from tqdm import tqdm +from pufferlib import viz from pufferlib.ocean.drive import binding +_GALLERY_METRIC_KEYS = ( + "score", + "dnf_rate", + "episode_return", + "num_goals_reached", + "collision_rate", + "offroad_rate", + "red_light_violation_rate", + "total_infractions", + "total_distance_travelled", + "episode_length", +) + + +def _episode_metrics_from_info(info): + """Pull the gallery-sort metrics out of a `completed_episode` summary dict.""" + out = {} + for key in _GALLERY_METRIC_KEYS: + value = info.get(key) + if isinstance(value, (int, float)): + out[key] = float(value) + return out + @dataclass class EvalResult: @@ -84,6 +108,7 @@ def rollout(self, vecenv, policy, args) -> EvalResult: try: metrics = self._run_rollout_loop(vecenv, policy, args) t_metric = time.time() + self._maybe_export_episodes(args, metrics) frames = self._render_pass(vecenv, policy, args) if self.render else [] t_render = time.time() finally: @@ -92,10 +117,6 @@ def rollout(self, vecenv, policy, args) -> EvalResult: metrics["metric_seconds"] = float(t_metric - t0) metrics["render_seconds"] = float(t_render - t_metric) metrics["eval_seconds"] = float(t_render - t0) - # Opt-in per-episode CSV + coverage check (writes files, folds - # coverage_* scalars into metrics). No-op unless the evaluator set - # eval.export_episode_csv / eval.verify_coverage. - self._maybe_export_episodes(args, metrics) return EvalResult(metrics=metrics, frames=frames) def _run_rollout_loop(self, vecenv, policy, args) -> dict: @@ -451,6 +472,10 @@ def _render_pass_html(self, vecenv, policy, args) -> list: out_dir = Path(args.get("render_results_dir") or args.get("eval_results_dir") or ".") / "gif" / self.name out_dir.mkdir(parents=True, exist_ok=True) + # Per-rendered-file metrics, accumulated inline from each scenario's + # completed_episode summary so the gallery sort uses this render's + # own rollouts (the metric-pass CSV is from a different vec env). + render_file_metrics = {} epoch = int(args.get("epoch") or 0) global_step = int(args.get("global_step") or 0) @@ -511,13 +536,18 @@ def _render_pass_html(self, vecenv, policy, args) -> list: # basename: map_name is the full bin path, and an absolute # value would make `out_dir / stem` escape out_dir. map_name = os.path.basename(str(info.get("map_name") or "map")).split(".")[0] - stem = f"{map_name}_{scenario_id}{step_suffix}" + # scenario_id repeats across rollouts on the same map in + # gigaflow mode (the C side fills it with the map's short + # name), so append a monotonic counter to make every + # rendered episode land in its own file. + stem = f"{map_name}_{scenario_id}_{scenarios_done:04d}{step_suffix}" tmp_path = out_dir / f"{stem}.pkl.zlib" html_path = out_dir / f"{stem}.html" tmp_path.write_bytes(bundle_bytes) mining_viz.render_compact_replay_html(str(tmp_path), str(html_path)) tmp_path.unlink(missing_ok=True) html_paths.append(html_path) + render_file_metrics[html_path.name] = _episode_metrics_from_info(info) scenarios_done += 1 progress.update(1) if scenarios_done >= num_scenarios: @@ -529,6 +559,9 @@ def _render_pass_html(self, vecenv, policy, args) -> list: vec.close() progress.close() + if html_paths: + viz.build_gallery_index(str(out_dir), file_metrics=render_file_metrics or None) + return html_paths def _render_pass_obs(self, vecenv, policy, args) -> list: @@ -544,7 +577,6 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: import torch import pufferlib - from pufferlib import viz eval_cfg = self.config.get("eval", {}) for required in ("render_num_scenarios", "render_max_steps"): @@ -566,9 +598,13 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: render_env_kwargs = self._render_env_overrides(args) render_env_kwargs.pop("render_mode", None) # obs viz reads state, no EGL + # Per-episode summaries are needed so the gallery sort dropdown can + # show this render's actual metrics. + render_env_kwargs["emit_completed_episodes"] = True device = args["train"]["device"] html_paths = [] + render_file_metrics = {} scenarios_done = 0 progress = tqdm(total=num_scenarios * (max_steps + 1), desc=f"{self.name} obs_html", unit="step") pool_method = getattr(policy, "pool_slot_counts", None) @@ -631,6 +667,7 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: policy_std_hist = [[] for _ in range(n_in_batch)] policy_log_prob_hist = [[] for _ in range(n_in_batch)] pool_hist = None + batch_summary = None for t in range(max_steps): with torch.no_grad(): ob_t = torch.as_tensor(ob).to(device) @@ -707,7 +744,10 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: np.asarray(policy_outputs[start_obs_index:end_obs_index], dtype=np.float32).copy() ) start_obs_index = end_obs_index - ob, _, _, _, _ = vec.step(clipped_action) + ob, _, _, _, step_infos = vec.step(clipped_action) + for d in self._flatten_infos(step_infos): + if isinstance(d, dict) and d.get("summary_type") == "completed_episode": + batch_summary = d progress.update(to_render) for e in range(to_render): map_name = os.path.basename(str(scenarios[e].get("map_name") or "map")).split(".")[0] @@ -739,6 +779,8 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: compact_replay[k] = hists[e] viz.generate_interactive_replay(scenarios[e], compact_replay, filename=str(path)) html_paths.append(path) + if batch_summary is not None: + render_file_metrics[path.name] = _episode_metrics_from_info(batch_summary) scenarios_done += 1 progress.update(1) if scenarios_done >= num_scenarios: @@ -748,7 +790,7 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: progress.close() if html_paths: - viz.build_gallery_index(str(out_dir)) + viz.build_gallery_index(str(out_dir), file_metrics=render_file_metrics or None) return html_paths def _render_env_overrides(self, args) -> dict: diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index bcfd00e753..5851f32503 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -2012,6 +2012,7 @@ static int my_completed_episode_to_dict(PyObject *dict, Env *env, CompletedEpiso assign_to_dict(dict, "red_light_violation_rate", summary->red_light_violation_rate); assign_to_dict(dict, "num_goals_reached", summary->num_goals_reached); assign_to_dict(dict, "score", summary->score); + assign_to_dict(dict, "dnf_rate", summary->dnf_rate); assign_to_dict(dict, "total_distance_travelled", summary->total_distance_travelled); assign_to_dict(dict, "total_infractions", summary->total_infractions); diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index f6dd977e90..ab39886587 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -192,6 +192,7 @@ struct CompletedEpisodeSummary { float red_light_violation_rate; float num_goals_reached; float score; + float dnf_rate; float total_distance_travelled; float total_infractions; float n; @@ -2642,6 +2643,7 @@ static void add_log(Drive *env) { s->red_light_violation_rate = 0.0f; s->num_goals_reached = 0.0f; s->score = 0.0f; + s->dnf_rate = 0.0f; s->total_distance_travelled = 0.0f; s->total_infractions = 0.0f; s->n = (float) env->active_agent_count; @@ -2660,6 +2662,7 @@ static void add_log(Drive *env) { int offroad = env->logs[i].offroad_rate; int red_light = env->logs[i].red_light_violation_rate; int num_goals = env->logs[i].num_goals_reached; + int num_waypoints = env->logs[i].num_waypoints_reached; s->episode_length += env->logs[i].episode_length; s->episode_return += env->logs[i].episode_return; s->collision_rate += collided; @@ -2669,6 +2672,12 @@ static void add_log(Drive *env) { if (num_goals >= 3 && !agent_i->removed && !agent_i->stopped) { s->score += 1.0f; } + // Mirror the aggregate Log DNF predicate (see drive.h:2577): + // the agent stayed clean of infractions but never reached even + // one waypoint — i.e. wandered without making progress. + if (!offroad && !collided && !red_light && num_waypoints < 1) { + s->dnf_rate += 1.0f; + } s->total_distance_travelled += agent_i->distance_since_spawn; if (collided || offroad || red_light) { s->total_infractions += 1.0f; diff --git a/pufferlib/viz.py b/pufferlib/viz.py index 5656d7d392..212fc336a4 100644 --- a/pufferlib/viz.py +++ b/pufferlib/viz.py @@ -1217,7 +1217,10 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"): obsCtx.save(); obsCtx.translate(obsC.width/2, obsC.height/2); obsCtx.scale(scale, -scale); obsCtx.lineCap = "round"; if(showAll){ obsCtx.strokeStyle="#bbb"; obsCtx.lineWidth=1.5*px; for(const r of frame.lanes){ obsCtx.beginPath(); obsCtx.moveTo(r[0]+r[4]*r[2]/2,r[1]+r[5]*r[2]/2); obsCtx.lineTo(r[0]-r[4]*r[2]/2,r[1]-r[5]*r[2]/2); obsCtx.stroke(); } } if(showAll){ obsCtx.strokeStyle="#333"; obsCtx.lineWidth=3*px; for(const r of frame.bounds){ obsCtx.beginPath(); obsCtx.moveTo(r[0]+r[4]*r[2]/2,r[1]+r[5]*r[2]/2); obsCtx.lineTo(r[0]-r[4]*r[2]/2,r[1]-r[5]*r[2]/2); obsCtx.stroke(); } } - if(showPool){ for(const r of frame.lanes.concat(frame.bounds)){ if(r[6] > 0){ obsCtx.strokeStyle=`rgba(0,125,145,${poolAlpha(r[6])})`; obsCtx.lineWidth=(obsMode === 1 ? 2.2 : 2.0)*px; obsCtx.beginPath(); obsCtx.moveTo(r[0]+r[4]*r[2]/2,r[1]+r[5]*r[2]/2); obsCtx.lineTo(r[0]-r[4]*r[2]/2,r[1]-r[5]*r[2]/2); obsCtx.stroke(); } } } + if(showPool){ + for(const r of frame.lanes){ if(r[6] > 0){ obsCtx.strokeStyle=`rgba(0,125,145,${poolAlpha(r[6])})`; obsCtx.lineWidth=(obsMode === 1 ? 2.2 : 2.0)*px; obsCtx.beginPath(); obsCtx.moveTo(r[0]+r[4]*r[2]/2,r[1]+r[5]*r[2]/2); obsCtx.lineTo(r[0]-r[4]*r[2]/2,r[1]-r[5]*r[2]/2); obsCtx.stroke(); } } + for(const r of frame.bounds){ if(r[6] > 0){ obsCtx.strokeStyle=`rgba(200,0,0,${poolAlpha(r[6])})`; obsCtx.lineWidth=(obsMode === 1 ? 2.2 : 2.0)*px; obsCtx.beginPath(); obsCtx.moveTo(r[0]+r[4]*r[2]/2,r[1]+r[5]*r[2]/2); obsCtx.lineTo(r[0]-r[4]*r[2]/2,r[1]-r[5]*r[2]/2); obsCtx.stroke(); } } + } for(const g of frame.gps){ obsCtx.fillStyle="magenta"; obsCtx.beginPath(); obsCtx.arc(g[0],g[1],5*px,0,7); obsCtx.fill(); } for(const t of frame.traffic_controls){ if(showPool && t.pool > 0){ obsCtx.strokeStyle=`rgba(0,125,145,${poolAlpha(t.pool)})`; obsCtx.lineWidth=(obsMode === 1 ? 3.2 : 2.4)*px; obsCtx.beginPath(); obsCtx.moveTo(t.x1,t.y1); obsCtx.lineTo(t.x2,t.y2); obsCtx.stroke(); } if(showAll){ obsCtx.strokeStyle = t.type === 1 ? trafficColor({state:t.state}) : (t.type === 2 ? "#cc0000" : "#ffd700"); obsCtx.lineWidth=2.5*px; obsCtx.beginPath(); obsCtx.moveTo(t.x1,t.y1); obsCtx.lineTo(t.x2,t.y2); obsCtx.stroke(); } } for(const p of frame.partners){ if(!showAll && !(showPool && p.pool > 0)) continue; obsCtx.save(); obsCtx.translate(p.x,p.y); obsCtx.rotate(p.h); if(showAll){ obsCtx.fillStyle="rgba(136,136,136,.8)"; obsCtx.strokeStyle="#333"; obsCtx.lineWidth=1.5*px; obsCtx.beginPath(); obsCtx.rect(-p.l/2,-p.w/2,p.l,p.w); obsCtx.fill(); obsCtx.stroke(); } if(showPool && p.pool > 0){ obsCtx.strokeStyle=`rgba(0,125,145,${poolAlpha(p.pool)})`; obsCtx.lineWidth=(obsMode === 1 ? 2.4 : 2.0)*px; obsCtx.strokeRect(-p.l/2,-p.w/2,p.l,p.w); } obsCtx.restore(); } @@ -1296,22 +1299,120 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"): f.write(final_html) -def build_gallery_index(folder_path="."): - files = [f for f in os.listdir(folder_path) if f != "index.html" and re.fullmatch(r"(.+)_([0-9]+)\.html", f)] +def build_gallery_index(folder_path=".", file_metrics=None): + """Build an index.html navigator for per-episode replay HTMLs in folder_path. + + If `file_metrics` is a dict mapping ` -> {metric_name: value}`, + the index also exposes a sort dropdown so the user can flip between sort + keys (default: `score` ascending — failures bubble to the top). When + `file_metrics` is None or empty, behaves as before (filename-order + dropdown, no sort UI). + """ + files = [f for f in os.listdir(folder_path) if f != "index.html" and f.endswith(".html")] if not files: print("No matching .html files found in this directory.") return - def sort_key(filename): - match = re.fullmatch(r"(.+)_([0-9]+)\.html", filename) - env_map_name = match.group(1) - global_episode_id = int(match.group(2)) - return (global_episode_id, env_map_name) - - files.sort(key=sort_key) + # Lexicographic sort over the full filename. With the triage_html stem + # `{map}_{scenario_id}_{scenarios_done:04d}_epoch{e}_step{s}.html`, the + # zero-padded scenarios_done dominates ordering within a map. + files.sort() + + metrics_map = file_metrics or {} + has_metrics = bool(metrics_map) + + # (key, default_direction). Anything in this list with at least one + # non-null value across files gets a dropdown entry. Default direction + # is what makes triage-useful values bubble to the top. + SORT_KEYS = [ + ("score", "asc"), + ("dnf_rate", "desc"), + ("episode_return", "asc"), + ("num_goals_reached", "asc"), + ("collision_rate", "desc"), + ("offroad_rate", "desc"), + ("red_light_violation_rate", "desc"), + ("total_infractions", "desc"), + ("total_distance_travelled", "asc"), + ("episode_length", "asc"), + ] + + available_keys = [] + if has_metrics: + present = set() + for v in metrics_map.values(): + present.update(v.keys()) + for k, d in SORT_KEYS: + if k in present: + available_keys.append((k, d)) + + metrics_json = json.dumps(metrics_map, separators=(",", ":")) + defaults_json = json.dumps({k: d for k, d in available_keys}, separators=(",", ":")) + + def make_label(f): + if not has_metrics or f not in metrics_map: + return f.replace(".html", "").replace("_", " ") + bits = [f.replace(".html", "")] + for k in ("score", "dnf_rate", "num_goals_reached", "episode_return"): + if k in metrics_map[f]: + v = metrics_map[f][k] + bits.append(f"{k}={v:.2f}" if isinstance(v, float) else f"{k}={v}") + return " · ".join(bits) + + options_html = "\n".join(f'' for f in files) + + sort_ui = "" + sort_js = "" + if has_metrics and available_keys: + sort_options = "\n".join( + f'' for k, _ in available_keys + ) + sort_ui = ( + 'SORT' + f'' + '" + ) + sort_js = ( + ( + "const FILE_METRICS = __METRICS_JSON__;" + "const SORT_DEFAULTS = __DEFAULTS_JSON__;" + "const sortKeySel = document.getElementById('sortKey');" + "const sortDirSel = document.getElementById('sortDir');" + "function onSortKeyChange() {" + " const k = sortKeySel.value;" + " if (SORT_DEFAULTS[k]) sortDirSel.value = SORT_DEFAULTS[k];" + " resortFiles();" + "}" + "function resortFiles() {" + " const key = sortKeySel.value;" + " const dir = sortDirSel.value;" + " const opts = Array.from(select.options);" + " opts.sort(function (a, b) {" + " const fA = a.getAttribute('data-name');" + " const fB = b.getAttribute('data-name');" + " const mA = (FILE_METRICS[fA] || {})[key];" + " const mB = (FILE_METRICS[fB] || {})[key];" + " const nA = (mA === undefined || mA === null) ? -Infinity : mA;" + " const nB = (mB === undefined || mB === null) ? -Infinity : mB;" + " if (nA === nB) return fA.localeCompare(fB);" + " return dir === 'asc' ? nA - nB : nB - nA;" + " });" + " const current = select.value;" + " while (select.firstChild) select.removeChild(select.firstChild);" + " opts.forEach(function (o) { select.appendChild(o); });" + " select.value = current;" + " updateButtons();" + "}" + "resortFiles();" + ) + .replace("__METRICS_JSON__", metrics_json) + .replace("__DEFAULTS_JSON__", defaults_json) + ) - # 3. Build the HTML template html_content = """ @@ -1320,20 +1421,22 @@ def sort_key(filename):