diff --git a/zoo/atari/envs/test_qwen_arati_env.py b/zoo/atari/envs/test_qwen_arati_env.py new file mode 100644 index 000000000..4281256e1 --- /dev/null +++ b/zoo/atari/envs/test_qwen_arati_env.py @@ -0,0 +1,375 @@ +# run_pong_qwen_ddp.py +import os, re, json, random +from dataclasses import dataclass +from collections import deque, namedtuple +from typing import List, Tuple, Union +import numpy as np +import shutil +from PIL import Image +import torch +import torch.distributed as dist + +from transformers import AutoProcessor +from transformers import Qwen2_5_VLForConditionalGeneration + +from easydict import EasyDict +from zoo.atari.envs.atari_lightzero_env import AtariEnvLightZero + + +def to_model_image(arr: Union[np.ndarray, torch.Tensor], channel_last: bool, use_pil: bool): + """ + 返回: + - use_pil=True -> PIL.Image(RGB) + - use_pil=False -> numpy HWC uint8 + """ + if isinstance(arr, torch.Tensor): + arr = arr.detach().cpu().numpy() + arr = np.asarray(arr) + + # 2D 灰度 -> HWC + if arr.ndim == 2: + arr = arr[:, :, None] + + # 统一到 HWC + if channel_last: + hwc = arr + else: + assert arr.ndim == 3 and arr.shape[0] in (1, 3), f"Expect (C,H,W) or (H,W,C), got {arr.shape}" + hwc = np.transpose(arr, (1, 2, 0)) + + # 灰度扩 3 通道 + if hwc.shape[-1] == 1: + hwc = np.repeat(hwc, 3, axis=-1) + + # 归一到 uint8 + if hwc.dtype != np.uint8: + if hwc.max() <= 1.0: + hwc = hwc * 255.0 + hwc = np.clip(hwc, 0, 255).astype(np.uint8) + + if use_pil: + return Image.fromarray(hwc, mode="RGB") + else: + return hwc + + + +def init_distributed(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + if not dist.is_initialized(): + dist.init_process_group(backend=backend, init_method="env://") + rank = dist.get_rank() + world_size = dist.get_world_size() + + # 设定 device + local_rank = int(os.getenv("LOCAL_RANK", rank % max(1, torch.cuda.device_count()))) + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + + return rank, world_size, local_rank + + +Transition = namedtuple("Transition", ["step", "image", "action_str"]) + +class QwenPongPolicy: + """ + - 历史 n 帧(仅包含:图像 + 我们当时的动作字符串) + - 指令结构(中文提示语义一致,英文更利于指令稳定): + 环境描述 + 任务描述 + 当前图片 + + 可选动作(字符串列表) + + 历史轨迹(只含 历史图片 + 历史动作字符串) + 要求模型输出:单行 纯动作字符串(如 RIGHTFIRE) + - 解析失败则从 allowed 随机抽取一个字符串,再映射回动作 id + - 支持 FlashAttention-2(若不可用自动回退) + """ + # 6 个官方动作名 + ID2NAME = { + 0: "NOOP", + 1: "FIRE", + 2: "RIGHT", + 3: "LEFT", + 4: "RIGHTFIRE", + 5: "LEFTFIRE", + } + NAME2ID = {v: k for k, v in ID2NAME.items()} + + ACTION_EXPLAIN = { + "NOOP": "Do nothing (stay still).", + "FIRE": "Serve a new point(use only at the start of a rally).", + "RIGHT": "Move your RIGHT paddle UP in this Pong port.", + "LEFT": "Move your RIGHT paddle DOWN in this Pong port.", + "RIGHTFIRE": "Move UP and SERVE simultaneously (use only to start a rally).", + "LEFTFIRE": "Move DOWN and SERVE simultaneously (use only to start a rally).", + } + + + def __init__(self, model_name: str, dtype: torch.dtype, history_n: int, + use_pil: bool, channel_last: bool, device: torch.device, save_dir: str = "pong_ddp_frames", save_image=False, rank: int = 0, topk: int = 1): + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, + torch_dtype=dtype, + device_map={"": device.index}, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + + self.model.eval() + + self.history_n = history_n + self.buffer: deque[Transition] = deque(maxlen=history_n) + self.use_pil = use_pil + self.channel_last = channel_last + self.device = device + self.save_image = save_image + self.save_dir = save_dir + self.rank = rank + self.rank_dir = os.path.join(self.save_dir, f"rank{rank:02d}") + self.topk = topk + if os.path.exists(self.rank_dir): + shutil.rmtree(self.rank_dir) + + os.makedirs(self.rank_dir, exist_ok=True) + self.meta_path = os.path.join(self.rank_dir, "trajectory.jsonl") + + def save_pil_if_enabled(self, img: Image.Image, save_root: str, step: int): + d = os.path.join(save_root, f"rank{self.rank:02d}") + os.makedirs(d, exist_ok=True) + img.save(os.path.join(d, f"frame_{step:06d}.png")) + + + def log_step(self, step: int, action_id: int, action_str: str, reward: float, topk_seq=None): + rec = { + "step": int(step), + "action_id": int(action_id), + "action": str(action_str), + "reward": float(reward), + } + if topk_seq is not None: + rec["topk_seq"] = topk_seq # 新增:不受约束的 top-k 序列及概率 + with open(self.meta_path, "a") as f: + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + + + def _build_messages_and_images(self, cur_img, allowed_names: List[str]): + """ + user.content 顺序(按你的要求): + 1) 环境描述 + 任务描述(文本) + 2) 当前图片 + 3) 可选动作(字符串列表)+ 对这 6 个动作的清晰解释 + 4) 历史轨迹(只包含:历史图片 + 对应动作字符串) + 5) 输出格式要求:只返回一行 {ACTION: } + """ + content = [] + images_for_processor = [] + + # 1) 环境 + 任务 + content.append({ + "type": "text", + "text": ( + "Environment: Atari Pong (ALE) — two paddles rally a ball.\n" + "Task: You control the right green paddle. Keep the paddle aligned vertically with the ball to return the ball and avoid losing it. The left paddle will hit the white ball, and you should try to land the white ball in the center of the right green paddle to return the ball and score.\n" + ) + }) + + # 2) 当前图片 + content.append({"type": "text", "text": "Current state image:"}) + content.append({"type": "image", "image": cur_img}) + images_for_processor.append(cur_img) + + # 3) 可选动作 + 解释 + allowed_str = ", ".join(allowed_names) + # 解释文本(只针对当前允许的动作给出说明) + explain_lines = [] + for name in allowed_names: + desc = self.ACTION_EXPLAIN.get(name, "") + if desc: + explain_lines.append(f"- {name}: {desc}") + explain_text = "\n".join(explain_lines) + + content.append({ + "type": "text", + "text": ( + f"Available actions (choose exactly one string): {allowed_str}\n" + # "Action semantics:\n" + # f"{explain_text}\n" + "Heuristic (to guide your choice): if the ball is above your paddle, choose an RIGHT action; " + "if the ball is below, choose a LEFT action; if perfectly aligned and rally is active, NOOP briefly is acceptable." + ) + }) + + # 4) 历史交互轨迹(只包含:历史图片 + 当时选择的动作字符串) + if len(self.buffer) > 0: + # 在输出历史之前,先加一句说明 + content.append({"type": "text", "text": + f"Now you are at step t. The following shows the previous {len(self.buffer)} steps of history (from most recent to oldest):\n" + }) + + # 然后再逐条列出历史轨迹 + for k, tr in enumerate(list(self.buffer)[::-1], start=1): # k=1 表示 t−1 + content.append({"type": "text", "text": f"(t−{k}) observation:"}) + content.append({"type": "image", "image": tr.image}) + content.append({"type": "text", "text": f"(t−{k}) you chose: {tr.action_str}\n"}) + images_for_processor.append(tr.image) + + # 5) 输出格式要求(只返回一行 {ACTION: }) + content.append({ + "type": "text", + "text": ( + "\nOutput requirement:\n" + "- Return EXACTLY ONE line in the form: {ACTION: }\n" + f"- MUST be one of: {allowed_str}\n" + ) + }) + + messages = [ + {"role": "system", "content": "You are a precise action selector for Atari Pong. Always follow the requested output format."}, + {"role": "user", "content": content}, + ] + return messages, images_for_processor + + def _parse_action_string(self, text: str, allowed_names: List[str]) -> str: + # 为避免 RIGHTFIRE 被 RIGHT 抢先匹配,按长度降序 + names_sorted = sorted(allowed_names, key=len, reverse=True) + + alt = "|".join(map(re.escape, names_sorted)) + pattern = rf"""\{{\s*"?ACTION"?\s*[::]\s*"?\s*({alt})\s*"?\s*\}}""" + + m = re.search(pattern, text, flags=re.IGNORECASE) + if m: + return m.group(1).upper() + + return random.choice(allowed_names) + + @torch.inference_mode() + def decide(self, obs_dict: dict, step: int): + allowed_ids = [i for i, v in enumerate(obs_dict.get("action_mask", [1]*6)) if int(v) == 1] + allowed_names = [self.ID2NAME[i] for i in allowed_ids] + + cur_img = to_model_image(obs_dict["observation"], channel_last=False, use_pil=self.use_pil) + + messages, images_for_processor = self._build_messages_and_images(cur_img, allowed_names) + prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) + + inputs = self.processor( + text=prompt, + images=images_for_processor, + return_tensors="pt" + ).to(self.device) + + tok = self.processor.tokenizer + gen_out = self.model.generate( + **inputs, + max_new_tokens=16, + num_beams=self.topk, + num_return_sequences=self.topk, + do_sample=False, # 纯测概率,不采样 + output_scores=True, + return_dict_in_generate=True, + length_penalty=1.0, # 不做长度归一 + early_stopping=True, + eos_token_id=getattr(tok, "eos_token_id", None), + pad_token_id=getattr(tok, "pad_token_id", getattr(tok, "eos_token_id", None)), + ) + + prompt_len = int(inputs["input_ids"].shape[1]) + seq_new = gen_out.sequences[:, prompt_len:] + decoded = self.processor.tokenizer.batch_decode(seq_new, skip_special_tokens=True) + decoded = [s.strip() for s in decoded] + + trans = self.model.compute_transition_scores( + gen_out.sequences, gen_out.scores, gen_out.beam_indices, normalize_logits=True + ) # shape: [num_return, gen_steps] + gen_steps = len(gen_out.scores) + logps = trans[:, -gen_steps:].sum(dim=1) # 每条序列的 log P + + # 在 top-k 集合里归一化成概率 + probs = torch.softmax(logps, dim=0).tolist() + + # 组装 top-3(序列 + 概率) + topk_seq = [{"text": txt, "prob": float(p)} for txt, p in zip(decoded, probs)] + + # 仍然用你已有的正则从“top-1 文本”解析动作作为实际执行的动作 + out_text = decoded[0] + action_str = self._parse_action_string(out_text, allowed_names) + action_id = self.NAME2ID.get(action_str, random.choice(allowed_ids)) + + if self.use_pil and self.save_image: + self.save_pil_if_enabled(cur_img, self.save_dir, step) + + # 返回 topk_seq,供日志写入 + return action_id, action_str, out_text, topk_seq + + def record(self, prev_obs: dict, action_id: int, step: int): + img = to_model_image(prev_obs["observation"], channel_last=False, use_pil=self.use_pil) + action_str = self.ID2NAME[action_id] + self.buffer.append(Transition(step=step, image=img, action_str=action_str)) + + +if __name__ == "__main__": + rank, world_size, local_rank = init_distributed() + device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") + + base_seed = 12345 + random.seed(base_seed + rank) + np.random.seed(base_seed + rank) + torch.manual_seed(base_seed + rank) + + config = EasyDict(dict( + collector_env_num=8, + evaluator_env_num=3, + n_evaluator_episode=3, + env_id='PongNoFrameskip-v4', + env_type='Atari', + observation_shape=[3, 96, 96], + collect_max_episode_steps=int(1.08e5), + eval_max_episode_steps=int(1.08e5), + gray_scale=False, + frame_skip=4, + frame_stack_num=1, + episode_life=True, + clip_rewards=True, + channel_last=False, + render_mode_human=False, + scale=True, + warp_frame=True, + save_video=False, + transform2string=False, + game_wrapper=True, + stop_value=int(1e6), + save_replay=False, + replay_path=None, + )) + config.max_episode_steps = config.eval_max_episode_steps + env = AtariEnvLightZero(config) + + policy = QwenPongPolicy( + model_name="/fs-computility/niuyazhe/shared/xiongjyu/model/Qwen2.5-VL-3B-Instruct", + dtype=torch.bfloat16, + history_n=2, + use_pil=True, + channel_last=config.channel_last, + device=device, + save_dir="/fs-computility/niuyazhe/shared/xiongjyu/jericho/LightZero/pong_ddp_frames", + save_image=True, + rank=rank, + topk=3 + ) + + obs = env.reset() + episode_return, steps = 0.0, 0 + + while True: + action_id, action_str, raw, topk_seq = policy.decide(obs, step=steps) + prev_obs = obs + obs, reward, done, info = env.step(action_id) + policy.log_step(steps, action_id, action_str, reward, topk_seq=topk_seq) + + policy.record(prev_obs, action_id, step=steps) + + episode_return += float(reward) + steps += 1 + + if done or steps >= config.max_episode_steps: + print(f"[RANK {rank}/{world_size}] return={episode_return}, steps={steps}, info={info}") + break diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index f0a4a675a..dd89defa1 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -14,6 +14,8 @@ from ding.envs import BaseEnv, BaseEnvTimestep from jericho import FrotzEnv +import re + @ENV_REGISTRY.register('jericho') class JerichoEnv(BaseEnv): @@ -120,75 +122,186 @@ def __init__(self, cfg: Dict[str, Any]) -> None: self.reward_space: gym.spaces.Box = gym.spaces.Box( low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32) - def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: + # def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: + # """ + # Overview: + # Prepare the observation for the agent, including tokenization and the creation of an action mask. + + # Arguments: + # - obs (:obj:`str`): The raw observation text. + # - return_str (:obj:`bool`, optional): If True, the observation is returned as a raw string (defaults to False). + + # Returns: + # - (:obj:`Dict[str, Any]`): A dictionary containing the observation, attention mask (if applicable), + # and action mask. For unizero, an additional "to_play" key is provided. + # """ + # if self._action_list is None: + # self._action_list = self._env.get_valid_actions() + + # # Filter available actions based on whether stuck actions are removed. + # if self.remove_stuck_actions: + # available_actions: List[str] = [a for a in self._action_list if a not in self.blocked_actions] + # if len(available_actions) < 1 and len(self._action_list) > 0: + # # Fallback to the first action if all actions are blocked. + # available_actions = [self._action_list[0]] + # self._action_list = available_actions + # else: + # available_actions = self._action_list + + # # Include player location and inventory in the observation if enabled. + # if self.add_location_and_inventory: + # player_location = self._env.get_player_location() + # inventory = self._env.get_inventory() + # full_obs: str = f"Location: {player_location}\nInventory: {inventory}{obs}\nValid actions: {available_actions}" + # else: + # full_obs = f"{obs}\nValid actions: {available_actions}" + + # full_obs_str = copy.deepcopy(full_obs) + + # # Tokenize observation if required. + # if not return_str: + # tokenized_output = JerichoEnv.tokenizer( + # [full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len) + # obs_attn_mask = tokenized_output['attention_mask'] + # full_obs = np.array(tokenized_output['input_ids'][0], dtype=np.int32) + # # Create action mask based on the number of available actions. + # if len(available_actions) == 0: + # # Avoid an all-zero action mask that can cause segmentation faults. + # action_mask = [1] + [0] * (self.max_action_num - 1) + # elif 0 < len(available_actions) <= self.max_action_num: + # action_mask = [1] * len(available_actions) + [0] * (self.max_action_num - len(available_actions)) + # elif len(available_actions) == self.max_action_num: + # action_mask = [1] * len(available_actions) + # else: + # action_mask = [1] * self.max_action_num + + # action_mask = np.array(action_mask, dtype=np.int8) + + # if return_str: + # if self.for_unizero: + # return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + + # else: + # return {'observation': full_obs, 'action_mask': action_mask} + # else: + # if self.for_unizero: + # if self.save_replay: + # return {'observation': full_obs, 'observation_str': full_obs_str,'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + # else: + # return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + # else: + # return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask} + + + def _clean_jericho_object_str(self, obj_str: str) -> str: """ - Overview: - Prepare the observation for the agent, including tokenization and the creation of an action mask. - - Arguments: - - obs (:obj:`str`): The raw observation text. - - return_str (:obj:`bool`, optional): If True, the observation is returned as a raw string (defaults to False). - - Returns: - - (:obj:`Dict[str, Any]`): A dictionary containing the observation, attention mask (if applicable), - and action mask. For unizero, an additional "to_play" key is provided. + 一个高效且鲁棒的辅助函数,用于从 Jericho 的原始对象字符串中提取干净的名称。 + + 例如: + - 输入: "Obj42: Upstairs hallway Parent0 Sibling0..." + - 输出: "Upstairs hallway" + - 输入: "some other string" (不匹配模式) + - 输出: "some other string" (原样返回) """ - if self._action_list is None: - self._action_list = self._env.get_valid_actions() - - # Filter available actions based on whether stuck actions are removed. - if self.remove_stuck_actions: - available_actions: List[str] = [a for a in self._action_list if a not in self.blocked_actions] - if len(available_actions) < 1 and len(self._action_list) > 0: - # Fallback to the first action if all actions are blocked. - available_actions = [self._action_list[0]] - self._action_list = available_actions - else: - available_actions = self._action_list - - # Include player location and inventory in the observation if enabled. - if self.add_location_and_inventory: - player_location = self._env.get_player_location() - inventory = self._env.get_inventory() - full_obs: str = f"Location: {player_location}\nInventory: {inventory}{obs}\nValid actions: {available_actions}" - else: - full_obs = f"{obs}\nValid actions: {available_actions}" + if not isinstance(obj_str, str): + return str(obj_str) # 如果输入不是字符串,则返回其字符串表示形式 + + # 使用正则表达式匹配 "ObjXX: " 和 " Parent" 之间的内容。 + # - `r'...'`: 原始字符串,避免反斜杠问题。 + # - `Obj\d+:`: 匹配 "Obj" 后跟一个或多个数字和冒号。 + # - `\s*`: 匹配零个或多个空白字符。 + # - `(.*?)`: 非贪婪捕获组。这是核心,它会捕获两个模式之间的所有字符。 + # - `\s*Parent`: 匹配 "Parent" 之前可能存在的空白。 + match = re.search(r'Obj\d+:\s*(.*?)\s*Parent', obj_str) - full_obs_str = copy.deepcopy(full_obs) + if match: + # group(1) 返回第一个捕获组的内容,.strip() 去除前后多余的空格。 + return match.group(1).strip() - # Tokenize observation if required. - if not return_str: - tokenized_output = JerichoEnv.tokenizer( - [full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len) - obs_attn_mask = tokenized_output['attention_mask'] - full_obs = np.array(tokenized_output['input_ids'][0], dtype=np.int32) - # Create action mask based on the number of available actions. - if len(available_actions) == 0: - # Avoid an all-zero action mask that can cause segmentation faults. - action_mask = [1] + [0] * (self.max_action_num - 1) - elif 0 < len(available_actions) <= self.max_action_num: - action_mask = [1] * len(available_actions) + [0] * (self.max_action_num - len(available_actions)) - elif len(available_actions) == self.max_action_num: - action_mask = [1] * len(available_actions) - else: - action_mask = [1] * self.max_action_num + # 如果正则表达式不匹配,说明字符串格式不符合预期, + # 作为鲁棒性保障,我们原样返回它。 + return obj_str - action_mask = np.array(action_mask, dtype=np.int8) - if return_str: - if self.for_unizero: - return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: + """ + [已修复] 准备观测数据,现在会返回一个包含独立 'location' 和 'inventory' 键的字典。 + """ + if self._action_list is None: + self._action_list = self._env.get_valid_actions() + + if self.remove_stuck_actions: + available_actions: List[str] = [a for a in self._action_list if a not in self.blocked_actions] + if not available_actions and self._action_list: + available_actions = [self._action_list[0]] + self._action_list = available_actions + else: + available_actions = self._action_list + + # 初始化返回的字典 + result_obs = {} + + # 构造观测字符串 + if self.add_location_and_inventory: + # 1. 获取原始的 ZObject 对象和列表 + raw_player_location = self._env.get_player_location() + raw_inventory = self._env.get_inventory() + + # 2. 直接从对象属性获取干净的名称 + # 这种方法高效、直接且不会因字符串格式变化而失效。 + + # 处理地点:检查对象是否存在,然后获取其 .name 属性 + player_location = raw_player_location.name if hasattr(raw_player_location, 'name') else "N/A" + + # 处理库存:遍历列表,对每个 ZObject 获取其 .name 属性 + if raw_inventory and isinstance(raw_inventory, list): + # 使用列表推导式和 hasattr 安全地提取每个物品的名称 + inventory = [item.name for item in raw_inventory if hasattr(item, 'name')] + else: + # 如果库存为空或不是列表,则返回一个空列表 + inventory = [] + + # 3. 将清理后的结构化信息添加到返回字典中 + result_obs['location'] = player_location + # 如果库存为空列表,f-string会将其显示为 '[]',如果需要显示 'empty',可以加一步判断 + result_obs['inventory'] = inventory if inventory else "empty" + + # 4. 构造最终的观测字符串 + # f-string 会自动处理列表,输出如:Inventory: ['piece of white paper'] + # 这比 "Inventory: [piece of white paper]" 稍微更规范一些。 + full_obs_str = f"Location: {result_obs['location']}\nInventory: {result_obs['inventory']}\n{obs}\nValid actions: {available_actions}" else: - return {'observation': full_obs, 'action_mask': action_mask} - else: - if self.for_unizero: + # 即使不添加,也提供默认值以避免KeyError + result_obs['location'] = "N/A" + result_obs['inventory'] = "N/A" + full_obs_str = f"{obs}\nValid actions: {available_actions}" + + result_obs['observation'] = full_obs_str + + # 处理 tokenization + if not return_str: + tokenized_output = JerichoEnv.tokenizer( + [full_obs_str], truncation=True, padding="max_length", max_length=self.max_seq_len) + result_obs['observation'] = np.array(tokenized_output['input_ids'][0], dtype=np.int32) + result_obs['obs_attn_mask'] = tokenized_output['attention_mask'] if self.save_replay: - return {'observation': full_obs, 'observation_str': full_obs_str,'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} - else: - return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + result_obs['observation_str'] = full_obs_str + + # 创建 action_mask + if not available_actions: + action_mask = [1] + [0] * (self.max_action_num - 1) else: - return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask} + num_actions = len(available_actions) + action_mask = [1] * min(num_actions, self.max_action_num) + [0] * (self.max_action_num - min(num_actions, self.max_action_num)) + + result_obs['action_mask'] = np.array(action_mask, dtype=np.int8) + + if self.for_unizero: + result_obs['to_play'] = -1 + result_obs['timestep'] = self._timestep + + return result_obs def reset(self, return_str: bool = False) -> Dict[str, Any]: """ @@ -202,6 +315,8 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]: - (:obj:`Dict[str, Any]`): The processed observation from the environment reset. """ initial_observation, info = self._env.reset() + + self.finished = False self._init_flag = True self._action_list = None @@ -305,6 +420,7 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) -> observation, reward, done, info = self._env.step(action_str) + self._timestep += 1 if not self.for_unizero: reward = np.array([float(reward)]) diff --git a/zoo/jericho/envs/test_qwen_v5.py b/zoo/jericho/envs/test_qwen_v5.py new file mode 100644 index 000000000..f19977d68 --- /dev/null +++ b/zoo/jericho/envs/test_qwen_v5.py @@ -0,0 +1,513 @@ +import logging +import copy +import os +import json +import re +from datetime import datetime +from typing import Any, Dict, List, Optional, Union +from easydict import EasyDict +from collections import deque +from abc import ABC, abstractmethod +import numpy as np +import torch +from transformers import AutoTokenizer + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +import torch.distributed as dist +import torch + +from jericho_env import JerichoEnv +# --- LLM Provider Specific Imports --- +# Qwen (local transformers) +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +# Poe API (or other OpenAI-compatible APIs) +import openai + +# --- 新增的导入 --- +from sentence_transformers import SentenceTransformer +import torch.nn.functional as F + +# ===================================================================================== +# 优化后的 MemoryManager 类 (无修改) +# ===================================================================================== +class MemoryManager: + """ + 管理好/坏记忆的类,实现了基于相似度的驱逐策略。 + 【优化】: 增加了更严格的冗余检查。 + """ + def __init__(self, maxlen: int, similarity_threshold: float = 0.85, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): + """ + 初始化 MemoryManager. + + 参数: + maxlen (int): 记忆库的最大容量。 + similarity_threshold (float): 用于判断是否冗余或替换的相似度阈值。 (稍微提高阈值) + device (str): 用于计算嵌入的设备 ('cuda' 或 'cpu')。 + """ + self.maxlen = maxlen + self.similarity_threshold = similarity_threshold + self.device = device + self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) + self.memories: List[str] = [] + self.embeddings: Optional[torch.Tensor] = None + logging.info(f"MemoryManager initialized with maxlen={maxlen}, threshold={similarity_threshold} on device='{self.device}'") + + def get_memories(self) -> List[str]: + return self.memories + + def __len__(self) -> int: + return len(self.memories) + + def add_memory(self, new_memory_text: str) -> str: + """ + 添加一条新记忆,并根据策略进行管理。 + 【优化】: 优先检查并拒绝与现有记忆高度相似的新记忆,以防止冗余。 + 返回一个描述所执行操作的日志消息。 + """ + if not new_memory_text or not isinstance(new_memory_text, str): + return "Skipped adding empty or invalid memory." + + with torch.no_grad(): + new_embedding = self.model.encode(new_memory_text, convert_to_tensor=True, device=self.device) + + # 【核心优化 1】: 检查新记忆是否与任何现有记忆过于相似 + if self.embeddings is not None and len(self.memories) > 0: + similarities = F.cosine_similarity(new_embedding, self.embeddings) + max_similarity = torch.max(similarities).item() + if max_similarity > self.similarity_threshold: + return f"Rejected redundant memory (sim: {max_similarity:.2f} > {self.similarity_threshold}). New: '{new_memory_text}'" + + # 如果记忆库未满,直接添加 + if len(self.memories) < self.maxlen: + self.memories.append(new_memory_text) + if self.embeddings is None: + self.embeddings = new_embedding.unsqueeze(0) + else: + self.embeddings = torch.cat([self.embeddings, new_embedding.unsqueeze(0)], dim=0) + return f"Added new memory: '{new_memory_text}'" + + # 如果记忆库已满,执行先进先出(FIFO)替换最旧的记忆 + else: + removed_memory = self.memories.pop(0) + self.memories.append(new_memory_text) + + # 更新嵌入张量 + self.embeddings = torch.cat([self.embeddings[1:], new_embedding.unsqueeze(0)], dim=0) + return (f"FIFO eviction. " + f"Removed: '{removed_memory}', Added: '{new_memory_text}'") + + +def init_distributed(): + if not dist.is_initialized(): + dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo') + rank = dist.get_rank() + world_size = dist.get_world_size() + return rank, world_size + +# ===================================================================================== +# 1. 定义一个抽象基类 (Abstract Base Class) 作为所有 LLM 策略的通用接口 +# ===================================================================================== +class BaseLLMPolicy(ABC): + def __init__(self, reflection_history_len=9999): + self.reflection_history_len = reflection_history_len + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + def format_prompt_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None, use_structured_info: bool = True) -> str: + """ + 【修改 1】: 进一步优化提示词格式。 + - 增强了系统指令,明确了目标和思考方式。 + - 强化了决策指令,引导模型利用记忆打破循环,并增加了更强硬的指令。 + """ + system_prompt = """You are an intelligent agent playing a text-based adventure game. Your primary goal is to solve puzzles, uncover the story, and maximize your score. +Think step-by-step. First, analyze the current situation. Second, review your strategic memories. Third, decide on the single best action to make progress. +""" + + system_prompt += "\n--- Strategic Memory ---\n" + if good_memory: + system_prompt += "✅ Successful Strategies (Good Memories):\n" + "".join(f"- {mem}\n" for mem in good_memory) + else: + system_prompt += "✅ No successful strategies recorded yet.\n" + + if bad_memory: + system_prompt += "❌ Mistakes to Avoid (Bad Memories):\n" + "".join(f"- {mem}\n" for mem in bad_memory) + else: + system_prompt += "❌ No mistakes recorded yet.\n" + + system_prompt += "\n--- Condensed Game History (Recent Actions) ---\n" + if not history: + system_prompt += "The game has just begun.\n" + else: + history_lines = [] + for i, h in enumerate(history): + loc_info = "" + # if use_structured_info and 'location' in h and (i == 0 or h['location'] != history[i-1].get('location')): + # loc_info = f"At '{h['location']}', " + history_lines.append(f"{loc_info}I did: > {h['action']}\nOutcome: {h['obs']}") + system_prompt += "\n".join(history_lines) + "\n" + + # system_prompt += "\n--- Current Situation ---\n" + # if use_structured_info: + # system_prompt += f"[Current Location]: {location}" + # system_prompt += f"[My Inventory]: {inventory}\n" + + # system_prompt += f"[Full Scene Description]:\n{self.clean_obs(full_scene_description)}\n" + # system_prompt += f"\n--- Current Situation ---\n{self.clean_obs(full_scene_description)}\n" + system_prompt += f"\n--- Current Situation ---\n{full_scene_description}\n" + + + # 【修改 1.1】: 强化决策指令 + system_prompt += ( + "\n--- Your Decision ---\n" + "Based on the situation and my memories, what is the single most logical next action to progress? " + "If a past strategy is listed in 'Mistakes to Avoid', you MUST NOT repeat it. If a strategy is in 'Successful Strategies', consider applying a similar logic.\n" + f"Choose ONLY ONE from the following valid actions: {', '.join(valid_actions)}\n" + "Respond with the action phrase only.\n" + "> " + ) + return system_prompt + + @abstractmethod + def sample_action_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None, use_structured_info: bool = True) -> tuple[str, str]: + pass + + def clean_obs(self, obs: str) -> str: + if not isinstance(obs, str): + logging.error(f"clean_obs received a non-string input: {type(obs)}. Converting to string.") + obs = str(obs) + + obs = re.sub(r'Copyright \(c\).*reserved\.', '', obs, flags=re.DOTALL) + obs = re.sub(r'ZORK is a registered trademark.*', '', obs, flags=re.DOTALL) + obs = re.sub(r'Revision \d+ / Serial number \d+', '', obs, flags=re.DOTALL) + core_obs = obs.split('Valid actions:')[0].strip() + return '\n'.join(line.strip() for line in core_obs.split('\n') if line.strip()) + + @abstractmethod + def generate_autonomous_reflection(self, history: List[Dict[str, str]], final_score: float, use_structured_info: bool = True) -> Dict[str, str]: + pass + +# ===================================================================================== +# 2. 为本地 Qwen 模型创建一个具体实现 +# ===================================================================================== +class QwenLocalPolicy(BaseLLMPolicy): + def __init__(self, model_path: str, reflection_history_len: int = 9999, local_rank: int = 0): + super().__init__(reflection_history_len) + self.device = torch.device(f"cuda:{local_rank}") + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to(self.device) + print(f"QwenLocalPolicy initialized on device {self.device}") + + def _generate_text(self, prompt: str, gen_config: GenerationConfig) -> str: + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + + output = self.model.generate(**model_inputs, generation_config=gen_config) + output_ids = output[0][len(model_inputs.input_ids[0]):].tolist() + content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip() + return content + + def sample_action_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None, use_structured_info: bool = True) -> tuple[str, str]: + prompt = self.format_prompt_v2(history, full_scene_description, inventory, location, valid_actions, good_memory, bad_memory, use_structured_info) + + gen_config = GenerationConfig( + temperature=0.2, top_p=0.9, do_sample=True, max_new_tokens=20, + pad_token_id=self.tokenizer.eos_token_id + ) + + content = self._generate_text(prompt, gen_config).lower() + + sorted_actions = sorted(valid_actions, key=len, reverse=True) + for va in sorted_actions: + cleaned_va = va.lower().strip() + cleaned_content = content.replace(">", "").strip() + if cleaned_va in cleaned_content or cleaned_content in cleaned_va: + return va, prompt + + safe_actions = ['look', 'inventory', 'l', 'i'] + for sa in safe_actions: + if sa in valid_actions: + return sa, prompt + + return (valid_actions[0] if valid_actions else 'look'), prompt + + def generate_autonomous_reflection(self, history: List[Dict[str, str]], final_score: float, use_structured_info: bool = True) -> Dict[str, str]: + """ + 【修改 2】: 全面重写反思提示词,使其更具针对性。 + - 聚焦于分析导致游戏结束或停滞的【最后一步】。 + - 要求提供更具体的、可操作的替代方案。 + - 完整地展示最后的游戏历史,包括导致结束的动作和观测。 + """ + recent_history = history[-self.reflection_history_len:] + # 确保历史记录包含最后一步和最终结果 + trajectory_str = "\n".join([f"Step {i+1}: I did '> {h['action']}' and the outcome was:\n{h['obs']}\n" for i, h in enumerate(recent_history)]) + + prompt = ( + "You are an expert strategy analyst for a text-based adventure game. Your task is to analyze a game trajectory, identify the critical mistake that led to a low score or game over, and produce actionable advice.\n\n" + f"--- Final Score ---\n{final_score}\n\n" + f"--- Full Gameplay Log of the Final Attempt ---\n{trajectory_str}\n" + "--- Analysis Task ---\n" + "Analyze the final steps of the gameplay log. Pinpoint the single critical mistake that ended the game or caused it to get stuck. Provide your analysis as a single, clean JSON object with two keys: `critical_mistake_analysis` and `alternative_action_suggestion`.\n" + "- `critical_mistake_analysis`: A concise sentence explaining what the final wrong move was and why it was wrong, based on the final outcome. Example: 'The final action 'go north' led to a dead end, wasting a turn when exploring 'east' towards the Mayor's home was a more promising option based on the description.'\n" + "- `alternative_action_suggestion`: A concrete, actionable suggestion for what to do instead of the mistaken action, phrased as a general rule for the future. Example: 'When at the 'Outside' location after moving west, the next logical step is to explore 'east' to investigate the Mayor's home mentioned in the description.'\n\n" + "Respond ONLY with the JSON object, without any surrounding text or explanations.\n" + ) + + gen_config = GenerationConfig( + temperature=0.2, top_p=0.9, do_sample=True, max_new_tokens=250, + pad_token_id=self.tokenizer.eos_token_id + ) + + content = self._generate_text(prompt, gen_config) + + reflections = {'critical_mistake_analysis': '', 'alternative_action_suggestion': ''} + try: + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + json_str = json_match.group(0).replace('\n', ' ').replace('\r', '').strip() + parsed_json = json.loads(json_str) + reflections['critical_mistake_analysis'] = parsed_json.get('critical_mistake_analysis', '').strip() + reflections['alternative_action_suggestion'] = parsed_json.get('alternative_action_suggestion', '').strip() + else: + logging.warning(f"Could not find a JSON object in the LLM's reflection output. Content: {content}") + except json.JSONDecodeError as e: + logging.error(f"Failed to decode JSON from reflection output: {e}. Content: {content}") + return reflections + +# ===================================================================================== +# 3. Poe API 实现 (同样应用优化, 此处省略,修改逻辑与 QwenLocalPolicy 完全相同) +# ===================================================================================== +class PoeAPIPolicy(BaseLLMPolicy): # ... (内部逻辑与QwenLocalPolicy的修改相同) + pass # 假设已按上述逻辑修改 + +# ===================================================================================== +# 主程序修改 +# ===================================================================================== +if __name__ == '__main__': + """ + export CUDA_VISIBLE_DEVICES=6 + torchrun --nproc_per_node=1 /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/jericho/envs/test_qwen_v5.py + """ + # --- 核心配置区 --- + LLM_PROVIDER = "Qwen" + + LLM_CONFIGS = { + "Qwen": { + "model_path": "/fs-computility/niuyazhe/shared/xiongjyu/model/Qwen2.5-7B-Instruct", + "model_name": "Qwen2.5-7B-Instruct", + }, + "PoeAPI": { + "model_name": "GPT-3.5-Turbo", + "api_key": "YOUR_POE_API_KEY", # 请替换为您的密钥 + "base_url": "https://api.poe.com/v1" + } + } + + num_episodes = 50 + max_history_len = 10 + + # --- 游戏和实验设置 --- + USE_AUTONOMOUS_REFLECTION = True + # USE_STRUCTURED_INFO = False + USE_STRUCTURED_INFO = True + + ENV_TYPE = 'detective' + + # --- 【优化】相似度阈值 --- + SIMILARITY_THRESHOLD = 0.9 # 可以适当提高阈值,因为反思质量更高了 + + # --- 初始化 --- + rank, world_size = init_distributed() + print(f"[RANK {rank}] Initialized. World size: {world_size}") + + # --- 实例化策略对象 --- + llm_policy: BaseLLMPolicy + current_config = LLM_CONFIGS[LLM_PROVIDER] + + print(f"Using LLM Provider: {LLM_PROVIDER}") + if LLM_PROVIDER == "Qwen": + llm_policy = QwenLocalPolicy( + model_path=current_config["model_path"], + local_rank=rank + ) + elif LLM_PROVIDER == "PoeAPI": + # llm_policy = PoeAPIPolicy(...) # 省略 + pass + else: + raise ValueError(f"Unknown LLM_PROVIDER: {LLM_PROVIDER}") + + env_cfg = EasyDict( + dict( + max_steps=100, + game_path="./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/" + f"{ENV_TYPE}.z5", + add_location_and_inventory=USE_STRUCTURED_INFO, + max_action_num=55, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512, + remove_stuck_actions=False, + for_unizero=False, + collector_env_num=1, + evaluator_env_num=1, + save_replay=False, + save_replay_path=None, + env_type=ENV_TYPE, + collect_policy_mode='expert' + ) + ) + + log_dir = f'./priorzero_log_optimized/{current_config["model_name"]}/{ENV_TYPE}_structured_{USE_STRUCTURED_INFO}_auto_reflect_{USE_AUTONOMOUS_REFLECTION}_v2' + os.makedirs(log_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join(log_dir, f"rank_{rank}_{timestamp}.txt") + f = open(log_file, "w", encoding="utf-8") + + env = JerichoEnv(env_cfg) + + good_trial_memory = MemoryManager(maxlen=20, similarity_threshold=SIMILARITY_THRESHOLD) + bad_trial_memory = MemoryManager(maxlen=20, similarity_threshold=SIMILARITY_THRESHOLD) + + total_scores = [] + + for episode_id in range(num_episodes): + f.write(f"{'='*80}\n") + f.write(f"STARTING EPISODE: {episode_id} | Good Memories: {len(good_trial_memory)}, Bad Memories: {len(bad_trial_memory)}\n") + f.write(f"Current Good Memories: {good_trial_memory.get_memories()}\n") + f.write(f"Current Bad Memories: {bad_trial_memory.get_memories()}\n") + f.write(f"CONFIG: USE_STRUCTURED_INFO = {USE_STRUCTURED_INFO}, USE_AUTONOMOUS_REFLECTION = {USE_AUTONOMOUS_REFLECTION}\n") + f.write(f"{'='*80}\n") + f.flush() + + obs_or_dict = env.reset(return_str=True) + done = False + step_count = 0 + episode_reward = 0 + episode_history = [] + last_actions = deque(maxlen=4) + + while not done: + if isinstance(obs_or_dict, dict): + full_scene_description = obs_or_dict.get('observation', '') + inventory_str = obs_or_dict.get('inventory', 'not tracked') + location_str = obs_or_dict.get('location', 'not tracked') + else: + full_scene_description = obs_or_dict + inventory_str = "not tracked" + location_str = "not tracked" + try: + inventory_str = env._env.get_inventory() + location_str = env._env.get_player_location() + except: + pass + + valid_actions = env._env.get_valid_actions() + + if len(last_actions) == 4 and last_actions[0] == last_actions[2] and last_actions[1] == last_actions[3]: + f.write("[SYSTEM] Stuck in a 2-step loop. Forcing 'look' to re-evaluate.\n") + if 'look' in valid_actions: + action = 'look' + prompt = "[SYSTEM] Stuck in a loop. Forcing 'look' to re-evaluate." + else: + non_loop_actions = [a for a in valid_actions if a not in last_actions] + if non_loop_actions: + action = np.random.choice(non_loop_actions) + prompt = f"[SYSTEM] Stuck in a loop and 'look' is unavailable. Forcing random exploration: {action}." + else: + action = valid_actions[0] + prompt = f"[SYSTEM] Stuck in a loop and no other options. Forcing: {action}." + else: + action, prompt = llm_policy.sample_action_v2( + history=list(episode_history)[-max_history_len:], + full_scene_description=full_scene_description, + inventory=inventory_str, + location=location_str, + valid_actions=valid_actions, + good_memory=good_trial_memory.get_memories(), + bad_memory=bad_trial_memory.get_memories(), + use_structured_info=USE_STRUCTURED_INFO + ) + + last_actions.append(action) + + next_obs_or_dict, reward, done, info = env.step(action, return_str=True) + episode_reward += reward + + if isinstance(next_obs_or_dict, dict): + next_obs_str = next_obs_or_dict.get('observation', '') + else: + next_obs_str = next_obs_or_dict + + # cleaned_next_obs = llm_policy.clean_obs(next_obs_str) + # history_entry = {'obs': cleaned_next_obs, 'action': action} + history_entry = {'obs': next_obs_str, 'action': action} + + # if USE_STRUCTURED_INFO: + # history_entry['location'] = location_str + + episode_history.append(history_entry) + + # 【修改 3】: 增加“好记忆”的生成逻辑 + if reward > 0: + good_memory_suggestion = f"At the location '{location_str}', performing the action '{action}' was successful and yielded a positive reward." + log_msg = good_trial_memory.add_memory(good_memory_suggestion) + f.write(f"[GOOD MEMORY UPDATE]: {log_msg}\n") + print(f"[GOOD MEMORY UPDATE]: {log_msg}") + + + f.write(f"--- Step {step_count} ---\n") + f.write(f"[Prompt Sent to LLM]:\n{prompt}\n") + f.write(f"---------------------------------\n") + f.write(f"[LLM Chose Action]: {action}\n") + # 【修改 4】: 无论 done 是否为 True,都记录下最终的观测结果 + # f.write(f"[Final Observation]:\n{next_obs_str}\n") + f.write(f"[Next Observation]:\n{next_obs_str}\n") + f.write(f"---------------------------------\n") + f.write(f"[Reward]: {reward}, [Done]: {done}, [Total Score]: {episode_reward}\n") + f.write(f"---------------------------------\n\n") + f.flush() + + obs_or_dict = next_obs_or_dict + step_count += 1 + if step_count >= env_cfg.max_steps: + done = True + + final_score = info.get('eval_episode_return', episode_reward) + total_scores.append(final_score) + + # 【修改 5】: 修改反思逻辑,只在分数低时进行,并使用新的反思结果 + if USE_AUTONOMOUS_REFLECTION: # 只在分数低时反思 + f.write("[Reflection Mode]: Autonomous (V2 - Focused on Final Mistake)\n") + print("[Reflection Mode]: Autonomous (V2 - Focused on Final Mistake)") + + # 使用新的、更具针对性的反思函数 + reflections = llm_policy.generate_autonomous_reflection( + episode_history, final_score=final_score, use_structured_info=USE_STRUCTURED_INFO + ) + + f.write(f"Reflections JSON from LLM: {reflections}\n") + print(f"Reflections JSON from LLM: {reflections}") + + # 我们现在使用更有意义的“替代行动建议”作为坏记忆 + bad_reflection_suggestion = reflections.get('alternative_action_suggestion') + + if bad_reflection_suggestion: + log_msg = bad_trial_memory.add_memory(bad_reflection_suggestion) + f.write(f"[BAD MEMORY UPDATE]: {log_msg}\n") + print(f"[BAD MEMORY UPDATE]: {log_msg}") + else: + f.write(f"[BAD MEMORY UPDATE]: LLM did not provide a valid suggestion.\n") + print(f"[BAD MEMORY UPDATE]: LLM did not provide a valid suggestion.") + + f.write(f"Episode {episode_id} finished. Final Score: {final_score}\n") + print(f"Episode {episode_id} finished. Final Score: {final_score}") + + f.write(f"\n\n{'='*80}\n") + f.write(f"All episodes finished. Total Scores: {total_scores}\n") + f.write(f"Average score: {np.mean(total_scores)}\n") + f.close() + print(f"[RANK {rank}] Finished. Log written to {log_file}") + del env + + if dist.is_initialized(): + dist.destroy_process_group() + print(f"[RANK {rank}] Process group destroyed.") \ No newline at end of file diff --git a/zoo/jericho/envs/test_qwen_v8.py b/zoo/jericho/envs/test_qwen_v8.py new file mode 100644 index 000000000..71976e2dd --- /dev/null +++ b/zoo/jericho/envs/test_qwen_v8.py @@ -0,0 +1,507 @@ +import logging +import copy +import os +import json +import re +from datetime import datetime +from typing import Any, Dict, List, Optional, Union +from easydict import EasyDict +from collections import deque +from abc import ABC, abstractmethod +import numpy as np +import torch +from transformers import AutoTokenizer + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +import torch.distributed as dist +import torch + +from jericho_env import JerichoEnv +# --- LLM Provider Specific Imports --- +# Qwen (local transformers) +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +# Poe API (or other OpenAI-compatible APIs) +import openai + +# --- 新增的导入 --- +from sentence_transformers import SentenceTransformer +import torch.nn.functional as F + +# ===================================================================================== +# 优化后的 MemoryManager 类 (无修改) +# ===================================================================================== +class MemoryManager: + """ + 管理好/坏记忆的类,实现了基于相似度的驱逐策略。 + 【优化】: 增加了更严格的冗余检查。 + """ + def __init__(self, maxlen: int, similarity_threshold: float = 0.85, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): + """ + 初始化 MemoryManager. + + 参数: + maxlen (int): 记忆库的最大容量。 + similarity_threshold (float): 用于判断是否冗余或替换的相似度阈值。 (稍微提高阈值) + device (str): 用于计算嵌入的设备 ('cuda' 或 'cpu')。 + """ + self.maxlen = maxlen + self.similarity_threshold = similarity_threshold + self.device = device + self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) + self.memories: List[str] = [] + self.embeddings: Optional[torch.Tensor] = None + logging.info(f"MemoryManager initialized with maxlen={maxlen}, threshold={similarity_threshold} on device='{self.device}'") + + def get_memories(self) -> List[str]: + return self.memories + + def __len__(self) -> int: + return len(self.memories) + + def add_memory(self, new_memory_text: str) -> str: + """ + 添加一条新记忆,并根据策略进行管理。 + 【优化】: 优先检查并拒绝与现有记忆高度相似的新记忆,以防止冗余。 + 返回一个描述所执行操作的日志消息。 + """ + if not new_memory_text or not isinstance(new_memory_text, str): + return "Skipped adding empty or invalid memory." + + with torch.no_grad(): + new_embedding = self.model.encode(new_memory_text, convert_to_tensor=True, device=self.device) + + # 【核心优化 1】: 检查新记忆是否与任何现有记忆过于相似 + if self.embeddings is not None and len(self.memories) > 0: + similarities = F.cosine_similarity(new_embedding, self.embeddings) + max_similarity = torch.max(similarities).item() + if max_similarity > self.similarity_threshold: + return f"Rejected redundant memory (sim: {max_similarity:.2f} > {self.similarity_threshold}). New: '{new_memory_text}'" + + # 如果记忆库未满,直接添加 + if len(self.memories) < self.maxlen: + self.memories.append(new_memory_text) + if self.embeddings is None: + self.embeddings = new_embedding.unsqueeze(0) + else: + self.embeddings = torch.cat([self.embeddings, new_embedding.unsqueeze(0)], dim=0) + return f"Added new memory: '{new_memory_text}'" + + # 如果记忆库已满,执行先进先出(FIFO)替换最旧的记忆 + else: + removed_memory = self.memories.pop(0) + self.memories.append(new_memory_text) + + # 更新嵌入张量 + self.embeddings = torch.cat([self.embeddings[1:], new_embedding.unsqueeze(0)], dim=0) + return (f"FIFO eviction. " + f"Removed: '{removed_memory}', Added: '{new_memory_text}'") + + +def init_distributed(): + if not dist.is_initialized(): + dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo') + rank = dist.get_rank() + world_size = dist.get_world_size() + return rank, world_size + +# ===================================================================================== +# 1. 定义一个抽象基类 (Abstract Base Class) 作为所有 LLM 策略的通用接口 +# ===================================================================================== +class BaseLLMPolicy(ABC): + def __init__(self, reflection_history_len=9999): + self.reflection_history_len = reflection_history_len + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + def format_prompt_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None, use_structured_info: bool = True) -> str: + """ + - 增强了系统指令,明确了目标和思考方式。 + - 强化了决策指令,引导模型利用记忆打破循环,并增加了更强硬的指令。 + """ + system_prompt = """You are an intelligent agent playing a text-based adventure game. Your primary goal is to solve puzzles, uncover the story, and maximize your score. +Think step-by-step. First, analyze the Current Obeservation and Recent Game History. Second, review your Strategic Memory. Third, decide on the single best action to make progress. +""" + + system_prompt += "\n--- Strategic Memory ---\n" + if good_memory: + system_prompt += "✅ Successful Strategies (Good Memories):\n" + "".join(f"- {mem}\n" for mem in good_memory) + else: + system_prompt += "✅ No successful strategies recorded yet.\n" + + if bad_memory: + system_prompt += "❌ Mistakes to Avoid (Bad Memories):\n" + "".join(f"- {mem}\n" for mem in bad_memory) + else: + system_prompt += "❌ No mistakes recorded yet.\n" + + system_prompt += "\n--- Recent Game History ---\n" + + if not history: + system_prompt += "The game has just begun.\n" + else: + history_lines = [] + for i, h in enumerate(history): + loc_info = "" + if i < len(history)-1: + history_lines.append(f"{loc_info}YOU DID ACTION: > {h['action']}\nTHEN YOU OBSERVED: {h['obs']}") + else: + history_lines.append(f"{loc_info}YOU DID ACTION: > {h['action']}\nTHEN YOU OBSERVED:") + + system_prompt += "\n".join(history_lines) + "\n" + + system_prompt += f"\n--- Current Obeservation ---\n{full_scene_description}\n" + + + system_prompt += ( + "\n--- Your Decision ---\n" + "Based on the Current Obeservation, Recent Game History and Strategic Memory, what is the single most promising next action to progress? " + "If a past strategy is listed in 'Mistakes to Avoid', you MUST NOT repeat it. If a strategy is in 'Successful Strategies', consider applying a similar logic.\n" + f"Choose ONLY ONE from the following valid actions: {', '.join(valid_actions)}\n" + "Respond with the action phrase only.\n" + "> " + ) + return system_prompt + + @abstractmethod + def sample_action_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None, use_structured_info: bool = True) -> tuple[str, str]: + pass + + def clean_obs(self, obs: str) -> str: + if not isinstance(obs, str): + logging.error(f"clean_obs received a non-string input: {type(obs)}. Converting to string.") + obs = str(obs) + + obs = re.sub(r'Copyright \(c\).*reserved\.', '', obs, flags=re.DOTALL) + obs = re.sub(r'ZORK is a registered trademark.*', '', obs, flags=re.DOTALL) + obs = re.sub(r'Revision \d+ / Serial number \d+', '', obs, flags=re.DOTALL) + core_obs = obs.split('Valid actions:')[0].strip() + return '\n'.join(line.strip() for line in core_obs.split('\n') if line.strip()) + + @abstractmethod + def generate_autonomous_reflection(self, history: List[Dict[str, str]], final_score: float, use_structured_info: bool = True) -> Dict[str, str]: + pass + +# ===================================================================================== +# 2. 为本地 Qwen 模型创建一个具体实现 +# ===================================================================================== +class QwenLocalPolicy(BaseLLMPolicy): + def __init__(self, model_path: str, reflection_history_len: int = 9999, local_rank: int = 0): + super().__init__(reflection_history_len) + self.device = torch.device(f"cuda:{local_rank}") + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to(self.device) + print(f"QwenLocalPolicy initialized on device {self.device}") + + def _generate_text(self, prompt: str, gen_config: GenerationConfig) -> str: + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + + output = self.model.generate(**model_inputs, generation_config=gen_config) + output_ids = output[0][len(model_inputs.input_ids[0]):].tolist() + content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip() + return content + + def sample_action_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None, use_structured_info: bool = True) -> tuple[str, str]: + prompt = self.format_prompt_v2(history, full_scene_description, inventory, location, valid_actions, good_memory, bad_memory, use_structured_info) + + gen_config = GenerationConfig( + temperature=0.2, top_p=0.9, do_sample=True, max_new_tokens=20, + pad_token_id=self.tokenizer.eos_token_id + ) + + content = self._generate_text(prompt, gen_config).lower() + + sorted_actions = sorted(valid_actions, key=len, reverse=True) + for va in sorted_actions: + cleaned_va = va.lower().strip() + cleaned_content = content.replace(">", "").strip() + if cleaned_va in cleaned_content or cleaned_content in cleaned_va: + return va, prompt + + safe_actions = ['look', 'inventory', 'l', 'i'] + for sa in safe_actions: + if sa in valid_actions: + return sa, prompt + + return (valid_actions[0] if valid_actions else 'look'), prompt + + def generate_autonomous_reflection(self, history: List[Dict[str, str]], final_score: float, use_structured_info: bool = True) -> Dict[str, str]: + """ + - 聚焦于分析导致游戏结束或停滞的【最后一步】。 + - 要求提供更具体的、可操作的替代方案。 + - 完整地展示最后的游戏历史,包括导致结束的动作和观测。 + """ + recent_history = history[-self.reflection_history_len:] + # 确保历史记录包含最后一步和最终结果 + # TODO + trajectory_str = "\n".join([f"Step {i+1}: I did '> {h['action']}' and the outcome was:\n{h['obs']}\n" for i, h in enumerate(recent_history)]) + + prompt = ( + "You are an expert strategy analyst for a text-based adventure game. Your task is to analyze a game trajectory, identify the critical mistake that led to a low score or game over, and produce actionable advice.\n\n" + f"--- Final Score ---\n{final_score}\n\n" + f"--- Full Gameplay Log of the Final Attempt ---\n{trajectory_str}\n" + "--- Analysis Task ---\n" + "Analyze the final steps of the gameplay log. Pinpoint the single critical mistake that ended the game or caused it to get stuck. Provide your analysis as a single, clean JSON object with two keys: `critical_mistake_analysis` and `alternative_action_suggestion`.\n" + "- `critical_mistake_analysis`: A concise sentence explaining what the final wrong move was and why it was wrong, based on the final outcome. Example: 'The final action 'go north' led to a dead end, wasting a turn when exploring 'east' towards the Mayor's home was a more promising option based on the description.'\n" + "- `alternative_action_suggestion`: A concrete, actionable suggestion for what to do instead of the mistaken action, phrased as a general rule for the future. Example: 'When at the 'Outside' location after moving west, the next logical step is to explore 'east' to investigate the Mayor's home mentioned in the description.'\n\n" + "Respond ONLY with the JSON object, without any surrounding text or explanations.\n" + ) + + gen_config = GenerationConfig( + temperature=0.2, top_p=0.9, do_sample=True, max_new_tokens=250, + pad_token_id=self.tokenizer.eos_token_id + ) + + content = self._generate_text(prompt, gen_config) + + reflections = {'critical_mistake_analysis': '', 'alternative_action_suggestion': ''} + try: + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + json_str = json_match.group(0).replace('\n', ' ').replace('\r', '').strip() + parsed_json = json.loads(json_str) + reflections['critical_mistake_analysis'] = parsed_json.get('critical_mistake_analysis', '').strip() + reflections['alternative_action_suggestion'] = parsed_json.get('alternative_action_suggestion', '').strip() + else: + logging.warning(f"Could not find a JSON object in the LLM's reflection output. Content: {content}") + except json.JSONDecodeError as e: + logging.error(f"Failed to decode JSON from reflection output: {e}. Content: {content}") + return reflections + +# ===================================================================================== +# 3. Poe API 实现 (同样应用优化, 此处省略,修改逻辑与 QwenLocalPolicy 完全相同) +# ===================================================================================== +class PoeAPIPolicy(BaseLLMPolicy): # ... (内部逻辑与QwenLocalPolicy的修改相同) + pass # 假设已按上述逻辑修改 + +# ===================================================================================== +# 主程序修改 +# ===================================================================================== +if __name__ == '__main__': + """ + export CUDA_VISIBLE_DEVICES=6 + torchrun --nproc_per_node=1 --master-port 20092 /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/jericho/envs/test_qwen_v8.py + """ + # --- 核心配置区 --- + LLM_PROVIDER = "Qwen" + + LLM_CONFIGS = { + "Qwen": { + "model_path": "/fs-computility/niuyazhe/shared/xiongjyu/model/Qwen2.5-7B-Instruct", + "model_name": "Qwen2.5-7B-Instruct", + }, + "PoeAPI": { + "model_name": "GPT-3.5-Turbo", + "api_key": "YOUR_POE_API_KEY", # 请替换为您的密钥 + "base_url": "https://api.poe.com/v1" + } + } + + num_episodes = 20 + max_history_len = 10 + good_trial_memory_maxlen = 20 + bad_trial_memory_maxlen = 20 + + # --- 游戏和实验设置 --- + USE_AUTONOMOUS_REFLECTION = True + USE_STRUCTURED_INFO = True + ENV_TYPE = 'detective' + # ENV_TYPE = 'zork1' + + + # --- 【优化】相似度阈值 --- + SIMILARITY_THRESHOLD = 0.9 # 可以适当提高阈值,因为反思质量更高了 + + # --- 初始化 --- + rank, world_size = init_distributed() + print(f"[RANK {rank}] Initialized. World size: {world_size}") + + # --- 实例化策略对象 --- + llm_policy: BaseLLMPolicy + current_config = LLM_CONFIGS[LLM_PROVIDER] + + print(f"Using LLM Provider: {LLM_PROVIDER}") + if LLM_PROVIDER == "Qwen": + llm_policy = QwenLocalPolicy( + model_path=current_config["model_path"], + local_rank=rank + ) + elif LLM_PROVIDER == "PoeAPI": + # llm_policy = PoeAPIPolicy(...) # 省略 + pass + else: + raise ValueError(f"Unknown LLM_PROVIDER: {LLM_PROVIDER}") + + env_cfg = EasyDict( + dict( + max_steps=100, # for detective + # max_steps=400, # for zork1 + game_path="./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/" + f"{ENV_TYPE}.z5", + add_location_and_inventory=USE_STRUCTURED_INFO, + max_action_num=55, # TODO + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512, + remove_stuck_actions=False, + for_unizero=False, + collector_env_num=1, + evaluator_env_num=1, + save_replay=False, + save_replay_path=None, + env_type=ENV_TYPE, + collect_policy_mode='expert' + ) + ) + + log_dir = f'./priorzero_log_optimized/{current_config["model_name"]}/{ENV_TYPE}_structured_{USE_STRUCTURED_INFO}_auto_reflect_{USE_AUTONOMOUS_REFLECTION}_v2' + os.makedirs(log_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join(log_dir, f"rank_{rank}_{timestamp}.txt") + f = open(log_file, "w", encoding="utf-8") + + env = JerichoEnv(env_cfg) + + good_trial_memory = MemoryManager(maxlen=good_trial_memory_maxlen, similarity_threshold=SIMILARITY_THRESHOLD) + bad_trial_memory = MemoryManager(maxlen=bad_trial_memory_maxlen, similarity_threshold=SIMILARITY_THRESHOLD) + + total_scores = [] + + for episode_id in range(num_episodes): + f.write(f"{'='*80}\n") + f.write(f"STARTING EPISODE: {episode_id} | Good Memories: {len(good_trial_memory)}, Bad Memories: {len(bad_trial_memory)}\n") + f.write(f"Current Good Memories: {good_trial_memory.get_memories()}\n") + f.write(f"Current Bad Memories: {bad_trial_memory.get_memories()}\n") + f.write(f"CONFIG: USE_STRUCTURED_INFO = {USE_STRUCTURED_INFO}, USE_AUTONOMOUS_REFLECTION = {USE_AUTONOMOUS_REFLECTION}\n") + f.write(f"{'='*80}\n") + f.flush() + + obs_or_dict = env.reset(return_str=True) + done = False + step_count = 0 + episode_reward = 0 + episode_history = [] + last_actions = deque(maxlen=4) + + while not done: + if isinstance(obs_or_dict, dict): + full_scene_description = obs_or_dict.get('observation', '') + inventory_str = obs_or_dict.get('inventory', 'not tracked') + location_str = obs_or_dict.get('location', 'not tracked') + else: + full_scene_description = obs_or_dict + inventory_str = "not tracked" + location_str = "not tracked" + try: + inventory_str = env._env.get_inventory() + location_str = env._env.get_player_location() + except: + pass + + valid_actions = env._env.get_valid_actions() + + if len(last_actions) == 4 and last_actions[0] == last_actions[2] and last_actions[1] == last_actions[3]: + # TODO + f.write("[SYSTEM] Stuck in a 2-step loop. Forcing 'look' to re-evaluate.\n") + if 'look' in valid_actions: + action = 'look' + prompt = "[SYSTEM] Stuck in a loop. Forcing 'look' to re-evaluate." + else: + non_loop_actions = [a for a in valid_actions if a not in last_actions] + if non_loop_actions: + action = np.random.choice(non_loop_actions) + prompt = f"[SYSTEM] Stuck in a loop and 'look' is unavailable. Forcing random exploration: {action}." + else: + action = valid_actions[0] + prompt = f"[SYSTEM] Stuck in a loop and no other options. Forcing: {action}." + else: + action, prompt = llm_policy.sample_action_v2( + history=list(episode_history)[-max_history_len:], + full_scene_description=full_scene_description, + inventory=inventory_str, + location=location_str, + valid_actions=valid_actions, + good_memory=good_trial_memory.get_memories(), + bad_memory=bad_trial_memory.get_memories(), + use_structured_info=USE_STRUCTURED_INFO + ) + + last_actions.append(action) + + next_obs_or_dict, reward, done, info = env.step(action, return_str=True) + episode_reward += reward + + if isinstance(next_obs_or_dict, dict): + next_obs_str = next_obs_or_dict.get('observation', '') + else: + next_obs_str = next_obs_or_dict + + history_entry = {'action': action, 'obs': next_obs_str} + episode_history.append(history_entry) + + # 【修改 3】: 增加“好记忆”的生成逻辑 + if reward > 0: + good_memory_suggestion = f"At the location '{location_str}', performing the action '{action}' was successful and yielded a positive reward." + log_msg = good_trial_memory.add_memory(good_memory_suggestion) + f.write(f"[GOOD MEMORY UPDATE]: {log_msg}\n") + print(f"[GOOD MEMORY UPDATE]: {log_msg}") + + + f.write(f"--- Step {step_count} ---\n") + f.write(f"[Prompt Sent to LLM]:\n==============================================\n{prompt}\n") + f.write(f"==============================================\n") + f.write(f"[LLM Chose Action]: {action}\n") + if done: + f.write(f"[Final Observation]:\n{next_obs_str}\n") + else: + f.write(f"[Next Observation]:\n{next_obs_str}\n") + + f.write(f"---------------------------------\n") + f.write(f"[Reward]: {reward}, [Done]: {done}, [Total Score]: {episode_reward}\n") + f.write(f"---------------------------------\n\n") + f.flush() + + obs_or_dict = next_obs_or_dict + step_count += 1 + if step_count >= env_cfg.max_steps: + done = True + + final_score = info.get('eval_episode_return', episode_reward) + total_scores.append(final_score) + + # 【修改 5】: 修改反思逻辑,只在分数低时进行,并使用新的反思结果 + if USE_AUTONOMOUS_REFLECTION: # 只在分数低时反思 + f.write("[Reflection Mode]: Autonomous\n") + print("[Reflection Mode]: Autonomous") + + # 使用新的、更具针对性的反思函数 + reflections = llm_policy.generate_autonomous_reflection( + episode_history, final_score=final_score, use_structured_info=USE_STRUCTURED_INFO + ) + + f.write(f"Reflections JSON from LLM: {reflections}\n") + print(f"Reflections JSON from LLM: {reflections}") + + # 我们现在使用更有意义的“替代行动建议”作为坏记忆 + bad_reflection_suggestion = reflections.get('alternative_action_suggestion') + + if bad_reflection_suggestion: + log_msg = bad_trial_memory.add_memory(bad_reflection_suggestion) + f.write(f"[BAD MEMORY UPDATE]: {log_msg}\n") + print(f"[BAD MEMORY UPDATE]: {log_msg}") + else: + f.write(f"[BAD MEMORY UPDATE]: LLM did not provide a valid suggestion.\n") + print(f"[BAD MEMORY UPDATE]: LLM did not provide a valid suggestion.") + + f.write(f"Episode {episode_id} finished. Final Score: {final_score}\n") + print(f"Episode {episode_id} finished. Final Score: {final_score}") + + f.write(f"\n\n{'='*80}\n") + f.write(f"All episodes finished. Total Scores: {total_scores}\n") + f.write(f"Average score: {np.mean(total_scores)}\n") + f.close() + print(f"[RANK {rank}] Finished. Log written to {log_file}") + del env + + if dist.is_initialized(): + dist.destroy_process_group() + print(f"[RANK {rank}] Process group destroyed.") \ No newline at end of file diff --git a/zoo/jericho/envs/test_qwen_v9.py b/zoo/jericho/envs/test_qwen_v9.py new file mode 100644 index 000000000..48fddbcd1 --- /dev/null +++ b/zoo/jericho/envs/test_qwen_v9.py @@ -0,0 +1,596 @@ +# -*- coding: utf-8 -*- +""" +使用大型语言模型(LLM)作为智能体来玩 Jericho 文本冒险游戏的脚本。 + +该脚本实现了以下核心功能: +1. 一个基于 LLM 的策略智能体,能够理解游戏状态并做出决策。 +2. 一个记忆管理器(MemoryManager),用于存储和管理游戏过程中的“好记忆”(成功策略)和“坏记忆”(失败教训),并使用语义相似度来避免冗余。 +3. 一个自主反思(Autonomous Reflection)机制,在每轮游戏结束后,让 LLM 分析游戏过程,总结失败原因并提炼出可操作的经验。 +4. 支持多种 LLM 后端,例如本地部署的 Qwen 模型。 +5. 通过命令行参数高度可配置,方便进行实验和调整。 + +如何运行: +- 使用 torchrun 进行分布式训练(即使是单节点单GPU)。 +- 示例命令: + export CUDA_VISIBLE_DEVICES=1 + cd /fs-computility/niuyazhe/puyuan/code/LightZero + torchrun --nproc_per_node=1 --master-port 20092 /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/jericho/envs/test_qwen_v9.py \ + --model-path /fs-computility/niuyazhe/shared/xiongjyu/model/Qwen2.5-7B-Instruct \ + --env-type detective \ + --max-steps 100 \ + --num-episodes 20 \ + --good-memory-maxlen 20 \ + --temperature 0.01 \ + --log-dir ./priorzero_log_optimized +""" + +import logging +import copy +import os +import json +import re +import argparse +from datetime import datetime +from typing import Any, Dict, List, Optional, Union +from easydict import EasyDict +from collections import deque +from abc import ABC, abstractmethod + +import numpy as np +import torch +import torch.distributed as dist +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from sentence_transformers import SentenceTransformer +import torch.nn.functional as F + +from jericho_env import JerichoEnv + + +# ===================================================================================== +# 1. 记忆管理模块 (MemoryManager) +# ===================================================================================== +class MemoryManager: + """ + 管理智能体的长期记忆,包括成功经验(好记忆)和失败教训(坏记忆)。 + + 该类使用 SentenceTransformer 来计算记忆文本的嵌入向量,并实现了一套 + 基于余弦相似度的驱逐策略,以防止记忆库中出现过多语义重复的内容。 + """ + def __init__(self, maxlen: int, similarity_threshold: float = 0.85, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): + """ + 初始化 MemoryManager. + + Args: + maxlen (int): 记忆库的最大容量。 + similarity_threshold (float): 用于判断记忆是否冗余的相似度阈值。 + device (str): 用于计算嵌入向量的设备 ('cuda' 或 'cpu')。 + """ + self.maxlen = maxlen + self.similarity_threshold = similarity_threshold + self.device = device + self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) + self.memories: List[str] = [] + self.embeddings: Optional[torch.Tensor] = None + logging.info(f"MemoryManager initialized with maxlen={maxlen}, threshold={similarity_threshold} on device='{self.device}'") + + def get_memories(self) -> List[str]: + """返回当前所有的记忆文本。""" + return self.memories + + def __len__(self) -> int: + """返回当前记忆的数量。""" + return len(self.memories) + + def add_memory(self, new_memory_text: str) -> str: + """ + 添加一条新记忆,并根据冗余检查和容量限制进行管理。 + + 该方法首先检查新记忆是否与现有记忆高度相似,如果相似度超过阈值,则拒绝添加。 + 如果记忆库已满,则采用先进先出(FIFO)策略移除最旧的记忆。 + + Args: + new_memory_text (str): 要添加的新记忆文本。 + + Returns: + str: 描述本次操作的日志信息。 + """ + if not new_memory_text or not isinstance(new_memory_text, str): + return "Skipped adding empty or invalid memory." + + with torch.no_grad(): + new_embedding = self.model.encode(new_memory_text, convert_to_tensor=True, device=self.device) + + # 检查新记忆是否与任何现有记忆过于相似 + if self.embeddings is not None and len(self.memories) > 0: + similarities = F.cosine_similarity(new_embedding, self.embeddings) + max_similarity = torch.max(similarities).item() + if max_similarity > self.similarity_threshold: + return f"Rejected redundant memory (sim: {max_similarity:.2f} > {self.similarity_threshold}). New: '{new_memory_text}'" + + # 如果记忆库未满,直接添加 + if len(self.memories) < self.maxlen: + self.memories.append(new_memory_text) + if self.embeddings is None: + self.embeddings = new_embedding.unsqueeze(0) + else: + self.embeddings = torch.cat([self.embeddings, new_embedding.unsqueeze(0)], dim=0) + return f"Added new memory: '{new_memory_text}'" + + # 如果记忆库已满,执行FIFO替换 + else: + removed_memory = self.memories.pop(0) + self.memories.append(new_memory_text) + + # 更新嵌入张量 + self.embeddings = torch.cat([self.embeddings[1:], new_embedding.unsqueeze(0)], dim=0) + return (f"FIFO eviction. " + f"Removed: '{removed_memory}', Added: '{new_memory_text}'") + +# ===================================================================================== +# 2. 分布式环境初始化 +# ===================================================================================== +def init_distributed(): + """初始化 PyTorch 分布式环境。""" + if not dist.is_initialized(): + dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo') + rank = dist.get_rank() + world_size = dist.get_world_size() + return rank, world_size + +# ===================================================================================== +# 3. LLM 策略抽象基类 +# ===================================================================================== +class BaseLLMPolicy(ABC): + """ + 所有 LLM 策略的抽象基类,定义了通用接口。 + """ + def __init__(self, reflection_history_len=9999): + self.reflection_history_len = reflection_history_len + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + def format_prompt_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None) -> str: + """ + 构建用于动作选择的 Prompt。 + + 该 Prompt 结构清晰,包含系统指令、战略记忆、近期历史、当前观察和决策指令, + 旨在引导 LLM 做出更具策略性的决策。 + + Args: + history: 近期的动作和观察历史。 + full_scene_description: 当前的完整场景描述。 + inventory: 当前的物品栏。 + location: 当前的地点。 + valid_actions: 当前可用的动作列表。 + good_memory: 成功策略列表。 + bad_memory: 需要避免的错误列表。 + + Returns: + str: 格式化后的完整 Prompt 字符串。 + """ + system_prompt = """You are an intelligent agent playing a text-based adventure game. Your primary goal is to solve puzzles, uncover the story, and maximize your score. +Think step-by-step. First, analyze the Current Obeservation and Recent Game History. Second, review your Strategic Memory. Third, decide on the single best action to make progress. +""" + + system_prompt += "\n--- Strategic Memory ---\n" + if good_memory: + system_prompt += "✅ Successful Strategies (Good Memories):\n" + "".join(f"- {mem}\n" for mem in good_memory) + else: + system_prompt += "✅ No successful strategies recorded yet.\n" + + if bad_memory: + system_prompt += "❌ Mistakes to Avoid (Bad Memories):\n" + "".join(f"- {mem}\n" for mem in bad_memory) + else: + system_prompt += "❌ No mistakes recorded yet.\n" + + system_prompt += "\n--- Recent Game History ---\n" + + if not history: + system_prompt += "The game has just begun.\n" + else: + history_lines = [] + for i, h in enumerate(history): + loc_info = "" + if i < len(history)-1: + history_lines.append(f"{loc_info}YOU DID ACTION: > {h['action']}\nTHEN YOU OBSERVED: {h['obs']}") + else: + history_lines.append(f"{loc_info}YOU DID ACTION: > {h['action']}\nTHEN YOU OBSERVED:") + + system_prompt += "\n".join(history_lines) + "\n" + + system_prompt += f"\n--- Current Obeservation ---\n{full_scene_description}\n" + + system_prompt += ( + "\n--- Your Decision ---\n" + "Based on the Current Obeservation, Recent Game History and Strategic Memory, what is the single most promising next action to progress? " + "If a past strategy is listed in 'Mistakes to Avoid', you MUST NOT repeat it. If a strategy is in 'Successful Strategies', consider applying a similar logic.\n" + f"Choose ONLY ONE from the following valid actions: {', '.join(valid_actions)}\n" + "Respond with the action phrase only.\n" + "> " + ) + return system_prompt + + @abstractmethod + def sample_action_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None) -> tuple[str, str]: + """ + 根据当前状态采样一个动作。这是一个抽象方法,需要由子类实现。 + + Returns: + tuple[str, str]: (选择的动作, 用于生成的 Prompt) + """ + pass + + + @abstractmethod + def generate_autonomous_reflection(self, history: List[Dict[str, str]], final_score: float) -> Dict[str, str]: + """ + 生成自主反思。这是一个抽象方法,需要由子类实现。 + + Returns: + Dict[str, str]: 包含 'critical_mistake_analysis' 和 'alternative_action_suggestion' 的字典。 + """ + pass + +# ===================================================================================== +# 4. 本地 Qwen 模型策略实现 +# ===================================================================================== +class QwenLocalPolicy(BaseLLMPolicy): + """ + 使用本地部署的 Qwen 系列模型的策略实现。 + """ + def __init__(self, model_path: str, reflection_history_len: int = 9999, local_rank: int = 0): + super().__init__(reflection_history_len) + self.device = torch.device(f"cuda:{local_rank}") + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to(self.device) + print(f"QwenLocalPolicy initialized on device {self.device}") + + def _generate_text(self, prompt: str, gen_config: GenerationConfig) -> str: + """内部方法,用于调用模型生成文本。""" + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + + output = self.model.generate(**model_inputs, generation_config=gen_config) + output_ids = output[0][len(model_inputs.input_ids[0]):].tolist() + content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip() + return content + + def sample_action_v2(self, history: List[Dict[str, str]], full_scene_description: str, inventory: str, location: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None) -> tuple[str, str]: + """ + 实现动作采样逻辑。 + + 首先生成 Prompt,然后调用 LLM 生成动作文本。之后,将生成的文本与 + 有效的动作列表进行匹配,以找出最合适的动作。如果匹配失败,则选择 + 一个安全动作(如 'look')或列表中的第一个动作作为后备。 + """ + prompt = self.format_prompt_v2(history, full_scene_description, inventory, location, valid_actions, good_memory, bad_memory) + + gen_config = GenerationConfig( + temperature=TEMPERATURE, top_p=0.9, do_sample=True, max_new_tokens=20, + pad_token_id=self.tokenizer.eos_token_id + ) + + content = self._generate_text(prompt, gen_config).lower() + + # 优先匹配更长、更具体的动作 + sorted_actions = sorted(valid_actions, key=len, reverse=True) + for va in sorted_actions: + cleaned_va = va.lower().strip() + cleaned_content = content.replace(">", "").strip() + if cleaned_va in cleaned_content or cleaned_content in cleaned_va: + return va, prompt + + # 后备安全动作 + safe_actions = ['look', 'inventory', 'l', 'i'] + for sa in safe_actions: + if sa in valid_actions: + return sa, prompt + + return (valid_actions[0] if valid_actions else 'look'), prompt + + def generate_autonomous_reflection(self, history: List[Dict[str, str]], final_score: float) -> Dict[str, str]: + """ + 实现自主反思逻辑。 + + 构建一个专门的 Prompt,要求 LLM 分析游戏轨迹,特别是最后几步, + 找出导致游戏结束或停滞的关键错误,并以 JSON 格式返回分析和改进建议。 + """ + recent_history = history[-self.reflection_history_len:] + # 确保历史记录包含最后一步和最终结果 + # TODO: test inpact, 加入每步的reward + trajectory_str = "\n".join([f"Step {i+1}: YOU DID ACTION: > {h['action']}', THEN YOU OBSERVED: \n{h['obs']}\n" for i, h in enumerate(recent_history)]) + + # trajectory_str = "\n".join([f"STEP {i+1}: ACTION: > {h['action']}', OBSERVATION: \n{h['obs']}\n" for i, h in enumerate(recent_history)]) + # trajectory_str = "\n".join([f"Step {i+1}: I did '> {h['action']}' and the outcome was:\n{h['obs']}\n" for i, h in enumerate(recent_history)]) + + # TODO: test inpact of different prompts + # 20eps reward_mean:110.5 Total Scores: [array([60.]), array([140.]), array([90.]), array([90.]), array([140.]), array([160.]), array([90.]), array([90.]), array([90.]), array([140.]), array([160.]), array([90.]), array([90.]), array([140.]), array([90.]), array([140.]), array([140.]), array([90.]), array([90.]), array([90.])] + prompt = ( + "You are an expert strategy analyst for a text-based adventure game. Your task is to analyze a game trajectory, identify the critical mistake that led to a low score or game over, and produce actionable advice.\n\n" + f"--- Final Score ---\n{final_score}\n\n" + f"--- Full Gameplay Log of the Final Attempt ---\n{trajectory_str}\n" + "--- Analysis Task ---\n" + "Analyze the final steps of the gameplay log. Pinpoint the single critical mistake that ended the game or caused it to get stuck. Provide your analysis as a single, clean JSON object with two keys: `critical_mistake_analysis` and `alternative_action_suggestion`.\n" + # "- `critical_mistake_analysis`: A concise sentence explaining what the final wrong move was and why it was wrong, based on the final outcome. Example: 'The final action 'go north' led to a dead end, wasting a turn when exploring 'east' towards the Mayor's home was a more promising option based on the description.'\n" + # "- `alternative_action_suggestion`: A concrete, actionable suggestion for what to do instead of the mistaken action, phrased as a general rule for the future. Example: 'When at the 'Outside' location after moving west, the next logical step is to explore 'east' to investigate the Mayor's home mentioned in the description.'\n\n" + "- `critical_mistake_analysis`: A concise sentence explaining what the final wrong move was and why it was wrong, based on the final outcome. Example: 'The final action 'north' led to a dead end, while exploring 'east' towards the Mayor's home was a more promising option.'\n" + "- `alternative_action_suggestion`: A concrete, actionable suggestion for what to do instead of the mistaken action, phrased as a general rule for the future. Example: 'When at the 'Outside' location after moving west, the next promising step is to explore 'east' to investigate the Mayor's home.'\n\n" + "Respond ONLY with the JSON object, without any surrounding text or explanations.\n" + ) + # TODO(pu): 为什么提示词差一点,性能就差很多? + + # prompt = ( + # "You are an expert strategy analyst for a text-based adventure game. Your task is to analyze a game trajectory, identify the critical mistake that led to a low score or game over, and produce actionable advice.\n\n" + # f"--- Final Score ---\n{final_score}\n\n" + # f"--- Full Gameplay Log of the Final Attempt ---\n{trajectory_str}\n" + # "--- Analysis Task ---\n" + # "Analyze the final steps of the gameplay log. Pinpoint the single critical mistake that ended the game or caused it to get stuck. Provide your analysis as a single, clean JSON object with two keys: `critical_mistake_analysis` and `alternative_action_suggestion`.\n" + # "- `critical_mistake_analysis`: A concise sentence explaining what the final wrong move was and why it was wrong, based on the final outcome. Example: 'The ACTION in the Location repeatedly led to the player getting stuck, as there was no useful purpose for the paper in that location, and it prevented further exploration. \n" + # "- `alternative_action_suggestion`: A concrete, actionable suggestion for what to do instead of the mistaken action, phrased as a general rule for the future. Example: 'When at the Location , avoid taking or interacting with items unless there is a clear purpose or hint that it will lead to progress. Instead, explore other directions or examine the environment for clues.'\n\n" + # "Respond ONLY with the JSON object, without any surrounding text or explanations.\n" + # ) + + # prompt = f""" + # # Role and Mission + # You are an expert strategy analyst for a text-based adventure game. + # Your mission is to analyze the gameplay log provided below, identify the **single critical mistake** that resulted in a low score or a game-over state, and provide specific, actionable advice for improvement. + # --- + # # Game Data + # ## Final Score + # {final_score} + # ## Full Gameplay Log of the Final Attempt + # {trajectory_str} + # --- + # # Analysis Task + # Analyze the final steps of the gameplay log to pinpoint the **one critical mistake** that halted progress or ended the game. + # Your analysis must be delivered as a **single, clean JSON object**, containing only the following two keys: `critical_mistake_analysis` and `alternative_action_suggestion`. + # ## JSON Structure Definition + # 1. `critical_mistake_analysis`: (Type: String) + # - **Content Requirement**: A concise sentence explaining **what the final wrong move was** and **why it was wrong**, based on the game's outcome. + # - **Good Example**: "After finding the key, the player's final command 'go north' was a redundant and ineffective move because the logical next step was to use the key with the command 'unlock door with key'." + # 2. `alternative_action_suggestion`: (Type: String) + # - **Content Requirement**: A concrete, actionable suggestion for what to do instead. This advice should be framed as a **general rule for future gameplay**. + # - **Good Example**: "When a new item (like a key) is acquired, immediately consider its purpose and try related interaction commands (e.g., 'unlock door') rather than repeating previously failed actions." + # --- + # # Output Format Requirement + # **Strictly adhere to this rule**: Your response must **ONLY** be the JSON object itself. It must **NOT** be enclosed in markdown code blocks (like ```json ... ```) or contain any surrounding text, explanations, or comments. + # """ + + + gen_config = GenerationConfig( + temperature=TEMPERATURE, top_p=0.9, do_sample=True, max_new_tokens=250, + pad_token_id=self.tokenizer.eos_token_id + ) + + content = self._generate_text(prompt, gen_config) + + reflections = {'critical_mistake_analysis': '', 'alternative_action_suggestion': ''} + try: + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + json_str = json_match.group(0).replace('\n', ' ').replace('\r', '').strip() + parsed_json = json.loads(json_str) + reflections['critical_mistake_analysis'] = parsed_json.get('critical_mistake_analysis', '').strip() + reflections['alternative_action_suggestion'] = parsed_json.get('alternative_action_suggestion', '').strip() + else: + logging.warning(f"Could not find a JSON object in the LLM's reflection output. Content: {content}") + except json.JSONDecodeError as e: + logging.error(f"Failed to decode JSON from reflection output: {e}. Content: {content}") + return reflections + +# ===================================================================================== +# 5. 参数解析 +# ===================================================================================== +def parse_args(): + """解析命令行参数。""" + parser = argparse.ArgumentParser(description="Run Jericho text-based game with an LLM agent.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # --- LLM 和模型配置 --- + parser.add_argument('--llm-provider', type=str, default='Qwen', choices=['Qwen', 'PoeAPI'], help='The LLM provider to use.') + parser.add_argument('--model-path', type=str, default='/fs-computility/niuyazhe/shared/xiongjyu/model/Qwen2.5-7B-Instruct', help='Path to the local LLM model (e.g., for Qwen).') + parser.add_argument('--model-name', type=str, default='Qwen2.5-7B-Instruct', help='Identifier for the model, used for logging.') + + # --- 实验和游戏配置 --- + parser.add_argument('--env-type', type=str, default='detective', help='The Jericho game to play (e.g., zork1, detective).') + parser.add_argument('--num-episodes', type=int, default=20, help='Number of episodes to run.') + parser.add_argument('--max-steps', type=int, default=100, help='Maximum number of steps per episode.') + parser.add_argument('--log-dir', type=str, default='./priorzero_log_optimized', help='Base directory for saving logs.') + + # --- 智能体行为配置 --- + parser.add_argument('--use-autonomous-reflection', action=argparse.BooleanOptionalAction, default=True, help='Enable or disable autonomous reflection after each episode.') + parser.add_argument('--use-structured-info', action=argparse.BooleanOptionalAction, default=True, help='Provide structured info (location, inventory) to the LLM.') + parser.add_argument('--max-history-len', type=int, default=10, help='Maximum number of recent history steps to include in the prompt.') + + # --- 记忆系统配置 --- + parser.add_argument('--good-memory-maxlen', type=int, default=20, help='Maximum capacity for good memories.') + parser.add_argument('--bad-memory-maxlen', type=int, default=20, help='Maximum capacity for bad memories.') + parser.add_argument('--similarity-threshold', type=float, default=0.9, help='Similarity threshold for memory eviction.') + parser.add_argument('--temperature', type=float, default=0.01, help='temperature in llm.') + + + return parser.parse_args() + +# ===================================================================================== +# 6. 主执行逻辑 +# ===================================================================================== +def main(args): + """主执行函数,包含游戏循环和智能体交互。""" + # --- 初始化分布式环境和日志 --- + rank, world_size = init_distributed() + print(f"[RANK {rank}] Initialized. World size: {world_size}") + + log_dir = os.path.join(args.log_dir, args.model_name, f"{args.env_type}_structured_{args.use_structured_info}_auto_reflect_{args.use_autonomous_reflection}_v2") + os.makedirs(log_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join(log_dir, f"rank_{rank}_{timestamp}.txt") + f = open(log_file, "w", encoding="utf-8") + + global TEMPERATURE;TEMPERATURE=args.temperature + + # --- 实例化策略对象 --- + llm_policy: BaseLLMPolicy + print(f"Using LLM Provider: {args.llm_provider}") + if args.llm_provider == "Qwen": + llm_policy = QwenLocalPolicy( + model_path=args.model_path, + local_rank=rank + ) + # elif args.llm_provider == "PoeAPI": + # # PoeAPIPolicy 的实现可以放在这里 + # raise NotImplementedError("PoeAPIPolicy is not fully implemented in this example.") + else: + raise ValueError(f"Unknown LLM_PROVIDER: {args.llm_provider}") + + # --- 配置并初始化游戏环境 --- + env_cfg = EasyDict( + dict( + max_steps=args.max_steps, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{args.env_type}.z5", + add_location_and_inventory=args.use_structured_info, + env_type=args.env_type, + # 以下为保持原有配置的参数 + max_action_num=55, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512, + remove_stuck_actions=False, + for_unizero=False, + collector_env_num=1, + evaluator_env_num=1, + save_replay=False, + save_replay_path=None, + collect_policy_mode='expert' + ) + ) + env = JerichoEnv(env_cfg) + + # --- 初始化记忆管理器 --- + good_trial_memory = MemoryManager(maxlen=args.good_memory_maxlen, similarity_threshold=args.similarity_threshold) + bad_trial_memory = MemoryManager(maxlen=args.bad_memory_maxlen, similarity_threshold=args.similarity_threshold) + + total_scores = [] + + # --- 游戏主循环 --- + for episode_id in range(args.num_episodes): + f.write(f"{'='*80}\n") + f.write(f"STARTING EPISODE: {episode_id} | Good Memories: {len(good_trial_memory)}, Bad Memories: {len(bad_trial_memory)}\n") + f.write(f"Current Good Memories: {good_trial_memory.get_memories()}\n") + f.write(f"Current Bad Memories: {bad_trial_memory.get_memories()}\n") + f.write(f"CONFIG: USE_STRUCTURED_INFO = {args.use_structured_info}, USE_AUTONOMOUS_REFLECTION = {args.use_autonomous_reflection}\n") + f.write(f"{'='*80}\n") + f.flush() + + obs_or_dict = env.reset(return_str=True) + done = False + step_count = 0 + episode_reward = 0 + episode_history = [] + last_actions = deque(maxlen=4) + + while not done: + # --- 解析观察值 --- + if isinstance(obs_or_dict, dict): + full_scene_description = obs_or_dict.get('observation', '') + inventory_str = obs_or_dict.get('inventory', 'not tracked') + location_str = obs_or_dict.get('location', 'not tracked') + else: + full_scene_description = obs_or_dict + inventory_str, location_str = "not tracked", "not tracked" + try: + inventory_str = env._env.get_inventory() + location_str = env._env.get_player_location() + except Exception: + pass + + valid_actions = env._env.get_valid_actions() + + # --- 动作选择:处理循环并调用 LLM --- + # TODO(pu): + # if len(last_actions) == 4 and last_actions[0] == last_actions[2] and last_actions[1] == last_actions[3]: + # # 20eps return_mean=102.5 + # f.write("[SYSTEM] Stuck in a 2-step loop. Forcing 'look' to re-evaluate.\n") + # action, prompt = ('look', "[SYSTEM] Stuck in a loop. Forcing 'look'.") if 'look' in valid_actions else (np.random.choice([a for a in valid_actions if a not in last_actions] or valid_actions), "[SYSTEM] Stuck in a loop. Forcing random action.") + # else: # 20eps return_mean=154.0 + action, prompt = llm_policy.sample_action_v2( + history=list(episode_history)[-args.max_history_len:], + full_scene_description=full_scene_description, + inventory=inventory_str, + location=location_str, + valid_actions=valid_actions, + good_memory=good_trial_memory.get_memories(), + bad_memory=bad_trial_memory.get_memories() + ) + + last_actions.append(action) + + # --- 与环境交互 --- + next_obs_or_dict, reward, done, info = env.step(action, return_str=True) + episode_reward += reward + + next_obs_str = next_obs_or_dict.get('observation', '') if isinstance(next_obs_or_dict, dict) else next_obs_or_dict + episode_history.append({'action': action, 'obs': next_obs_str}) + + # --- 更新“好记忆” --- + if reward > 0: + good_memory_suggestion = f"At the location '{location_str}', performing the action '{action}' was successful and yielded a positive reward." + log_msg = good_trial_memory.add_memory(good_memory_suggestion) + f.write(f"[GOOD MEMORY UPDATE]: {log_msg}\n") + print(f"[GOOD MEMORY UPDATE]: {log_msg}") + + # --- 记录日志 --- + f.write(f"--- Step {step_count} ---\n") + f.write(f"[Prompt Sent to LLM]:\n==============================================\n{prompt}\n") + f.write(f"==============================================\n") + f.write(f"[LLM Chose Action]: {action}\n") + if done: + f.write(f"[Final Observation]:\n{next_obs_str}\n") + else: + f.write(f"[Next Observation]:\n{next_obs_str}\n") + f.write(f"---------------------------------\n") + f.write(f"[Reward]: {reward}, [Done]: {done}, [Total Score]: {episode_reward}\n") + f.write(f"---------------------------------\n\n") + f.flush() + + obs_or_dict = next_obs_or_dict + step_count += 1 + if step_count >= env_cfg.max_steps: + done = True + + final_score = info.get('eval_episode_return', episode_reward) + total_scores.append(final_score) + + # --- 回合结束后的自主反思 --- + if args.use_autonomous_reflection: + f.write("[Reflection Mode]: Autonomous\n") + print("[Reflection Mode]: Autonomous") + + reflections = llm_policy.generate_autonomous_reflection(episode_history, final_score=final_score) + f.write(f"Reflections JSON from LLM: {reflections}\n") + print(f"Reflections JSON from LLM: {reflections}") + + bad_reflection_suggestion = reflections.get('alternative_action_suggestion') + if bad_reflection_suggestion: + log_msg = bad_trial_memory.add_memory(bad_reflection_suggestion) + f.write(f"[BAD MEMORY UPDATE]: {log_msg}\n") + print(f"[BAD MEMORY UPDATE]: {log_msg}") + else: + f.write(f"[BAD MEMORY UPDATE]: LLM did not provide a valid suggestion.\n") + print(f"[BAD MEMORY UPDATE]: LLM did not provide a valid suggestion.") + + f.write(f"Episode {episode_id} finished. Final Score: {final_score}\n") + print(f"Episode {episode_id} finished. Final Score: {final_score}") + + # --- 实验结束,清理和总结 --- + f.write(f"\n\n{'='*80}\n") + f.write(f"All episodes finished. Total Scores: {total_scores}\n") + f.write(f"Average score: {np.mean(total_scores)}\n") + f.close() + print(f"[RANK {rank}] Finished. Log written to {log_file}") + del env + + if dist.is_initialized(): + dist.destroy_process_group() + print(f"[RANK {rank}] Process group destroyed.") + +if __name__ == '__main__': + args = parse_args() + main(args) \ No newline at end of file