Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions backend/app/api/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,17 @@ def progress_callback(stage, progress, message, **kwargs):
progress_callback=progress_callback,
parallel_profile_count=parallel_profile_count
)

# 任务完成
task_manager.complete_task(
task_id,
result=result_state.to_simple_dict()
)

if result_state.status == SimulationStatus.FAILED:
task_manager.fail_task(
task_id,
result_state.error or "模拟准备失败"
)
else:
task_manager.complete_task(
task_id,
result=result_state.to_simple_dict()
)

except Exception as e:
logger.error(f"准备模拟失败: {str(e)}")
Expand Down Expand Up @@ -1095,6 +1100,8 @@ def get_simulation_profiles_realtime(simulation_id: str):
# 检查是否正在生成(通过 state.json 判断)
is_generating = False
total_expected = None
status = None
error = None

state_file = os.path.join(sim_dir, "state.json")
if os.path.exists(state_file):
Expand All @@ -1104,6 +1111,7 @@ def get_simulation_profiles_realtime(simulation_id: str):
status = state_data.get("status", "")
is_generating = status == "preparing"
total_expected = state_data.get("entities_count")
error = state_data.get("error")
except Exception:
pass

Expand All @@ -1115,6 +1123,8 @@ def get_simulation_profiles_realtime(simulation_id: str):
"count": len(profiles),
"total_expected": total_expected,
"is_generating": is_generating,
"status": status,
"error": error,
"file_exists": file_exists,
"file_modified_at": file_modified_at,
"profiles": profiles
Expand Down Expand Up @@ -1190,6 +1200,9 @@ def get_simulation_config_realtime(simulation_id: str):
# 检查是否正在生成(通过 state.json 判断)
is_generating = False
generation_stage = None
status = None
error = None
profiles_generated = False
config_generated = False

state_file = os.path.join(sim_dir, "state.json")
Expand All @@ -1198,17 +1211,21 @@ def get_simulation_config_realtime(simulation_id: str):
with open(state_file, 'r', encoding='utf-8') as f:
state_data = json.load(f)
status = state_data.get("status", "")
error = state_data.get("error")
is_generating = status == "preparing"
profiles_generated = state_data.get("profiles_generated", False)
config_generated = state_data.get("config_generated", False)

# 判断当前阶段
if is_generating:
if state_data.get("profiles_generated", False):
if profiles_generated:
generation_stage = "generating_config"
else:
generation_stage = "generating_profiles"
elif status == "ready":
generation_stage = "completed"
elif status == "failed":
generation_stage = "failed"
except Exception:
pass

Expand All @@ -1218,7 +1235,10 @@ def get_simulation_config_realtime(simulation_id: str):
"file_exists": file_exists,
"file_modified_at": file_modified_at,
"is_generating": is_generating,
"status": status,
"error": error,
"generation_stage": generation_stage,
"profiles_generated": profiles_generated,
"config_generated": config_generated,
"config": config
}
Expand Down
12 changes: 11 additions & 1 deletion backend/app/services/simulation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class SimulationState:
entity_types: List[str] = field(default_factory=list)

# 配置生成信息
profiles_generated: bool = False
config_generated: bool = False
config_reasoning: str = ""

