自动化报表 Agent:数据提取、分析与可视化
自动化报表 Agent 能够理解数据源、生成查询逻辑、分析数据趋势并推荐合适的可视化方案,大幅提升数据到洞察的转化效率。
自动化报表 Agent:数据提取、分析与可视化
在企业数据应用中,从原始数据到业务洞察往往需要经历多个环节:连接数据源、编写查询逻辑、清洗数据、计算指标、选择图表类型、生成可视化报告。传统模式下,这个过程高度依赖数据分析师和工程师的专业技能,耗时且容易出错。
自动化报表 Agent 的核心价值在于将整个流程智能化:它能够理解业务需求,自动生成高质量的 SQL 查询,进行数据质量检查,识别数据模式,并推荐最合适的可视化方案。这不仅提升了效率,还降低了数据分析的门槛,让更多业务人员能够直接从数据中获得洞察。
核心概念与架构设计
多 Agent 协作模式
自动化报表 Agent 采用 Supervisor/Worker 协作模式,由一个主协调 Agent 负责任务分配和结果聚合,多个专业 Agent 负责具体的子任务。这种模式既保证了任务的灵活性,又确保了各个环节的质量控制。
主协调 Agent 的职责包括:
- 解析用户需求并拆分为可执行的子任务
- 协调各个专业 Agent 的执行顺序
- 收集中间结果并进行质量验证
- 处理异常情况并进行回滚或修正
- 汇总最终结果并生成可交付的报表
专业 Agent 包括:
- 数据源连接 Agent:管理数据库连接、表结构理解、权限验证
- SQL 生成 Agent:根据需求生成优化后的 SQL 查询
- 数据质量 Agent:检查数据完整性、一致性、异常值
- 分析逻辑 Agent:生成聚合、分组、计算指标的逻辑
- 可视化 Agent:根据数据特征推荐图表类型
- 报告生成 Agent:组装最终报告并添加注释说明
数据理解与元数据管理
Agent 首先需要理解数据源的结构和语义。这包括表结构、字段类型、关系、业务语义等。元数据管理是整个系统的基础,直接影响后续 SQL 生成和分析逻辑的准确性。
SQL 生成优化机制
SQL 生成是自动化报表的核心能力。Agent 需要将自然语言需求转换为准确且高效的 SQL 查询。这个过程涉及意图识别、实体映射、逻辑构建、性能优化等多个环节。
SQL 生成 Agent 采用多阶段策略:
- 需求解析阶段:识别查询目标、过滤条件、聚合维度、排序规则
- 实体映射阶段:将业务术语映射到实际的表名、字段名
- 逻辑构建阶段:构建 SELECT、FROM、WHERE、GROUP BY、HAVING、ORDER BY 子句
- 优化阶段:应用索引提示、避免全表扫描、优化 JOIN 顺序
- 验证阶段:检查 SQL 语法、验证字段存在性、预估执行成本
关键技术实现
报告生成核心框架
下面实现一个完整的自动化报表生成框架,涵盖数据源连接、SQL 生成、数据分析和可视化推荐:
import os
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.engine import Engine
from datetime import datetime, timedelta
import json
import re
class DataSourceType(Enum):
MYSQL = "mysql"
POSTGRESQL = "postgresql"
CLICKHOUSE = "clickhouse"
SQLITE = "sqlite"
@dataclass
class DataSource:
name: str
type: DataSourceType
host: str
port: int
database: str
username: str
password: str
schema: Optional[str] = None
@dataclass
class TableMetadata:
name: str
columns: List[Dict[str, Any]]
indexes: List[Dict[str, Any]]
row_count: int
sample_data: pd.DataFrame
@dataclass
class SQLQuery:
query: str
explanation: str
estimated_rows: int
execution_plan: str
@dataclass
class DataQualityReport:
total_rows: int
null_counts: Dict[str, int]
duplicate_count: int
outlier_count: int
completeness_score: float
consistency_score: float
@dataclass
class AnalysisInsight:
metric: str
value: Any
trend: str # "increasing", "decreasing", "stable"
change_percent: float
description: str
@dataclass
class VisualizationRecommendation:
chart_type: str # "line", "bar", "pie", "scatter", "heatmap"
x_axis: str
y_axis: str
color_by: Optional[str]
title: str
rationale: str
class DataSourceManager:
"""管理数据源连接和元数据"""
def __init__(self):
self.connections: Dict[str, Engine] = {}
self.metadata_cache: Dict[str, Dict[str, TableMetadata]] = {}
def create_connection(self, source: DataSource) -> Engine:
"""创建数据库连接"""
if source.name in self.connections:
return self.connections[source.name]
connection_string = self._build_connection_string(source)
engine = create_engine(connection_string, pool_pre_ping=True)
self.connections[source.name] = engine
return engine
def _build_connection_string(self, source: DataSource) -> str:
"""构建数据库连接字符串"""
if source.type == DataSourceType.MYSQL:
return f"mysql+pymysql://{source.username}:{source.password}@{source.host}:{source.port}/{source.database}"
elif source.type == DataSourceType.POSTGRESQL:
return f"postgresql://{source.username}:{source.password}@{source.host}:{source.port}/{source.database}"
elif source.type == DataSourceType.CLICKHOUSE:
return f"clickhouse://{source.username}:{source.password}@{source.host}:{source.port}/{source.database}"
elif source.type == DataSourceType.SQLITE:
return f"sqlite:///{source.host}"
else:
raise ValueError(f"Unsupported data source type: {source.type}")
def get_table_metadata(self, source: DataSource, table_name: str) -> TableMetadata:
"""获取表的元数据"""
cache_key = f"{source.name}:{table_name}"
if cache_key in self.metadata_cache:
return self.metadata_cache[cache_key]
engine = self.create_connection(source)
inspector = inspect(engine)
# 获取列信息
columns = []
for column in inspector.get_columns(table_name):
columns.append({
"name": column["name"],
"type": str(column["type"]),
"nullable": column.get("nullable", True),
"default": column.get("default"),
"autoincrement": column.get("autoincrement", False)
})
# 获取索引信息
indexes = []
for index in inspector.get_indexes(table_name):
indexes.append({
"name": index["name"],
"columns": index["column_names"],
"unique": index.get("unique", False)
})
# 获取行数和样本数据
with engine.connect() as conn:
result = conn.execute(text(f"SELECT COUNT(*) FROM {table_name}"))
row_count = result.scalar()
sample_query = text(f"SELECT * FROM {table_name} LIMIT 100")
sample_data = pd.read_sql(sample_query, conn)
metadata = TableMetadata(
name=table_name,
columns=columns,
indexes=indexes,
row_count=row_count,
sample_data=sample_data
)
self.metadata_cache[cache_key] = metadata
return metadata
def execute_query(self, source: DataSource, query: str) -> pd.DataFrame:
"""执行 SQL 查询"""
engine = self.create_connection(source)
with engine.connect() as conn:
return pd.read_sql(text(query), conn)
class SQLGeneratorAgent:
"""SQL 生成 Agent"""
def __init__(self, source_manager: DataSourceManager):
self.source_manager = source_manager
self.entity_mappings: Dict[str, Dict[str, str]] = {}
def understand_requirement(self, requirement: str, source: DataSource) -> Dict[str, Any]:
"""理解用户需求并提取关键信息"""
# 简化的需求理解,实际应用中应该使用 LLM
entities = {
"tables": self._extract_tables(requirement, source),
"metrics": self._extract_metrics(requirement),
"dimensions": self._extract_dimensions(requirement),
"filters": self._extract_filters(requirement),
"time_range": self._extract_time_range(requirement),
"aggregation": self._extract_aggregation(requirement)
}
return entities
def _extract_tables(self, requirement: str, source: DataSource) -> List[str]:
"""从需求中提取表名"""
# 简化实现,实际应该从元数据中心搜索
metadata = self.source_manager.metadata_cache.get(source.name, {})
available_tables = list(metadata.keys())
# 关键词匹配
table_keywords = {
"订单": ["orders", "order"],
"用户": ["users", "user", "customers"],
"商品": ["products", "items", "goods"],
"销售": ["sales", "transactions"]
}
matched_tables = []
for keyword, table_names in table_keywords.items():
if keyword in requirement:
for table_name in table_names:
if table_name in available_tables:
matched_tables.append(table_name)
break
# 如果没有匹配到,使用第一个可用表
if not matched_tables and available_tables:
matched_tables = [available_tables[0]]
return matched_tables
def _extract_metrics(self, requirement: str) -> List[str]:
"""提取指标"""
metric_keywords = {
"销售额": ["amount", "price", "total"],
"数量": ["quantity", "count", "num"],
"利润": ["profit", "margin"],
"成本": ["cost", "expense"],
"转化率": ["conversion_rate", "conversion"]
}
metrics = []
for metric_name, field_names in metric_keywords.items():
if metric_name in requirement:
metrics.append(metric_name)
return metrics
def _extract_dimensions(self, requirement: str) -> List[str]:
"""提取维度"""
dimension_keywords = ["按", "分组", "类别", "类型", "地区", "时间", "日期"]
dimensions = []
for keyword in dimension_keywords:
if keyword in requirement:
dimensions.append(keyword)
return dimensions
def _extract_filters(self, requirement: str) -> List[Dict[str, Any]]:
"""提取过滤条件"""
filters = []
# 简化的过滤条件提取
if "大于" in requirement:
match = re.search(r'(\w+)大于(\d+)', requirement)
if match:
filters.append({
"field": match.group(1),
"operator": ">",
"value": match.group(2)
})
if "小于" in requirement:
match = re.search(r'(\w+)小于(\d+)', requirement)
if match:
filters.append({
"field": match.group(1),
"operator": "<",
"value": match.group(2)
})
return filters
def _extract_time_range(self, requirement: str) -> Optional[Dict[str, str]]:
"""提取时间范围"""
time_patterns = {
"今天": ("NOW()::DATE", "NOW()::DATE + INTERVAL 1 DAY"),
"昨天": ("NOW()::DATE - INTERVAL 1 DAY", "NOW()::DATE"),
"本周": ("date_trunc('week', NOW())", "date_trunc('week', NOW()) + INTERVAL 1 WEEK"),
"本月": ("date_trunc('month', NOW())", "date_trunc('month', NOW()) + INTERVAL 1 MONTH"),
"最近7天": ("NOW()::DATE - INTERVAL 7 DAY", "NOW()::DATE"),
"最近30天": ("NOW()::DATE - INTERVAL 30 DAY", "NOW()::DATE")
}
for keyword, (start_date, end_date) in time_patterns.items():
if keyword in requirement:
return {"start": start_date, "end": end_date}
return None
def _extract_aggregation(self, requirement: str) -> str:
"""提取聚合类型"""
if "总和" in requirement or "总计" in requirement:
return "SUM"
elif "平均" in requirement:
return "AVG"
elif "数量" in requirement or "次数" in requirement:
return "COUNT"
elif "最大" in requirement:
return "MAX"
elif "最小" in requirement:
return "MIN"
else:
return "SUM"
def generate_sql(self, requirement: str, source: DataSource) -> SQLQuery:
"""生成 SQL 查询"""
entities = self.understand_requirement(requirement, source)
if not entities["tables"]:
raise ValueError("无法从需求中识别出相关表")
table = entities["tables"][0]
metadata = self.source_manager.get_table_metadata(source, table)
# 构建 SELECT 子句
select_clauses = []
group_by_clauses = []
# 添加维度字段
if entities["dimensions"]:
for dim in entities["dimensions"]:
# 简化:假设维度字段名包含关键词
matching_columns = [
col["name"] for col in metadata.columns
if dim in col["name"] or col["name"] in ["category", "type", "region", "date", "created_at"]
]
for col_name in matching_columns[:3]: # 限制最多3个维度
select_clauses.append(col_name)
group_by_clauses.append(col_name)
# 添加指标字段
metric_field = "amount" # 默认字段
if entities["metrics"]:
for metric in entities["metrics"]:
matching_columns = [
col["name"] for col in metadata.columns
if metric in col["name"] or col["name"] in ["amount", "price", "total", "count", "quantity"]
]
if matching_columns:
metric_field = matching_columns[0]
break
agg_func = entities["aggregation"]
select_clauses.append(f"{agg_func}({metric_field}) as {agg_func.lower()}_{metric_field}")
# 构建 WHERE 子句
where_clauses = []
if entities["time_range"]:
date_column = "created_at" # 假设日期字段名
where_clauses.append(f"{date_column} >= {entities['time_range']['start']}")
where_clauses.append(f"{date_column} < {entities['time_range']['end']}")
for filter_cond in entities["filters"]:
where_clauses.append(f"{filter_cond['field']} {filter_cond['operator']} {filter_cond['value']}")
# 构建 SQL
query_parts = [
f"SELECT {', '.join(select_clauses)}",
f"FROM {table}"
]
if where_clauses:
query_parts.append(f"WHERE {' AND '.join(where_clauses)}")
if group_by_clauses:
query_parts.append(f"GROUP BY {', '.join(group_by_clauses)}")
query_parts.append(f"ORDER BY {agg_func.lower()}_{metric_field} DESC")
query = " ".join(query_parts)
# 生成解释
explanation = self._generate_explanation(entities, agg_func, metric_field)
return SQLQuery(
query=query,
explanation=explanation,
estimated_rows=metadata.row_count // 10, # 简化估算
execution_plan="" # 实际应用中应该执行 EXPLAIN
)
def _generate_explanation(self, entities: Dict[str, Any], agg_func: str, metric_field: str) -> str:
"""生成 SQL 解释"""
parts = []
parts.append(f"查询 {entities['tables'][0]} 表")
if entities["dimensions"]:
parts.append(f"按 {', '.join(entities['dimensions'])} 分组")
parts.append(f"计算 {metric_field} 的 {agg_func} 值")
if entities["time_range"]:
parts.append(f"时间范围: {entities['time_range']['start']} 至 {entities['time_range']['end']}")
if entities["filters"]:
filter_desc = ", ".join([f"{f['field']} {f['operator']} {f['value']}" for f in entities["filters"]])
parts.append(f"过滤条件: {filter_desc}")
return ",".join(parts) + "。"
class DataQualityAgent:
"""数据质量检查 Agent"""
def check_quality(self, data: pd.DataFrame, metadata: TableMetadata) -> DataQualityReport:
"""检查数据质量"""
total_rows = len(data)
# 检查空值
null_counts = data.isnull().sum().to_dict()
# 检查重复行
duplicate_count = data.duplicated().sum()
# 检查异常值(简化:使用 IQR 方法)
outlier_count = 0
numeric_columns = data.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
Q1 = data[col].quantile(0.25)
Q3 = data[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
outliers = ((data[col] < lower_bound) | (data[col] > upper_bound)).sum()
outlier_count += outliers
# 计算完整性评分
total_cells = len(data.columns) * total_rows
null_cells = sum(null_counts.values())
completeness_score = 1.0 - (null_cells / total_cells) if total_cells > 0 else 1.0
# 计算一致性评分(简化)
consistency_score = 1.0 - (duplicate_count / total_rows) if total_rows > 0 else 1.0
return DataQualityReport(
total_rows=total_rows,
null_counts=null_counts,
duplicate_count=duplicate_count,
outlier_count=outlier_count,
completeness_score=completeness_score,
consistency_score=consistency_score
)
def generate_quality_report(self, report: DataQualityReport) -> str:
"""生成质量报告"""
parts = []
parts.append(f"数据质量检查报告")
parts.append(f"总行数: {report.total_rows}")
parts.append(f"完整性评分: {report.completeness_score:.2%}")
parts.append(f"一致性评分: {report.consistency_score:.2%}")
parts.append(f"重复行数: {report.duplicate_count}")
parts.append(f"异常值数量: {report.outlier_count}")
if report.null_counts:
parts.append("空值统计:")
for col, count in report.null_counts.items():
if count > 0:
parts.append(f" {col}: {count}")
return "\n".join(parts)
class AnalysisLogicAgent:
"""分析逻辑生成 Agent"""
def generate_insights(self, data: pd.DataFrame, metadata: TableMetadata) -> List[AnalysisInsight]:
"""生成分析洞察"""
insights = []
numeric_columns = data.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
# 计算基础指标
current_value = data[col].iloc[-1] if len(data) > 0 else None
if current_value is None:
continue
# 计算趋势
if len(data) >= 2:
prev_value = data[col].iloc[-2]
change_percent = ((current_value - prev_value) / prev_value * 100) if prev_value != 0 else 0
if change_percent > 5:
trend = "increasing"
elif change_percent < -5:
trend = "decreasing"
else:
trend = "stable"
insight = AnalysisInsight(
metric=col,
value=current_value,
trend=trend,
change_percent=change_percent,
description=f"{col} 当前值为 {current_value:.2f},较上期{'增长' if change_percent > 0 else '下降'} {abs(change_percent):.2f}%"
)
insights.append(insight)
return insights
def generate_summary(self, insights: List[AnalysisInsight]) -> str:
"""生成摘要"""
if not insights:
return "无显著洞察。"
parts = ["关键洞察:"]
for insight in insights:
parts.append(f"- {insight.description}")
return "\n".join(parts)
class VisualizationAgent:
"""可视化推荐 Agent"""
def recommend_visualization(
self,
data: pd.DataFrame,
requirement: str,
metadata: TableMetadata
) -> List[VisualizationRecommendation]:
"""推荐可视化方案"""
recommendations = []
# 分析数据特征
numeric_columns = data.select_dtypes(include=[np.number]).columns.tolist()
categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
time_columns = data.select_dtypes(include=['datetime64']).columns.tolist()
# 场景1: 时间序列数据
if time_columns and numeric_columns:
time_col = time_columns[0]
metric_col = numeric_columns[0]
recommendations.append(VisualizationRecommendation(
chart_type="line",
x_axis=time_col,
y_axis=metric_col,
color_by=None,
title=f"{metric_col} 趋势图",
rationale="时间序列数据适合用折线图展示趋势变化"
))
# 场景2: 分类对比
if categorical_columns and numeric_columns:
cat_col = categorical_columns[0]
metric_col = numeric_columns[0]
if len(data[cat_col].unique()) <= 10: # 类别数量适中
recommendations.append(VisualizationRecommendation(
chart_type="bar",
x_axis=cat_col,
y_axis=metric_col,
color_by=None,
title=f"{cat_col} vs {metric_col}",
rationale="分类数据适合用柱状图展示对比"
))
# 场景3: 部分与整体
if categorical_columns and len(data[categorical_columns[0]].unique()) <= 8:
cat_col = categorical_columns[0]
metric_col = numeric_columns[0] if numeric_columns else None
if metric_col:
recommendations.append(VisualizationRecommendation(
chart_type="pie",
x_axis=cat_col,
y_axis=metric_col,
color_by=None,
title=f"{cat_col} 占比分布",
rationale="少量分类适合用饼图展示占比"
))
# 场景4: 相关性分析
if len(numeric_columns) >= 2:
recommendations.append(VisualizationRecommendation(
chart_type="scatter",
x_axis=numeric_columns[0],
y_axis=numeric_columns[1],
color_by=None,
title=f"{numeric_columns[0]} vs {numeric_columns[1]}",
rationale="两个数值变量适合用散点图探索相关性"
))
return recommendations
class ReportGenerationAgent:
"""报告生成 Agent"""
def __init__(self):
self.source_manager = DataSourceManager()
self.sql_agent = SQLGeneratorAgent(self.source_manager)
self.quality_agent = DataQualityAgent()
self.analysis_agent = AnalysisLogicAgent()
self.viz_agent = VisualizationAgent()
def generate_report(
self,
requirement: str,
source: DataSource,
include_visualization: bool = True
) -> Dict[str, Any]:
"""生成完整报告"""
report = {
"requirement": requirement,
"timestamp": datetime.now().isoformat(),
"steps": []
}
try:
# 步骤1: 生成 SQL
report["steps"].append({"step": 1, "action": "生成 SQL 查询"})
sql_result = self.sql_agent.generate_sql(requirement, source)
report["sql"] = sql_result
# 步骤2: 执行查询
report["steps"].append({"step": 2, "action": "执行数据查询"})
data = self.source_manager.execute_query(source, sql_result.query)
report["data_summary"] = {
"rows": len(data),
"columns": len(data.columns),
"column_names": data.columns.tolist()
}
# 步骤3: 数据质量检查
report["steps"].append({"step": 3, "action": "数据质量检查"})
table_name = self.sql_agent.understand_requirement(requirement, source)["tables"][0]
metadata = self.source_manager.get_table_metadata(source, table_name)
quality_report = self.quality_agent.check_quality(data, metadata)
report["quality_report"] = self.quality_agent.generate_quality_report(quality_report)
# 步骤4: 生成分析洞察
report["steps"].append({"step": 4, "action": "生成分析洞察"})
insights = self.analysis_agent.generate_insights(data, metadata)
report["insights"] = [insight.__dict__ for insight in insights]
report["summary"] = self.analysis_agent.generate_summary(insights)
# 步骤5: 推荐可视化
if include_visualization:
report["steps"].append({"step": 5, "action": "推荐可视化方案"})
viz_recommendations = self.viz_agent.recommend_visualization(data, requirement, metadata)
report["visualizations"] = [viz.__dict__ for viz in viz_recommendations]
report["status"] = "success"
report["message"] = "报告生成成功"
except Exception as e:
report["status"] = "error"
report["message"] = f"报告生成失败: {str(e)}"
report["error"] = str(e)
return report
def export_report(self, report: Dict[str, Any], format: str = "json") -> str:
"""导出报告"""
if format == "json":
return json.dumps(report, ensure_ascii=False, indent=2, default=str)
elif format == "markdown":
return self._export_markdown(report)
else:
raise ValueError(f"不支持的格式: {format}")
def _export_markdown(self, report: Dict[str, Any]) -> str:
"""导出为 Markdown 格式"""
lines = []
lines.append("# 自动化报表")
lines.append(f"**需求**: {report['requirement']}")
lines.append(f"**生成时间**: {report['timestamp']}")
lines.append("")
if report.get("sql"):
lines.append("## SQL 查询")
lines.append("```sql")
lines.append(report["sql"]["query"])
lines.append("```")
lines.append(f"**说明**: {report['sql']['explanation']}")
lines.append("")
if report.get("data_summary"):
lines.append("## 数据概览")
lines.append(f"- 行数: {report['data_summary']['rows']}")
lines.append(f"- 列数: {report['data_summary']['columns']}")
lines.append(f"- 字段: {', '.join(report['data_summary']['column_names'])}")
lines.append("")
if report.get("quality_report"):
lines.append("## 数据质量")
lines.append("```")
lines.append(report["quality_report"])
lines.append("```")
lines.append("")
if report.get("insights"):
lines.append("## 分析洞察")
for insight in report["insights"]:
lines.append(f"- {insight['description']}")
lines.append("")
lines.append(f"**摘要**: {report.get('summary', '')}")
lines.append("")
if report.get("visualizations"):
lines.append("## 可视化推荐")
for viz in report["visualizations"]:
lines.append(f"### {viz['title']}")
lines.append(f"- 图表类型: {viz['chart_type']}")
lines.append(f"- X轴: {viz['x_axis']}")
lines.append(f"- Y轴: {viz['y_axis']}")
lines.append(f"- 说明: {viz['rationale']}")
lines.append("")
return "\n".join(lines)
# 使用示例
if __name__ == "__main__":
# 创建数据源(示例使用 SQLite)
import sqlite3
# 创建示例数据库
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
# 创建示例表
cursor.execute("""
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
user_id INTEGER,
product_id INTEGER,
amount REAL,
quantity INTEGER,
category TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# 插入示例数据
import random
from datetime import datetime, timedelta
categories = ["电子产品", "服装", "食品", "家居"]
for i in range(1000):
cursor.execute("""
INSERT INTO orders (user_id, product_id, amount, quantity, category, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""", (
random.randint(1, 100),
random.randint(1, 50),
random.uniform(10, 1000),
random.randint(1, 5),
random.choice(categories),
datetime.now() - timedelta(days=random.randint(0, 30))
))
conn.commit()
# 创建数据源
source = DataSource(
name="example_db",
type=DataSourceType.SQLITE,
host=":memory:",
port=0,
database="example",
username="",
password=""
)
# 初始化报告生成 Agent
report_agent = ReportGenerationAgent()
# 预加载元数据
report_agent.source_manager.metadata_cache[source.name] = {
"orders": report_agent.source_manager.get_table_metadata(source, "orders")
}
# 生成报告
requirement = "按类别统计最近7天的销售额总和"
report = report_agent.generate_report(requirement, source)
# 输出结果
print("=== 报告生成结果 ===")
print(report_agent.export_report(report, format="markdown"))
# 输出 SQL
print("\n=== 生成的 SQL ===")
print(report["sql"]["query"])
print("\n说明:", report["sql"]["explanation"])
conn.close()
这个实现展示了自动化报表 Agent 的核心功能:
- 数据源管理:支持多种数据库类型,自动获取表结构和元数据
- SQL 生成:根据自然语言需求生成优化的 SQL 查询
- 数据质量检查:检测空值、重复值、异常值
- 分析洞察生成:自动识别趋势和变化
- 可视化推荐:根据数据特征推荐合适的图表类型
增强的 SQL 生成策略
对于更复杂的场景,可以集成 LLM 来增强 SQL 生成能力:
import openai
class LLMEnhancedSQLGenerator:
"""使用 LLM 增强的 SQL 生成器"""
def __init__(self, api_key: str):
self.client = openai.OpenAI(api_key=api_key)
self.prompt_template = """
你是一个专业的 SQL 生成专家。根据以下信息生成 SQL 查询:
数据库表结构:
{schema}
用户需求:
{requirement}
请生成一个高效的 SQL 查询,并解释查询逻辑。
只返回 SQL 和解释,不要有其他内容。
SQL:
"""
def generate_sql_with_llm(
self,
requirement: str,
table_schemas: Dict[str, Dict[str, Any]]
) -> Tuple[str, str]:
"""使用 LLM 生成 SQL"""
# 构建模式描述
schema_desc = []
for table_name, columns in table_schemas.items():
col_desc = ", ".join([f"{col['name']} ({col['type']})" for col in columns])
schema_desc.append(f"表 {table_name}: {col_desc}")
schema_text = "\n".join(schema_desc)
# 调用 LLM
prompt = self.prompt_template.format(
schema=schema_text,
requirement=requirement
)
response = self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "你是一个专业的 SQL 生成专家。"},
{"role": "user", "content": prompt}
],
temperature=0.1
)
result = response.choices[0].message.content.strip()
# 解析 SQL 和解释
lines = result.split("\n")
sql_lines = []
explanation_lines = []
capture_explanation = False
for line in lines:
if line.strip().upper().startswith("SELECT"):
sql_lines.append(line)
elif capture_explanation:
explanation_lines.append(line)
elif "解释" in line or "说明" in line:
capture_explanation = True
sql = "\n".join(sql_lines) if sql_lines else ""
explanation = "\n".join(explanation_lines).strip() if explanation_lines else ""
return sql, explanation
def validate_and_optimize_sql(
self,
sql: str,
source: DataSource,
source_manager: DataSourceManager
) -> Tuple[str, List[str]]:
"""验证并优化 SQL"""
warnings = []
try:
# 语法检查
engine = source_manager.create_connection(source)
with engine.connect() as conn:
# 执行 EXPLAIN
explain_result = conn.execute(text(f"EXPLAIN {sql}"))
plan = explain_result.fetchall()
# 分析执行计划
plan_text = "\n".join([str(row) for row in plan])
if "Seq Scan" in plan_text and "Index" not in plan_text:
warnings.append("警告: 查询可能进行全表扫描,建议添加索引")
if "Nested Loop" in plan_text:
warnings.append("提示: 存在嵌套循环,可能影响性能")
if plan_text.count("Index Scan") > 5:
warnings.append("注意: 使用了多个索引扫描,检查是否可以优化")
except Exception as e:
raise ValueError(f"SQL 验证失败: {str(e)}")
return sql, warnings
最佳实践与常见陷阱
生产环境最佳实践
1. 缓存机制
对于频繁查询的报表,应该实现多层缓存:
- 元数据缓存:表结构、索引信息等不经常变化的数据
- 查询结果缓存:对相同查询参数的结果进行缓存
- SQL 模板缓存:对常见查询模式的 SQL 进行缓存
from functools import lru_cache
import hashlib
import pickle
from pathlib import Path
from typing import Optional
class QueryCache:
"""查询结果缓存"""
def __init__(self, cache_dir: str = ".cache"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
def _get_cache_key(self, query: str, params: Dict = None) -> str:
"""生成缓存键"""
key_str = query + str(sorted(params.items()) if params else "")
return hashlib.md5(key_str.encode()).hexdigest()
def get(self, query: str, params: Dict = None) -> Optional[pd.DataFrame]:
"""获取缓存"""
cache_key = self._get_cache_key(query, params)
cache_file = self.cache_dir / f"{cache_key}.pkl"
if cache_file.exists():
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
# 检查是否过期(24小时)
if (datetime.now() - cached_data["timestamp"]).total_seconds() < 86400:
return cached_data["data"]
return None
def set(self, query: str, data: pd.DataFrame, params: Dict = None):
"""设置缓存"""
cache_key = self._get_cache_key(query, params)
cache_file = self.cache_dir / f"{cache_key}.pkl"
with open(cache_file, 'wb') as f:
pickle.dump({
"timestamp": datetime.now(),
"data": data,
"query": query,
"params": params
}, f)
def clear(self):
"""清空缓存"""
for cache_file in self.cache_dir.glob("*.pkl"):
cache_file.unlink()
2. 权限控制
严格的权限控制是生产环境的必要措施:
- 数据库连接权限:只读权限,避免误操作
- 敏感字段脱敏:对敏感数据进行脱敏处理
- 查询复杂度限制:防止过于复杂的查询拖垮系统
class PermissionManager:
"""权限管理器"""
def __init__(self, rules: List[Dict[str, Any]]):
self.rules = rules
def check_table_access(self, user: str, table: str) -> bool:
"""检查表访问权限"""
for rule in self.rules:
if rule["user"] == user and table in rule["allowed_tables"]:
return True
return False
def check_query_complexity(self, query: str) -> Tuple[bool, str]:
"""检查查询复杂度"""
# 检查是否包含危险操作
dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "UPDATE", "INSERT"]
for keyword in dangerous_keywords:
if keyword in query.upper():
return False, f"查询包含危险操作: {keyword}"
# 检查 JOIN 数量
join_count = query.upper().count("JOIN")
if join_count > 5:
return False, f"JOIN 数量过多 ({join_count}),可能影响性能"
# 检查子查询深度
subquery_depth = query.count("(") - query.count(")")
if abs(subquery_depth) > 3:
return False, f"子查询深度过大 ({abs(subquery_depth)})"
return True, ""
def mask_sensitive_data(
self,
data: pd.DataFrame,
table: str,
sensitive_fields: List[str]
) -> pd.DataFrame:
"""脱敏敏感数据"""
masked_data = data.copy()
for field in sensitive_fields:
if field in masked_data.columns:
# 简单脱敏:显示前3位和后2位
masked_data[field] = masked_data[field].apply(
lambda x: str(x)[:3] + "****" + str(x)[-2:] if pd.notna(x) else x
)
return masked_data
3. 错误处理与降级
完善的错误处理机制确保系统稳定性:
- SQL 语法错误:提供友好的错误提示和建议
- 执行超时:设置合理的超时时间
- 数据源不可用:提供降级方案和重试机制
from tenacity import retry, stop_after_attempt, wait_exponential
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobustReportGenerator:
"""健壮的报告生成器"""
def __init__(self):
self.logger = logger
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10)
)
def execute_query_with_retry(
self,
source: DataSource,
query: str,
timeout: int = 30
) -> pd.DataFrame:
"""带重试的查询执行"""
try:
engine = self.source_manager.create_connection(source)
with engine.connect() as conn:
# 设置超时
result = pd.read_sql(text(query), conn)
return result
except Exception as e:
self.logger.error(f"查询执行失败: {str(e)}")
raise
def generate_report_with_fallback(
self,
requirement: str,
source: DataSource
) -> Dict[str, Any]:
"""带降级的报告生成"""
try:
# 尝试生成完整报告
return self.generate_report(requirement, source)
except TimeoutError:
self.logger.warning("查询超时,尝试简化查询")
# 降级:生成简化报告
simplified_requirement = self._simplify_requirement(requirement)
return self.generate_report(simplified_requirement, source)
except Exception as e:
self.logger.error(f"报告生成失败: {str(e)}")
return {
"status": "error",
"message": f"报告生成失败: {str(e)}",
"suggestion": "请简化查询需求或联系管理员"
}
def _simplify_requirement(self, requirement: str) -> str:
"""简化需求"""
# 移除复杂的过滤条件
simplified = re.sub(r'并且.*', '', requirement)
return simplified
常见陷阱与解决方案
陷阱1: 生成的 SQL 性能差
问题:Agent 生成的 SQL 可能没有利用索引,导致全表扫描。
解决方案:
- 在元数据中包含索引信息
- 使用执行计划分析工具验证 SQL
- 对生成的 SQL 进行优化建议
class SQLValidator:
"""SQL 验证器"""
def __init__(self, source_manager: DataSourceManager):
self.source_manager = source_manager
def validate_sql_performance(
self,
sql: str,
source: DataSource
) -> Dict[str, Any]:
"""验证 SQL 性能"""
engine = self.source_manager.create_connection(source)
with engine.connect() as conn:
# 获取执行计划
explain_result = conn.execute(text(f"EXPLAIN ANALYZE {sql}"))
plan = explain_result.fetchall()
plan_text = "\n".join([str(row) for row in plan])
# 分析性能问题
issues = []
suggestions = []
if "Seq Scan" in plan_text:
issues.append("存在全表扫描")
suggestions.append("考虑添加索引或优化 WHERE 条件")
if "Nested Loop" in plan_text:
issues.append("存在低效的嵌套循环")
suggestions.append("考虑改用 Hash Join 或添加索引")
if plan_text.count("Index Scan") > 3:
issues.append("使用了多个索引扫描")
suggestions.append("考虑使用复合索引")
return {
"execution_plan": plan_text,
"issues": issues,
"suggestions": suggestions,
"is_optimized": len(issues) == 0
}
陷阱2: 数据质量问题未被发现
问题:生成的报表基于有问题的数据,导致错误的决策。
解决方案:
- 实施全面的数据质量检查
- 对异常数据进行标记和处理
- 提供数据质量报告
class AdvancedDataQualityChecker:
"""高级数据质量检查器"""
def check_data_completeness(
self,
data: pd.DataFrame,
required_columns: List[str],
rules: Dict[str, Any]
) -> Dict[str, Any]:
"""检查数据完整性"""
report = {
"total_rows": len(data),
"completeness": {},
"issues": []
}
for column in required_columns:
if column not in data.columns:
report["issues"].append(f"缺少必需列: {column}")
continue
null_count = data[column].isnull().sum()
completeness = 1.0 - (null_count / len(data))
report["completeness"][column] = {
"completeness": completeness,
"null_count": null_count,
"threshold": rules.get(column, {}).get("completeness_threshold", 0.95)
}
if completeness < rules.get(column, {}).get("completeness_threshold", 0.95):
report["issues"].append(
f"列 {column} 完整性不足: {completeness:.2%} "
f"(阈值: {rules.get(column, {}).get('completeness_threshold', 0.95):.2%})"
)
return report
def detect_data_drift(
self,
current_data: pd.DataFrame,
historical_stats: Dict[str, Dict[str, float]]
) -> Dict[str, Any]:
"""检测数据漂移"""
drift_report = {
"drift_detected": False,
"drifted_columns": [],
"details": {}
}
numeric_columns = current_data.select_dtypes(include=[np.number]).columns
for column in numeric_columns:
if column not in historical_stats:
continue
current_mean = current_data[column].mean()
historical_mean = historical_stats[column].get("mean", 0)
# 计算漂移程度
drift_ratio = abs(current_mean - historical_mean) / (abs(historical_mean) + 1e-6)
drift_report["details"][column] = {
"current_mean": current_mean,
"historical_mean": historical_mean,
"drift_ratio": drift_ratio
}
if drift_ratio > 0.2: # 漂移超过20%
drift_report["drift_detected"] = True
drift_report["drifted_columns"].append(column)
return drift_report
陷阱3: 可视化选择不当
问题:为不适合的数据类型选择了错误的可视化方式。
解决方案:
- 基于数据特征和业务场景推荐可视化
- 提供多种可视化选项供用户选择
- 包含可视化选择的解释
性能优化考虑
查询性能优化
1. SQL 优化
class SQLOptimizer:
"""SQL 优化器"""
def __init__(self, source_manager: DataSourceManager):
self.source_manager = source_manager
def optimize_query(
self,
sql: str,
source: DataSource
) -> Tuple[str, List[str]]:
"""优化查询"""
optimizations = []
optimized_sql = sql
# 优化1: 添加 LIMIT 限制结果集大小
if "LIMIT" not in optimized_sql.upper():
optimized_sql += " LIMIT 10000"
optimizations.append("添加 LIMIT 限制结果集大小")
# 优化2: 优化 JOIN 顺序
if "JOIN" in optimized_sql.upper():
optimized_sql = self._optimize_join_order(optimized_sql, source)
optimizations.append("优化 JOIN 顺序")
# 优化3: 添加索引提示
optimized_sql = self._add_index_hints(optimized_sql, source)
optimizations.append("添加索引提示")
return optimized_sql, optimizations
def _optimize_join_order(self, sql: str, source: DataSource) -> str:
"""优化 JOIN 顺序(简化实现)"""
# 实际实现需要基于表的统计信息
return sql
def _add_index_hints(self, sql: str, source: DataSource) -> str:
"""添加索引提示(简化实现)"""
# 实际实现需要基于表的索引信息
return sql
2. 并行查询处理
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
class ParallelQueryExecutor:
"""并行查询执行器"""
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
self.lock = threading.Lock()
def execute_queries_in_parallel(
self,
queries: List[Tuple[DataSource, str]]
) -> List[pd.DataFrame]:
"""并行执行多个查询"""
results = [None] * len(queries)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._execute_single_query, idx, source, query): idx
for idx, (source, query) in enumerate(queries)
}
for future in as_completed(futures):
idx = futures[future]
try:
results[idx] = future.result()
except Exception as e:
print(f"查询 {idx} 执行失败: {str(e)}")
results[idx] = None
return results
def _execute_single_query(
self,
idx: int,
source: DataSource,
query: str
) -> pd.DataFrame:
"""执行单个查询"""
engine = create_engine(self._build_connection_string(source))
with engine.connect() as conn:
return pd.read_sql(text(query), conn)
def _build_connection_string(self, source: DataSource) -> str:
"""构建连接字符串"""
# 实现与 DataSourceManager 相同的逻辑
pass
缓存策略优化
1. 智能缓存失效
class IntelligentCache:
"""智能缓存"""
def __init__(self):
self.cache = {}
self.access_count = {}
self.last_access = {}
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
if key in self.cache:
self.access_count[key] = self.access_count.get(key, 0) + 1
self.last_access[key] = datetime.now()
return self.cache[key]
return None
def set(self, key: str, value: Any, ttl: int = 3600):
"""设置缓存"""
self.cache[key] = {
"value": value,
"created_at": datetime.now(),
"ttl": ttl
}
self.access_count[key] = 1
self.last_access[key] = datetime.now()
def cleanup_expired(self):
"""清理过期缓存"""
now = datetime.now()
expired_keys = []
for key, data in self.cache.items():
if (now - data["created_at"]).total_seconds() > data["ttl"]:
expired_keys.append(key)
for key in expired_keys:
del self.cache[key]
del self.access_count[key]
del self.last_access[key]
def evict_lru(self, count: int = 10):
"""驱逐最少使用的缓存"""
# 按访问次数和最后访问时间排序
sorted_keys = sorted(
self.access_count.keys(),
key=lambda k: (self.access_count[k], self.last_access[k])
)
for key in sorted_keys[:count]:
del self.cache[key]
del self.access_count[key]
del self.last_access[key]
2. 预计算常用报表
class ReportPrecomputer:
"""报表预计算"""
def __init__(self, report_generator: ReportGenerationAgent):
self.report_generator = report_generator
self.precomputed_reports = {}
def precompute_popular_reports(
self,
popular_queries: List[str],
source: DataSource
):
"""预计算热门报表"""
for query in popular_queries:
try:
report = self.report_generator.generate_report(query, source)
self.precomputed_reports[query] = {
"report": report,
"computed_at": datetime.now(),
"ttl": 3600 # 1小时有效期
}
except Exception as e:
print(f"预计算失败: {query}, 错误: {str(e)}")
def get_precomputed_report(self, query: str) -> Optional[Dict[str, Any]]:
"""获取预计算的报表"""
if query in self.precomputed_reports:
cached = self.precomputed_reports[query]
age = (datetime.now() - cached["computed_at"]).total_seconds()
if age < cached["ttl"]:
return cached["report"]
else:
# 缓存过期,删除
del self.precomputed_reports[query]
return None
参考资源
官方文档与框架
- LangChain: https://python.langchain.com/ - 用于构建 LLM 应用的框架,包含 SQL Database Chain
- Pandas: https://pandas.pydata.org/ - 数据处理和分析的核心库
- SQLAlchemy: https://www.sqlalchemy.org/ - Python SQL 工具包和对象关系映射
- OpenAI API: https://platform.openai.com/docs/ - 用于增强 SQL 生成能力
相关技术
- Text-to-SQL: 自然语言到 SQL 的转换技术
- 数据质量: https://dataquality.ai/ - 数据质量最佳实践
- 数据可视化: https://plotly.com/ - 交互式可视化库
学术资源
- "Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task" - Text-to-SQL 基准数据集
- "Table-GPT: Table-tuned GPT for Diverse Table Tasks" - 表格数据处理的 LLM 应用研究
社区资源
- Awesome Text-to-SQL: https://github.com/yechens/NL2SQL - Text-to-SQL 资源集合
- SQL-flow: https://github.com/sqlflow/sqlflow - SQL 到机器学习工作流的桥梁
- Metabase: https://www.metabase.com/ - 开源 BI 工具,可参考其 SQL 生成逻辑
自动化报表 Agent 是数据智能化的关键组件,它将 LLM 的理解能力与传统数据分析工具结合,大幅降低了数据洞察的门槛。随着技术的不断成熟,我们有理由相信,数据分析将变得更加智能、高效和普惠。