Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.dev
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MINIO_ENDPOINT=127.0.0.1:9000
SQLALCHEMY_DATABASE_URI=postgresql+psycopg2://aix_db:1@127.0.0.1:15432/aix_db

# LangFuse 配置 默认关闭 (可选)
LANGFUSE_TRACING_ENABLED="true"
LANGFUSE_TRACING_ENABLED="false"
LANGFUSE_SECRET_KEY = "sk-lf-4bf2a844-4a9c-4626-af69-0cae99bf2bfb"
LANGFUSE_PUBLIC_KEY = "pk-lf-8aff3c29-3239-4a52-8028-bacc185f6c22"
LANGFUSE_BASE_URL = "http://localhost:3000"
13 changes: 10 additions & 3 deletions agent/text2sql/analysis/data_render_antv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import json
import logging
import os
import traceback
from decimal import Decimal
from datetime import datetime, date
Expand Down Expand Up @@ -37,6 +38,9 @@
"excel": "postgres", # Excel 使用 PostgreSQL 规则
}

# 前端表格预览最大行数(避免一次性返回全部数据导致页面卡顿)
TABLE_PREVIEW_MAX_ROWS = int(os.getenv("TABLE_PREVIEW_MAX_ROWS", "100"))


def convert_value(v):
"""转换数据类型"""
Expand Down Expand Up @@ -513,8 +517,11 @@ async def data_render_ant(state: AgentState):
except Exception as e:
logger.warning(f"获取数据源类型失败: {e},使用默认值 mysql")

# 视图侧仅预览前 TABLE_PREVIEW_MAX_ROWS 行,避免一次性渲染全部数据导致前端卡顿
preview_data = data[:TABLE_PREVIEW_MAX_ROWS]

# 获取实际的列名(从第一条数据中提取)
actual_columns = list(data[0].keys()) if data else []
actual_columns = list(preview_data[0].keys()) if preview_data else []

if not actual_columns:
logger.warning("无法从数据中提取列名,跳过数据渲染")
Expand All @@ -536,9 +543,9 @@ async def data_render_ant(state: AgentState):
else:
logger.warning(f"列名映射失败或返回空,使用原始列名。actual_columns={actual_columns[:3]}")

# 转换数据格式: 将英文列名映射为中文列名
# 转换数据格式: 将英文列名映射为中文列名(仅对预览数据执行)
formatted_data = []
for row in data:
for row in preview_data:
formatted_row = {}
for col_name, value in row.items():
chinese_col_name = column_mapping.get(col_name, col_name)
Expand Down
14 changes: 11 additions & 3 deletions agent/text2sql/analysis/llm_summarizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from datetime import datetime, date
import json
import re
Expand All @@ -11,6 +12,11 @@
from agent.text2sql.template.prompt_builder import PromptBuilder

logger = logging.getLogger(__name__)

# 总结时最多带入的数据行数,避免 prompt 过大拖慢 LLM(可配置)
SUMMARIZE_MAX_ROWS = int(os.getenv("SUMMARIZE_MAX_ROWS", "25"))
# 总结时数据 JSON 最大字符数,超出则截断并注明(可配置)
SUMMARIZE_MAX_CHARS = int(os.getenv("SUMMARIZE_MAX_CHARS", "12000"))
"""
大模型数据总结节点
"""
Expand Down Expand Up @@ -65,14 +71,16 @@ def summarize(state: AgentState):
prompt_builder = PromptBuilder()

try:
# 获取数据结果
# 获取数据结果,限制行数与长度以加快 LLM 总结
data_result = state["execution_result"].data

# 如果数据是字典或列表,转换为JSON字符串
if isinstance(data_result, list) and len(data_result) > SUMMARIZE_MAX_ROWS:
data_result = data_result[:SUMMARIZE_MAX_ROWS]
if isinstance(data_result, (dict, list)):
data_result_str = json.dumps(data_result, ensure_ascii=False, indent=2, cls=DecimalEncoder)
else:
data_result_str = str(data_result)
if len(data_result_str) > SUMMARIZE_MAX_CHARS:
data_result_str = data_result_str[:SUMMARIZE_MAX_CHARS] + "\n\n...(数据已截断,仅展示前一部分供总结)"

# 获取当前时间
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
Expand Down
2 changes: 1 addition & 1 deletion agent/text2sql/analysis/parallel_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def parallel_collect(state: AgentState, tasks: list[str] = None) -> AgentState:

for task, future in futures.items():
try:
result_state = future.result(timeout=180) # 最多等待60秒
result_state = future.result(timeout=60) # 单任务最多等待 60 秒,避免整体过慢
results[task] = result_state
logger.info(f"✅ 任务完成: {task}")
except Exception as e:
Expand Down
34 changes: 30 additions & 4 deletions agent/text2sql/analysis/unified_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,37 @@ async def unified_collect(state: AgentState) -> AgentState:
# chart_config 为空,检查是否是因为查询结果为空
execution_result = state.get("execution_result")
if execution_result and execution_result.success and not execution_result.data:
# SQL执行成功但无数据,生成空结果卡片,让前端显示空结果提示并可查看SQL
logger.info("📊 SQL执行成功但无数据,生成空结果卡片")
# SQL执行成功但无数据,仍然返回表格模板(temp01),这样前端可以显示表格结构和分页控件
logger.info("📊 SQL执行成功但无数据,生成空表格")
# 尝试从 SQL 中提取列名
columns = []
generated_sql = state.get("generated_sql", "") or state.get("filtered_sql", "")
if generated_sql:
try:
# 简单的列名提取:从 SELECT 和 FROM 之间提取
import re
select_match = re.search(r'SELECT\s+(.*?)\s+FROM', generated_sql, re.IGNORECASE | re.DOTALL)
if select_match:
select_clause = select_match.group(1)
# 分割列名(简单处理,不考虑复杂的子查询)
col_parts = [c.strip() for c in select_clause.split(',')]
for col in col_parts:
# 提取别名或列名
if ' AS ' in col.upper():
alias = col.split(' AS ')[-1].strip().strip('`').strip('"').strip("'")
columns.append(alias)
else:
# 提取最后一个点号后面的部分(表名.列名 -> 列名)
col_name = col.strip().strip('`').strip('"').strip("'")
if '.' in col_name:
col_name = col_name.split('.')[-1]
columns.append(col_name)
except Exception as e:
logger.warning(f"从 SQL 提取列名失败: {e}")

state["render_data"] = {
"template_code": "temp05",
"columns": [],
"template_code": "temp01", # 改为 temp01,显示表格
"columns": columns if columns else [],
"data": [],
}
else:
Expand Down
Loading