Expand Down Expand Up @@ -86,6 +87,7 @@ def to_dict(self) -> Dict[str, Any]:
"entities_count": self.entities_count,
"profiles_count": self.profiles_count,
"entity_types": self.entity_types,
"profiles_generated": self.profiles_generated,
"config_generated": self.config_generated,
"config_reasoning": self.config_reasoning,
"current_round": self.current_round,
Expand All @@ -106,6 +108,7 @@ def to_simple_dict(self) -> Dict[str, Any]:
"entities_count": self.entities_count,
"profiles_count": self.profiles_count,
"entity_types": self.entity_types,
"profiles_generated": self.profiles_generated,
"config_generated": self.config_generated,
"error": self.error,
}
Expand Down Expand Up @@ -177,6 +180,7 @@ def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState
entities_count=data.get("entities_count", 0),
profiles_count=data.get("profiles_count", 0),
entity_types=data.get("entity_types", []),
profiles_generated=data.get("profiles_generated", False),
config_generated=data.get("config_generated", False),
config_reasoning=data.get("config_reasoning", ""),
current_round=data.get("current_round", 0),
Expand Down Expand Up @@ -264,6 +268,10 @@ def prepare_simulation(

try:
state.status = SimulationStatus.PREPARING
state.error = None
state.profiles_generated = False
state.config_generated = False
state.config_reasoning = ""
self._save_simulation_state(state)

sim_dir = self._get_simulation_dir(simulation_id)
Expand Down Expand Up @@ -298,7 +306,7 @@ def prepare_simulation(
state.status = SimulationStatus.FAILED
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
self._save_simulation_state(state)
return state
raise ValueError(state.error)

# ========== 阶段2: 生成Agent Profile ==========
total_entities = len(filtered.entities)
Expand Down Expand Up @@ -346,6 +354,8 @@ def profile_progress(current, total, msg):
)

state.profiles_count = len(profiles)
state.profiles_generated = len(profiles) > 0
self._save_simulation_state(state)

# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性
Expand Down
62 changes: 50 additions & 12 deletions backend/app/services/zep_entity_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,35 @@ def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
except Exception as e:
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
return []

def _infer_entity_type(self, node: Dict[str, Any]) -> Optional[str]:
"""
为缺少 labels 的节点推断一个可用于模拟的粗粒度实体类型。

仅在图谱完全没有业务标签时兜底使用,避免把纯概念节点全部纳入模拟。
"""
name_text = (node.get("name", "") or "").lower()
summary_text = (node.get("summary", "") or "").lower()

keyword_groups = [
("GovernmentAgency", ["监管", "政府", "政务", "公共服务", "部门", "官方", "监管机构"]),
("Consumer", ["公众", "消费者", "用户", "居民", "市民", "网民", "家长", "患者"]),
("MedicalInstitution", ["医疗", "医院", "卫健", "诊所", "医药"]),
("EducationalInstitution", ["教育", "学校", "大学", "高校", "学院", "科研"]),
("MediaOutlet", ["媒体", "新闻", "内容产业", "记者", "传媒"]),
("LegalInstitution", ["法律", "法院", "律所", "司法"]),
("Organization", ["企业", "公司", "制造业", "金融", "零售", "电商", "物流", "交通", "平台", "行业"]),
]

for entity_type, keywords in keyword_groups:
if any(keyword in name_text for keyword in keywords):
return entity_type

for entity_type, keywords in keyword_groups:
if any(keyword in summary_text for keyword in keywords):
return entity_type

return None

def filter_defined_entities(
self,
Expand Down Expand Up @@ -251,30 +280,39 @@ def filter_defined_entities(

for node in all_nodes:
labels = node.get("labels", [])

# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]

if not custom_labels:
# 只有默认标签,跳过
continue

# 如果指定了预定义类型,检查是否匹配
inferred_entity_type = None

if custom_labels:
candidate_labels = custom_labels
else:
inferred_entity_type = self._infer_entity_type(node)
if not inferred_entity_type:
continue
candidate_labels = [inferred_entity_type]

if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types]
matching_labels = [l for l in candidate_labels if l in defined_entity_types]
if not matching_labels:
continue
entity_type = matching_labels[0]
else:
entity_type = custom_labels[0]
entity_type = candidate_labels[0]

entity_types_found.add(entity_type)

effective_labels = list(labels)
if inferred_entity_type and inferred_entity_type not in effective_labels:
effective_labels.append(inferred_entity_type)
logger.info(
f"节点 {node['name']} 缺少业务标签,推断实体类型为 {inferred_entity_type}"
)

# 创建实体节点对象
entity = EntityNode(
uuid=node["uuid"],
name=node["name"],
labels=labels,
labels=effective_labels,
summary=node["summary"],
attributes=node["attributes"],
)
Expand Down
45 changes: 32 additions & 13 deletions frontend/src/components/Step2EnvSetup.vue
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,15 @@ const addLog = (msg) => {
emit('add-log', msg)
}

const handlePrepareFailure = (message) => {
const errorMessage = message || '模拟环境准备失败'
stopPolling()
stopProfilesPolling()
stopConfigPolling()
addLog(`✗ ${errorMessage}`)
emit('update-status', 'error')
}

// 处理开始模拟按钮点击
const handleStartSimulation = () => {
// 构建传递给父组件的参数
Expand Down Expand Up @@ -895,9 +904,7 @@ const pollPrepareStatus = async () => {
stopProfilesPolling()
await loadPreparedData()
} else if (data.status === 'failed') {
addLog(`✗ 准备失败: ${data.error || '未知错误'}`)
stopPolling()
stopProfilesPolling()
handlePrepareFailure(`准备失败: ${data.error || '未知错误'}`)
}
}
} catch (err) {
Expand Down Expand Up @@ -969,6 +976,11 @@ const fetchConfigRealtime = async () => {

if (res.success && res.data) {
const data = res.data

if (data.status === 'failed' || data.error) {
handlePrepareFailure(data.error || '配置生成失败')
return
}

// 输出配置生成阶段日志(避免重复)
if (data.generation_stage && data.generation_stage !== lastLoggedConfigStage) {
Expand Down Expand Up @@ -1029,29 +1041,36 @@ const loadPreparedData = async () => {
try {
const res = await getSimulationConfigRealtime(props.simulationId)
if (res.success && res.data) {
if (res.data.config_generated && res.data.config) {
simulationConfig.value = res.data.config
const configState = res.data

if (configState.status === 'failed' || configState.error) {
handlePrepareFailure(configState.error || '配置生成失败')
return
}

if (configState.config_generated && configState.config) {
simulationConfig.value = configState.config
addLog('✓ 模拟配置加载成功')

// 显示详细配置摘要
if (res.data.summary) {
addLog(` ├─ Agent数量: ${res.data.summary.total_agents}个`)
addLog(` ├─ 模拟时长: ${res.data.summary.simulation_hours}小时`)
addLog(` └─ 初始帖子: ${res.data.summary.initial_posts_count}条`)
if (configState.summary) {
addLog(` ├─ Agent数量: ${configState.summary.total_agents}个`)
addLog(` ├─ 模拟时长: ${configState.summary.simulation_hours}小时`)
addLog(` └─ 初始帖子: ${configState.summary.initial_posts_count}条`)
}

addLog('✓ 环境搭建完成,可以开始模拟')
phase.value = 4
emit('update-status', 'completed')
} else {
// 配置尚未生成,开始轮询
} else if (configState.is_generating) {
addLog('配置生成中,开始轮询等待...')
startConfigPolling()
} else {
handlePrepareFailure('配置尚未生成,且后端未处于生成中状态')
}
}
} catch (err) {
addLog(`加载配置失败: ${err.message}`)
emit('update-status', 'error')
handlePrepareFailure(`加载配置失败: ${err.message}`)
}
}

Expand Down