diff --git a/recipe/simpletir/agent_utils.py b/recipe/simpletir/agent_utils.py index bbbb0734..cdf3336a 100644 --- a/recipe/simpletir/agent_utils.py +++ b/recipe/simpletir/agent_utils.py @@ -448,7 +448,6 @@ def run_llm_loop( [not v for v in is_void_turn], dtype=torch.bool ) active_num_list.append(active_mask.sum().item()) - turns_stats[curr_active_mask] += 1 use_code_stats += torch.tensor(code_info["use_code"], dtype=torch.int) valid_code_stats += torch.tensor(code_info["valid_code"], dtype=torch.int) success_code_lines.extend(code_info["success_code_lines"]) @@ -462,6 +461,7 @@ def run_llm_loop( next_obs[i] += self.prompt_dict["final_prompt"] if step < self.config.max_turns - 1: + turns_stats[curr_active_mask] += 1 next_obs_ids = self._process_next_obs(next_obs) rollings = self._update_rolling_state(rollings, responses_ids, next_obs_ids) original_right_side = self._update_right_side( diff --git a/requirements_sglang.txt b/requirements_sglang.txt index e859e1c9..2ae61996 100644 --- a/requirements_sglang.txt +++ b/requirements_sglang.txt @@ -17,5 +17,5 @@ torchdata torchvision transformers wandb -sglang[all]==0.4.4.post3 +sglang[all]==0.4.9 torch-memory-saver>=0.0.5 \ No newline at end of file