自动化报表 Agent:数据提取、分析与可视化

自动化报表 Agent 能够理解数据源、生成查询逻辑、分析数据趋势并推荐合适的可视化方案,大幅提升数据到洞察的转化效率。

自动化报表 Agent:数据提取、分析与可视化

在企业数据应用中,从原始数据到业务洞察往往需要经历多个环节:连接数据源、编写查询逻辑、清洗数据、计算指标、选择图表类型、生成可视化报告。传统模式下,这个过程高度依赖数据分析师和工程师的专业技能,耗时且容易出错。

自动化报表 Agent 的核心价值在于将整个流程智能化:它能够理解业务需求,自动生成高质量的 SQL 查询,进行数据质量检查,识别数据模式,并推荐最合适的可视化方案。这不仅提升了效率,还降低了数据分析的门槛,让更多业务人员能够直接从数据中获得洞察。

Rendering diagram...

核心概念与架构设计

多 Agent 协作模式

自动化报表 Agent 采用 Supervisor/Worker 协作模式,由一个主协调 Agent 负责任务分配和结果聚合,多个专业 Agent 负责具体的子任务。这种模式既保证了任务的灵活性,又确保了各个环节的质量控制。

主协调 Agent 的职责包括:

  • 解析用户需求并拆分为可执行的子任务
  • 协调各个专业 Agent 的执行顺序
  • 收集中间结果并进行质量验证
  • 处理异常情况并进行回滚或修正
  • 汇总最终结果并生成可交付的报表

专业 Agent 包括:

  • 数据源连接 Agent:管理数据库连接、表结构理解、权限验证
  • SQL 生成 Agent:根据需求生成优化后的 SQL 查询
  • 数据质量 Agent:检查数据完整性、一致性、异常值
  • 分析逻辑 Agent:生成聚合、分组、计算指标的逻辑
  • 可视化 Agent:根据数据特征推荐图表类型
  • 报告生成 Agent:组装最终报告并添加注释说明

数据理解与元数据管理

Agent 首先需要理解数据源的结构和语义。这包括表结构、字段类型、关系、业务语义等。元数据管理是整个系统的基础,直接影响后续 SQL 生成和分析逻辑的准确性。

Rendering diagram...

SQL 生成优化机制

SQL 生成是自动化报表的核心能力。Agent 需要将自然语言需求转换为准确且高效的 SQL 查询。这个过程涉及意图识别、实体映射、逻辑构建、性能优化等多个环节。

SQL 生成 Agent 采用多阶段策略:

  1. 需求解析阶段:识别查询目标、过滤条件、聚合维度、排序规则
  2. 实体映射阶段:将业务术语映射到实际的表名、字段名
  3. 逻辑构建阶段:构建 SELECT、FROM、WHERE、GROUP BY、HAVING、ORDER BY 子句
  4. 优化阶段:应用索引提示、避免全表扫描、优化 JOIN 顺序
  5. 验证阶段:检查 SQL 语法、验证字段存在性、预估执行成本

关键技术实现

报告生成核心框架

下面实现一个完整的自动化报表生成框架,涵盖数据源连接、SQL 生成、数据分析和可视化推荐:

import os
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.engine import Engine
from datetime import datetime, timedelta
import json
import re

class DataSourceType(Enum):
    MYSQL = "mysql"
    POSTGRESQL = "postgresql"
    CLICKHOUSE = "clickhouse"
    SQLITE = "sqlite"

@dataclass
class DataSource:
    name: str
    type: DataSourceType
    host: str
    port: int
    database: str
    username: str
    password: str
    schema: Optional[str] = None

@dataclass
class TableMetadata:
    name: str
    columns: List[Dict[str, Any]]
    indexes: List[Dict[str, Any]]
    row_count: int
    sample_data: pd.DataFrame

@dataclass
class SQLQuery:
    query: str
    explanation: str
    estimated_rows: int
    execution_plan: str

@dataclass
class DataQualityReport:
    total_rows: int
    null_counts: Dict[str, int]
    duplicate_count: int
    outlier_count: int
    completeness_score: float
    consistency_score: float

@dataclass
class AnalysisInsight:
    metric: str
    value: Any
    trend: str  # "increasing", "decreasing", "stable"
    change_percent: float
    description: str

@dataclass
class VisualizationRecommendation:
    chart_type: str  # "line", "bar", "pie", "scatter", "heatmap"
    x_axis: str
    y_axis: str
    color_by: Optional[str]
    title: str
    rationale: str

class DataSourceManager:
    """管理数据源连接和元数据"""
    
    def __init__(self):
        self.connections: Dict[str, Engine] = {}
        self.metadata_cache: Dict[str, Dict[str, TableMetadata]] = {}
    
    def create_connection(self, source: DataSource) -> Engine:
        """创建数据库连接"""
        if source.name in self.connections:
            return self.connections[source.name]
        
        connection_string = self._build_connection_string(source)
        engine = create_engine(connection_string, pool_pre_ping=True)
        self.connections[source.name] = engine
        return engine
    
    def _build_connection_string(self, source: DataSource) -> str:
        """构建数据库连接字符串"""
        if source.type == DataSourceType.MYSQL:
            return f"mysql+pymysql://{source.username}:{source.password}@{source.host}:{source.port}/{source.database}"
        elif source.type == DataSourceType.POSTGRESQL:
            return f"postgresql://{source.username}:{source.password}@{source.host}:{source.port}/{source.database}"
        elif source.type == DataSourceType.CLICKHOUSE:
            return f"clickhouse://{source.username}:{source.password}@{source.host}:{source.port}/{source.database}"
        elif source.type == DataSourceType.SQLITE:
            return f"sqlite:///{source.host}"
        else:
            raise ValueError(f"Unsupported data source type: {source.type}")
    
    def get_table_metadata(self, source: DataSource, table_name: str) -> TableMetadata:
        """获取表的元数据"""
        cache_key = f"{source.name}:{table_name}"
        if cache_key in self.metadata_cache:
            return self.metadata_cache[cache_key]
        
        engine = self.create_connection(source)
        inspector = inspect(engine)
        
        # 获取列信息
        columns = []
        for column in inspector.get_columns(table_name):
            columns.append({
                "name": column["name"],
                "type": str(column["type"]),
                "nullable": column.get("nullable", True),
                "default": column.get("default"),
                "autoincrement": column.get("autoincrement", False)
            })
        
        # 获取索引信息
        indexes = []
        for index in inspector.get_indexes(table_name):
            indexes.append({
                "name": index["name"],
                "columns": index["column_names"],
                "unique": index.get("unique", False)
            })
        
        # 获取行数和样本数据
        with engine.connect() as conn:
            result = conn.execute(text(f"SELECT COUNT(*) FROM {table_name}"))
            row_count = result.scalar()
            
            sample_query = text(f"SELECT * FROM {table_name} LIMIT 100")
            sample_data = pd.read_sql(sample_query, conn)
        
        metadata = TableMetadata(
            name=table_name,
            columns=columns,
            indexes=indexes,
            row_count=row_count,
            sample_data=sample_data
        )
        
        self.metadata_cache[cache_key] = metadata
        return metadata
    
    def execute_query(self, source: DataSource, query: str) -> pd.DataFrame:
        """执行 SQL 查询"""
        engine = self.create_connection(source)
        with engine.connect() as conn:
            return pd.read_sql(text(query), conn)

class SQLGeneratorAgent:
    """SQL 生成 Agent"""
    
    def __init__(self, source_manager: DataSourceManager):
        self.source_manager = source_manager
        self.entity_mappings: Dict[str, Dict[str, str]] = {}
    
    def understand_requirement(self, requirement: str, source: DataSource) -> Dict[str, Any]:
        """理解用户需求并提取关键信息"""
        # 简化的需求理解,实际应用中应该使用 LLM
        entities = {
            "tables": self._extract_tables(requirement, source),
            "metrics": self._extract_metrics(requirement),
            "dimensions": self._extract_dimensions(requirement),
            "filters": self._extract_filters(requirement),
            "time_range": self._extract_time_range(requirement),
            "aggregation": self._extract_aggregation(requirement)
        }
        return entities
    
    def _extract_tables(self, requirement: str, source: DataSource) -> List[str]:
        """从需求中提取表名"""
        # 简化实现,实际应该从元数据中心搜索
        metadata = self.source_manager.metadata_cache.get(source.name, {})
        available_tables = list(metadata.keys())
        
        # 关键词匹配
        table_keywords = {
            "订单": ["orders", "order"],
            "用户": ["users", "user", "customers"],
            "商品": ["products", "items", "goods"],
            "销售": ["sales", "transactions"]
        }
        
        matched_tables = []
        for keyword, table_names in table_keywords.items():
            if keyword in requirement:
                for table_name in table_names:
                    if table_name in available_tables:
                        matched_tables.append(table_name)
                        break
        
        # 如果没有匹配到,使用第一个可用表
        if not matched_tables and available_tables:
            matched_tables = [available_tables[0]]
        
        return matched_tables
    
    def _extract_metrics(self, requirement: str) -> List[str]:
        """提取指标"""
        metric_keywords = {
            "销售额": ["amount", "price", "total"],
            "数量": ["quantity", "count", "num"],
            "利润": ["profit", "margin"],
            "成本": ["cost", "expense"],
            "转化率": ["conversion_rate", "conversion"]
        }
        
        metrics = []
        for metric_name, field_names in metric_keywords.items():
            if metric_name in requirement:
                metrics.append(metric_name)
        
        return metrics
    
    def _extract_dimensions(self, requirement: str) -> List[str]:
        """提取维度"""
        dimension_keywords = ["按", "分组", "类别", "类型", "地区", "时间", "日期"]
        dimensions = []
        
        for keyword in dimension_keywords:
            if keyword in requirement:
                dimensions.append(keyword)
        
        return dimensions
    
    def _extract_filters(self, requirement: str) -> List[Dict[str, Any]]:
        """提取过滤条件"""
        filters = []
        
        # 简化的过滤条件提取
        if "大于" in requirement:
            match = re.search(r'(\w+)大于(\d+)', requirement)
            if match:
                filters.append({
                    "field": match.group(1),
                    "operator": ">",
                    "value": match.group(2)
                })
        
        if "小于" in requirement:
            match = re.search(r'(\w+)小于(\d+)', requirement)
            if match:
                filters.append({
                    "field": match.group(1),
                    "operator": "<",
                    "value": match.group(2)
                })
        
        return filters
    
    def _extract_time_range(self, requirement: str) -> Optional[Dict[str, str]]:
        """提取时间范围"""
        time_patterns = {
            "今天": ("NOW()::DATE", "NOW()::DATE + INTERVAL 1 DAY"),
            "昨天": ("NOW()::DATE - INTERVAL 1 DAY", "NOW()::DATE"),
            "本周": ("date_trunc('week', NOW())", "date_trunc('week', NOW()) + INTERVAL 1 WEEK"),
            "本月": ("date_trunc('month', NOW())", "date_trunc('month', NOW()) + INTERVAL 1 MONTH"),
            "最近7天": ("NOW()::DATE - INTERVAL 7 DAY", "NOW()::DATE"),
            "最近30天": ("NOW()::DATE - INTERVAL 30 DAY", "NOW()::DATE")
        }
        
        for keyword, (start_date, end_date) in time_patterns.items():
            if keyword in requirement:
                return {"start": start_date, "end": end_date}
        
        return None
    
    def _extract_aggregation(self, requirement: str) -> str:
        """提取聚合类型"""
        if "总和" in requirement or "总计" in requirement:
            return "SUM"
        elif "平均" in requirement:
            return "AVG"
        elif "数量" in requirement or "次数" in requirement:
            return "COUNT"
        elif "最大" in requirement:
            return "MAX"
        elif "最小" in requirement:
            return "MIN"
        else:
            return "SUM"
    
    def generate_sql(self, requirement: str, source: DataSource) -> SQLQuery:
        """生成 SQL 查询"""
        entities = self.understand_requirement(requirement, source)
        
        if not entities["tables"]:
            raise ValueError("无法从需求中识别出相关表")
        
        table = entities["tables"][0]
        metadata = self.source_manager.get_table_metadata(source, table)
        
        # 构建 SELECT 子句
        select_clauses = []
        group_by_clauses = []
        
        # 添加维度字段
        if entities["dimensions"]:
            for dim in entities["dimensions"]:
                # 简化:假设维度字段名包含关键词
                matching_columns = [
                    col["name"] for col in metadata.columns
                    if dim in col["name"] or col["name"] in ["category", "type", "region", "date", "created_at"]
                ]
                for col_name in matching_columns[:3]:  # 限制最多3个维度
                    select_clauses.append(col_name)
                    group_by_clauses.append(col_name)
        
        # 添加指标字段
        metric_field = "amount"  # 默认字段
        if entities["metrics"]:
            for metric in entities["metrics"]:
                matching_columns = [
                    col["name"] for col in metadata.columns
                    if metric in col["name"] or col["name"] in ["amount", "price", "total", "count", "quantity"]
                ]
                if matching_columns:
                    metric_field = matching_columns[0]
                    break
        
        agg_func = entities["aggregation"]
        select_clauses.append(f"{agg_func}({metric_field}) as {agg_func.lower()}_{metric_field}")
        
        # 构建 WHERE 子句
        where_clauses = []
        if entities["time_range"]:
            date_column = "created_at"  # 假设日期字段名
            where_clauses.append(f"{date_column} >= {entities['time_range']['start']}")
            where_clauses.append(f"{date_column} < {entities['time_range']['end']}")
        
        for filter_cond in entities["filters"]:
            where_clauses.append(f"{filter_cond['field']} {filter_cond['operator']} {filter_cond['value']}")
        
        # 构建 SQL
        query_parts = [
            f"SELECT {', '.join(select_clauses)}",
            f"FROM {table}"
        ]
        
        if where_clauses:
            query_parts.append(f"WHERE {' AND '.join(where_clauses)}")
        
        if group_by_clauses:
            query_parts.append(f"GROUP BY {', '.join(group_by_clauses)}")
            query_parts.append(f"ORDER BY {agg_func.lower()}_{metric_field} DESC")
        
        query = " ".join(query_parts)
        
        # 生成解释
        explanation = self._generate_explanation(entities, agg_func, metric_field)
        
        return SQLQuery(
            query=query,
            explanation=explanation,
            estimated_rows=metadata.row_count // 10,  # 简化估算
            execution_plan=""  # 实际应用中应该执行 EXPLAIN
        )
    
    def _generate_explanation(self, entities: Dict[str, Any], agg_func: str, metric_field: str) -> str:
        """生成 SQL 解释"""
        parts = []
        parts.append(f"查询 {entities['tables'][0]} 表")
        
        if entities["dimensions"]:
            parts.append(f"按 {', '.join(entities['dimensions'])} 分组")
        
        parts.append(f"计算 {metric_field} 的 {agg_func} 值")
        
        if entities["time_range"]:
            parts.append(f"时间范围: {entities['time_range']['start']} 至 {entities['time_range']['end']}")
        
        if entities["filters"]:
            filter_desc = ", ".join([f"{f['field']} {f['operator']} {f['value']}" for f in entities["filters"]])
            parts.append(f"过滤条件: {filter_desc}")
        
        return ",".join(parts) + "。"

class DataQualityAgent:
    """数据质量检查 Agent"""
    
    def check_quality(self, data: pd.DataFrame, metadata: TableMetadata) -> DataQualityReport:
        """检查数据质量"""
        total_rows = len(data)
        
        # 检查空值
        null_counts = data.isnull().sum().to_dict()
        
        # 检查重复行
        duplicate_count = data.duplicated().sum()
        
        # 检查异常值(简化:使用 IQR 方法)
        outlier_count = 0
        numeric_columns = data.select_dtypes(include=[np.number]).columns
        for col in numeric_columns:
            Q1 = data[col].quantile(0.25)
            Q3 = data[col].quantile(0.75)
            IQR = Q3 - Q1
            lower_bound = Q1 - 1.5 * IQR
            upper_bound = Q3 + 1.5 * IQR
            outliers = ((data[col] < lower_bound) | (data[col] > upper_bound)).sum()
            outlier_count += outliers
        
        # 计算完整性评分
        total_cells = len(data.columns) * total_rows
        null_cells = sum(null_counts.values())
        completeness_score = 1.0 - (null_cells / total_cells) if total_cells > 0 else 1.0
        
        # 计算一致性评分(简化)
        consistency_score = 1.0 - (duplicate_count / total_rows) if total_rows > 0 else 1.0
        
        return DataQualityReport(
            total_rows=total_rows,
            null_counts=null_counts,
            duplicate_count=duplicate_count,
            outlier_count=outlier_count,
            completeness_score=completeness_score,
            consistency_score=consistency_score
        )
    
    def generate_quality_report(self, report: DataQualityReport) -> str:
        """生成质量报告"""
        parts = []
        parts.append(f"数据质量检查报告")
        parts.append(f"总行数: {report.total_rows}")
        parts.append(f"完整性评分: {report.completeness_score:.2%}")
        parts.append(f"一致性评分: {report.consistency_score:.2%}")
        parts.append(f"重复行数: {report.duplicate_count}")
        parts.append(f"异常值数量: {report.outlier_count}")
        
        if report.null_counts:
            parts.append("空值统计:")
            for col, count in report.null_counts.items():
                if count > 0:
                    parts.append(f"  {col}: {count}")
        
        return "\n".join(parts)

class AnalysisLogicAgent:
    """分析逻辑生成 Agent"""
    
    def generate_insights(self, data: pd.DataFrame, metadata: TableMetadata) -> List[AnalysisInsight]:
        """生成分析洞察"""
        insights = []
        
        numeric_columns = data.select_dtypes(include=[np.number]).columns
        
        for col in numeric_columns:
            # 计算基础指标
            current_value = data[col].iloc[-1] if len(data) > 0 else None
            if current_value is None:
                continue
            
            # 计算趋势
            if len(data) >= 2:
                prev_value = data[col].iloc[-2]
                change_percent = ((current_value - prev_value) / prev_value * 100) if prev_value != 0 else 0
                
                if change_percent > 5:
                    trend = "increasing"
                elif change_percent < -5:
                    trend = "decreasing"
                else:
                    trend = "stable"
                
                insight = AnalysisInsight(
                    metric=col,
                    value=current_value,
                    trend=trend,
                    change_percent=change_percent,
                    description=f"{col} 当前值为 {current_value:.2f},较上期{'增长' if change_percent > 0 else '下降'} {abs(change_percent):.2f}%"
                )
                insights.append(insight)
        
        return insights
    
    def generate_summary(self, insights: List[AnalysisInsight]) -> str:
        """生成摘要"""
        if not insights:
            return "无显著洞察。"
        
        parts = ["关键洞察:"]
        for insight in insights:
            parts.append(f"- {insight.description}")
        
        return "\n".join(parts)

class VisualizationAgent:
    """可视化推荐 Agent"""
    
    def recommend_visualization(
        self,
        data: pd.DataFrame,
        requirement: str,
        metadata: TableMetadata
    ) -> List[VisualizationRecommendation]:
        """推荐可视化方案"""
        recommendations = []
        
        # 分析数据特征
        numeric_columns = data.select_dtypes(include=[np.number]).columns.tolist()
        categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
        time_columns = data.select_dtypes(include=['datetime64']).columns.tolist()
        
        # 场景1: 时间序列数据
        if time_columns and numeric_columns:
            time_col = time_columns[0]
            metric_col = numeric_columns[0]
            
            recommendations.append(VisualizationRecommendation(
                chart_type="line",
                x_axis=time_col,
                y_axis=metric_col,
                color_by=None,
                title=f"{metric_col} 趋势图",
                rationale="时间序列数据适合用折线图展示趋势变化"
            ))
        
        # 场景2: 分类对比
        if categorical_columns and numeric_columns:
            cat_col = categorical_columns[0]
            metric_col = numeric_columns[0]
            
            if len(data[cat_col].unique()) <= 10:  # 类别数量适中
                recommendations.append(VisualizationRecommendation(
                    chart_type="bar",
                    x_axis=cat_col,
                    y_axis=metric_col,
                    color_by=None,
                    title=f"{cat_col} vs {metric_col}",
                    rationale="分类数据适合用柱状图展示对比"
                ))
        
        # 场景3: 部分与整体
        if categorical_columns and len(data[categorical_columns[0]].unique()) <= 8:
            cat_col = categorical_columns[0]
            metric_col = numeric_columns[0] if numeric_columns else None
            
            if metric_col:
                recommendations.append(VisualizationRecommendation(
                    chart_type="pie",
                    x_axis=cat_col,
                    y_axis=metric_col,
                    color_by=None,
                    title=f"{cat_col} 占比分布",
                    rationale="少量分类适合用饼图展示占比"
                ))
        
        # 场景4: 相关性分析
        if len(numeric_columns) >= 2:
            recommendations.append(VisualizationRecommendation(
                chart_type="scatter",
                x_axis=numeric_columns[0],
                y_axis=numeric_columns[1],
                color_by=None,
                title=f"{numeric_columns[0]} vs {numeric_columns[1]}",
                rationale="两个数值变量适合用散点图探索相关性"
            ))
        
        return recommendations

class ReportGenerationAgent:
    """报告生成 Agent"""
    
    def __init__(self):
        self.source_manager = DataSourceManager()
        self.sql_agent = SQLGeneratorAgent(self.source_manager)
        self.quality_agent = DataQualityAgent()
        self.analysis_agent = AnalysisLogicAgent()
        self.viz_agent = VisualizationAgent()
    
    def generate_report(
        self,
        requirement: str,
        source: DataSource,
        include_visualization: bool = True
    ) -> Dict[str, Any]:
        """生成完整报告"""
        report = {
            "requirement": requirement,
            "timestamp": datetime.now().isoformat(),
            "steps": []
        }
        
        try:
            # 步骤1: 生成 SQL
            report["steps"].append({"step": 1, "action": "生成 SQL 查询"})
            sql_result = self.sql_agent.generate_sql(requirement, source)
            report["sql"] = sql_result
            
            # 步骤2: 执行查询
            report["steps"].append({"step": 2, "action": "执行数据查询"})
            data = self.source_manager.execute_query(source, sql_result.query)
            report["data_summary"] = {
                "rows": len(data),
                "columns": len(data.columns),
                "column_names": data.columns.tolist()
            }
            
            # 步骤3: 数据质量检查
            report["steps"].append({"step": 3, "action": "数据质量检查"})
            table_name = self.sql_agent.understand_requirement(requirement, source)["tables"][0]
            metadata = self.source_manager.get_table_metadata(source, table_name)
            quality_report = self.quality_agent.check_quality(data, metadata)
            report["quality_report"] = self.quality_agent.generate_quality_report(quality_report)
            
            # 步骤4: 生成分析洞察
            report["steps"].append({"step": 4, "action": "生成分析洞察"})
            insights = self.analysis_agent.generate_insights(data, metadata)
            report["insights"] = [insight.__dict__ for insight in insights]
            report["summary"] = self.analysis_agent.generate_summary(insights)
            
            # 步骤5: 推荐可视化
            if include_visualization:
                report["steps"].append({"step": 5, "action": "推荐可视化方案"})
                viz_recommendations = self.viz_agent.recommend_visualization(data, requirement, metadata)
                report["visualizations"] = [viz.__dict__ for viz in viz_recommendations]
            
            report["status"] = "success"
            report["message"] = "报告生成成功"
            
        except Exception as e:
            report["status"] = "error"
            report["message"] = f"报告生成失败: {str(e)}"
            report["error"] = str(e)
        
        return report
    
    def export_report(self, report: Dict[str, Any], format: str = "json") -> str:
        """导出报告"""
        if format == "json":
            return json.dumps(report, ensure_ascii=False, indent=2, default=str)
        elif format == "markdown":
            return self._export_markdown(report)
        else:
            raise ValueError(f"不支持的格式: {format}")
    
    def _export_markdown(self, report: Dict[str, Any]) -> str:
        """导出为 Markdown 格式"""
        lines = []
        lines.append("# 自动化报表")
        lines.append(f"**需求**: {report['requirement']}")
        lines.append(f"**生成时间**: {report['timestamp']}")
        lines.append("")
        
        if report.get("sql"):
            lines.append("## SQL 查询")
            lines.append("```sql")
            lines.append(report["sql"]["query"])
            lines.append("```")
            lines.append(f"**说明**: {report['sql']['explanation']}")
            lines.append("")
        
        if report.get("data_summary"):
            lines.append("## 数据概览")
            lines.append(f"- 行数: {report['data_summary']['rows']}")
            lines.append(f"- 列数: {report['data_summary']['columns']}")
            lines.append(f"- 字段: {', '.join(report['data_summary']['column_names'])}")
            lines.append("")
        
        if report.get("quality_report"):
            lines.append("## 数据质量")
            lines.append("```")
            lines.append(report["quality_report"])
            lines.append("```")
            lines.append("")
        
        if report.get("insights"):
            lines.append("## 分析洞察")
            for insight in report["insights"]:
                lines.append(f"- {insight['description']}")
            lines.append("")
            lines.append(f"**摘要**: {report.get('summary', '')}")
            lines.append("")
        
        if report.get("visualizations"):
            lines.append("## 可视化推荐")
            for viz in report["visualizations"]:
                lines.append(f"### {viz['title']}")
                lines.append(f"- 图表类型: {viz['chart_type']}")
                lines.append(f"- X轴: {viz['x_axis']}")
                lines.append(f"- Y轴: {viz['y_axis']}")
                lines.append(f"- 说明: {viz['rationale']}")
                lines.append("")
        
        return "\n".join(lines)

# 使用示例
if __name__ == "__main__":
    # 创建数据源(示例使用 SQLite)
    import sqlite3
    
    # 创建示例数据库
    conn = sqlite3.connect(":memory:")
    cursor = conn.cursor()
    
    # 创建示例表
    cursor.execute("""
    CREATE TABLE orders (
        id INTEGER PRIMARY KEY,
        user_id INTEGER,
        product_id INTEGER,
        amount REAL,
        quantity INTEGER,
        category TEXT,
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    )
    """)
    
    # 插入示例数据
    import random
    from datetime import datetime, timedelta
    
    categories = ["电子产品", "服装", "食品", "家居"]
    for i in range(1000):
        cursor.execute("""
        INSERT INTO orders (user_id, product_id, amount, quantity, category, created_at)
        VALUES (?, ?, ?, ?, ?, ?)
        """, (
            random.randint(1, 100),
            random.randint(1, 50),
            random.uniform(10, 1000),
            random.randint(1, 5),
            random.choice(categories),
            datetime.now() - timedelta(days=random.randint(0, 30))
        ))
    
    conn.commit()
    
    # 创建数据源
    source = DataSource(
        name="example_db",
        type=DataSourceType.SQLITE,
        host=":memory:",
        port=0,
        database="example",
        username="",
        password=""
    )
    
    # 初始化报告生成 Agent
    report_agent = ReportGenerationAgent()
    
    # 预加载元数据
    report_agent.source_manager.metadata_cache[source.name] = {
        "orders": report_agent.source_manager.get_table_metadata(source, "orders")
    }
    
    # 生成报告
    requirement = "按类别统计最近7天的销售额总和"
    report = report_agent.generate_report(requirement, source)
    
    # 输出结果
    print("=== 报告生成结果 ===")
    print(report_agent.export_report(report, format="markdown"))
    
    # 输出 SQL
    print("\n=== 生成的 SQL ===")
    print(report["sql"]["query"])
    print("\n说明:", report["sql"]["explanation"])
    
    conn.close()

这个实现展示了自动化报表 Agent 的核心功能:

  1. 数据源管理:支持多种数据库类型,自动获取表结构和元数据
  2. SQL 生成:根据自然语言需求生成优化的 SQL 查询
  3. 数据质量检查:检测空值、重复值、异常值
  4. 分析洞察生成:自动识别趋势和变化
  5. 可视化推荐:根据数据特征推荐合适的图表类型

增强的 SQL 生成策略

对于更复杂的场景,可以集成 LLM 来增强 SQL 生成能力:

import openai

class LLMEnhancedSQLGenerator:
    """使用 LLM 增强的 SQL 生成器"""
    
    def __init__(self, api_key: str):
        self.client = openai.OpenAI(api_key=api_key)
        self.prompt_template = """
你是一个专业的 SQL 生成专家。根据以下信息生成 SQL 查询:

数据库表结构:
{schema}

用户需求:
{requirement}

请生成一个高效的 SQL 查询,并解释查询逻辑。
只返回 SQL 和解释,不要有其他内容。

SQL:
"""
    
    def generate_sql_with_llm(
        self,
        requirement: str,
        table_schemas: Dict[str, Dict[str, Any]]
    ) -> Tuple[str, str]:
        """使用 LLM 生成 SQL"""
        # 构建模式描述
        schema_desc = []
        for table_name, columns in table_schemas.items():
            col_desc = ", ".join([f"{col['name']} ({col['type']})" for col in columns])
            schema_desc.append(f"表 {table_name}: {col_desc}")
        
        schema_text = "\n".join(schema_desc)
        
        # 调用 LLM
        prompt = self.prompt_template.format(
            schema=schema_text,
            requirement=requirement
        )
        
        response = self.client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "你是一个专业的 SQL 生成专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1
        )
        
        result = response.choices[0].message.content.strip()
        
        # 解析 SQL 和解释
        lines = result.split("\n")
        sql_lines = []
        explanation_lines = []
        capture_explanation = False
        
        for line in lines:
            if line.strip().upper().startswith("SELECT"):
                sql_lines.append(line)
            elif capture_explanation:
                explanation_lines.append(line)
            elif "解释" in line or "说明" in line:
                capture_explanation = True
        
        sql = "\n".join(sql_lines) if sql_lines else ""
        explanation = "\n".join(explanation_lines).strip() if explanation_lines else ""
        
        return sql, explanation
    
    def validate_and_optimize_sql(
        self,
        sql: str,
        source: DataSource,
        source_manager: DataSourceManager
    ) -> Tuple[str, List[str]]:
        """验证并优化 SQL"""
        warnings = []
        
        try:
            # 语法检查
            engine = source_manager.create_connection(source)
            with engine.connect() as conn:
                # 执行 EXPLAIN
                explain_result = conn.execute(text(f"EXPLAIN {sql}"))
                plan = explain_result.fetchall()
                
                # 分析执行计划
                plan_text = "\n".join([str(row) for row in plan])
                
                if "Seq Scan" in plan_text and "Index" not in plan_text:
                    warnings.append("警告: 查询可能进行全表扫描,建议添加索引")
                
                if "Nested Loop" in plan_text:
                    warnings.append("提示: 存在嵌套循环,可能影响性能")
                
                if plan_text.count("Index Scan") > 5:
                    warnings.append("注意: 使用了多个索引扫描,检查是否可以优化")
                
        except Exception as e:
            raise ValueError(f"SQL 验证失败: {str(e)}")
        
        return sql, warnings

最佳实践与常见陷阱

生产环境最佳实践

1. 缓存机制

对于频繁查询的报表,应该实现多层缓存:

  • 元数据缓存:表结构、索引信息等不经常变化的数据
  • 查询结果缓存:对相同查询参数的结果进行缓存
  • SQL 模板缓存:对常见查询模式的 SQL 进行缓存
from functools import lru_cache
import hashlib
import pickle
from pathlib import Path
from typing import Optional

class QueryCache:
    """查询结果缓存"""
    
    def __init__(self, cache_dir: str = ".cache"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
    
    def _get_cache_key(self, query: str, params: Dict = None) -> str:
        """生成缓存键"""
        key_str = query + str(sorted(params.items()) if params else "")
        return hashlib.md5(key_str.encode()).hexdigest()
    
    def get(self, query: str, params: Dict = None) -> Optional[pd.DataFrame]:
        """获取缓存"""
        cache_key = self._get_cache_key(query, params)
        cache_file = self.cache_dir / f"{cache_key}.pkl"
        
        if cache_file.exists():
            with open(cache_file, 'rb') as f:
                cached_data = pickle.load(f)
                # 检查是否过期(24小时)
                if (datetime.now() - cached_data["timestamp"]).total_seconds() < 86400:
                    return cached_data["data"]
        
        return None
    
    def set(self, query: str, data: pd.DataFrame, params: Dict = None):
        """设置缓存"""
        cache_key = self._get_cache_key(query, params)
        cache_file = self.cache_dir / f"{cache_key}.pkl"
        
        with open(cache_file, 'wb') as f:
            pickle.dump({
                "timestamp": datetime.now(),
                "data": data,
                "query": query,
                "params": params
            }, f)
    
    def clear(self):
        """清空缓存"""
        for cache_file in self.cache_dir.glob("*.pkl"):
            cache_file.unlink()

2. 权限控制

严格的权限控制是生产环境的必要措施:

  • 数据库连接权限:只读权限,避免误操作
  • 敏感字段脱敏:对敏感数据进行脱敏处理
  • 查询复杂度限制:防止过于复杂的查询拖垮系统
class PermissionManager:
    """权限管理器"""
    
    def __init__(self, rules: List[Dict[str, Any]]):
        self.rules = rules
    
    def check_table_access(self, user: str, table: str) -> bool:
        """检查表访问权限"""
        for rule in self.rules:
            if rule["user"] == user and table in rule["allowed_tables"]:
                return True
        return False
    
    def check_query_complexity(self, query: str) -> Tuple[bool, str]:
        """检查查询复杂度"""
        # 检查是否包含危险操作
        dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "UPDATE", "INSERT"]
        for keyword in dangerous_keywords:
            if keyword in query.upper():
                return False, f"查询包含危险操作: {keyword}"
        
        # 检查 JOIN 数量
        join_count = query.upper().count("JOIN")
        if join_count > 5:
            return False, f"JOIN 数量过多 ({join_count}),可能影响性能"
        
        # 检查子查询深度
        subquery_depth = query.count("(") - query.count(")")
        if abs(subquery_depth) > 3:
            return False, f"子查询深度过大 ({abs(subquery_depth)})"
        
        return True, ""
    
    def mask_sensitive_data(
        self,
        data: pd.DataFrame,
        table: str,
        sensitive_fields: List[str]
    ) -> pd.DataFrame:
        """脱敏敏感数据"""
        masked_data = data.copy()
        
        for field in sensitive_fields:
            if field in masked_data.columns:
                # 简单脱敏:显示前3位和后2位
                masked_data[field] = masked_data[field].apply(
                    lambda x: str(x)[:3] + "****" + str(x)[-2:] if pd.notna(x) else x
                )
        
        return masked_data

3. 错误处理与降级

完善的错误处理机制确保系统稳定性:

  • SQL 语法错误:提供友好的错误提示和建议
  • 执行超时:设置合理的超时时间
  • 数据源不可用:提供降级方案和重试机制
from tenacity import retry, stop_after_attempt, wait_exponential
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class RobustReportGenerator:
    """健壮的报告生成器"""
    
    def __init__(self):
        self.logger = logger
    
    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10)
    )
    def execute_query_with_retry(
        self,
        source: DataSource,
        query: str,
        timeout: int = 30
    ) -> pd.DataFrame:
        """带重试的查询执行"""
        try:
            engine = self.source_manager.create_connection(source)
            with engine.connect() as conn:
                # 设置超时
                result = pd.read_sql(text(query), conn)
                return result
        except Exception as e:
            self.logger.error(f"查询执行失败: {str(e)}")
            raise
    
    def generate_report_with_fallback(
        self,
        requirement: str,
        source: DataSource
    ) -> Dict[str, Any]:
        """带降级的报告生成"""
        try:
            # 尝试生成完整报告
            return self.generate_report(requirement, source)
        except TimeoutError:
            self.logger.warning("查询超时,尝试简化查询")
            # 降级:生成简化报告
            simplified_requirement = self._simplify_requirement(requirement)
            return self.generate_report(simplified_requirement, source)
        except Exception as e:
            self.logger.error(f"报告生成失败: {str(e)}")
            return {
                "status": "error",
                "message": f"报告生成失败: {str(e)}",
                "suggestion": "请简化查询需求或联系管理员"
            }
    
    def _simplify_requirement(self, requirement: str) -> str:
        """简化需求"""
        # 移除复杂的过滤条件
        simplified = re.sub(r'并且.*', '', requirement)
        return simplified

常见陷阱与解决方案

陷阱1: 生成的 SQL 性能差

问题:Agent 生成的 SQL 可能没有利用索引,导致全表扫描。

解决方案:

  • 在元数据中包含索引信息
  • 使用执行计划分析工具验证 SQL
  • 对生成的 SQL 进行优化建议
class SQLValidator:
    """SQL 验证器"""
    
    def __init__(self, source_manager: DataSourceManager):
        self.source_manager = source_manager
    
    def validate_sql_performance(
        self,
        sql: str,
        source: DataSource
    ) -> Dict[str, Any]:
        """验证 SQL 性能"""
        engine = self.source_manager.create_connection(source)
        
        with engine.connect() as conn:
            # 获取执行计划
            explain_result = conn.execute(text(f"EXPLAIN ANALYZE {sql}"))
            plan = explain_result.fetchall()
            
            plan_text = "\n".join([str(row) for row in plan])
            
            # 分析性能问题
            issues = []
            suggestions = []
            
            if "Seq Scan" in plan_text:
                issues.append("存在全表扫描")
                suggestions.append("考虑添加索引或优化 WHERE 条件")
            
            if "Nested Loop" in plan_text:
                issues.append("存在低效的嵌套循环")
                suggestions.append("考虑改用 Hash Join 或添加索引")
            
            if plan_text.count("Index Scan") > 3:
                issues.append("使用了多个索引扫描")
                suggestions.append("考虑使用复合索引")
            
            return {
                "execution_plan": plan_text,
                "issues": issues,
                "suggestions": suggestions,
                "is_optimized": len(issues) == 0
            }

陷阱2: 数据质量问题未被发现

问题:生成的报表基于有问题的数据,导致错误的决策。

解决方案:

  • 实施全面的数据质量检查
  • 对异常数据进行标记和处理
  • 提供数据质量报告
class AdvancedDataQualityChecker:
    """高级数据质量检查器"""
    
    def check_data_completeness(
        self,
        data: pd.DataFrame,
        required_columns: List[str],
        rules: Dict[str, Any]
    ) -> Dict[str, Any]:
        """检查数据完整性"""
        report = {
            "total_rows": len(data),
            "completeness": {},
            "issues": []
        }
        
        for column in required_columns:
            if column not in data.columns:
                report["issues"].append(f"缺少必需列: {column}")
                continue
            
            null_count = data[column].isnull().sum()
            completeness = 1.0 - (null_count / len(data))
            
            report["completeness"][column] = {
                "completeness": completeness,
                "null_count": null_count,
                "threshold": rules.get(column, {}).get("completeness_threshold", 0.95)
            }
            
            if completeness < rules.get(column, {}).get("completeness_threshold", 0.95):
                report["issues"].append(
                    f"列 {column} 完整性不足: {completeness:.2%} "
                    f"(阈值: {rules.get(column, {}).get('completeness_threshold', 0.95):.2%})"
                )
        
        return report
    
    def detect_data_drift(
        self,
        current_data: pd.DataFrame,
        historical_stats: Dict[str, Dict[str, float]]
    ) -> Dict[str, Any]:
        """检测数据漂移"""
        drift_report = {
            "drift_detected": False,
            "drifted_columns": [],
            "details": {}
        }
        
        numeric_columns = current_data.select_dtypes(include=[np.number]).columns
        
        for column in numeric_columns:
            if column not in historical_stats:
                continue
            
            current_mean = current_data[column].mean()
            historical_mean = historical_stats[column].get("mean", 0)
            
            # 计算漂移程度
            drift_ratio = abs(current_mean - historical_mean) / (abs(historical_mean) + 1e-6)
            
            drift_report["details"][column] = {
                "current_mean": current_mean,
                "historical_mean": historical_mean,
                "drift_ratio": drift_ratio
            }
            
            if drift_ratio > 0.2:  # 漂移超过20%
                drift_report["drift_detected"] = True
                drift_report["drifted_columns"].append(column)
        
        return drift_report

陷阱3: 可视化选择不当

问题:为不适合的数据类型选择了错误的可视化方式。

解决方案:

  • 基于数据特征和业务场景推荐可视化
  • 提供多种可视化选项供用户选择
  • 包含可视化选择的解释

性能优化考虑

查询性能优化

1. SQL 优化

class SQLOptimizer:
    """SQL 优化器"""
    
    def __init__(self, source_manager: DataSourceManager):
        self.source_manager = source_manager
    
    def optimize_query(
        self,
        sql: str,
        source: DataSource
    ) -> Tuple[str, List[str]]:
        """优化查询"""
        optimizations = []
        optimized_sql = sql
        
        # 优化1: 添加 LIMIT 限制结果集大小
        if "LIMIT" not in optimized_sql.upper():
            optimized_sql += " LIMIT 10000"
            optimizations.append("添加 LIMIT 限制结果集大小")
        
        # 优化2: 优化 JOIN 顺序
        if "JOIN" in optimized_sql.upper():
            optimized_sql = self._optimize_join_order(optimized_sql, source)
            optimizations.append("优化 JOIN 顺序")
        
        # 优化3: 添加索引提示
        optimized_sql = self._add_index_hints(optimized_sql, source)
        optimizations.append("添加索引提示")
        
        return optimized_sql, optimizations
    
    def _optimize_join_order(self, sql: str, source: DataSource) -> str:
        """优化 JOIN 顺序(简化实现)"""
        # 实际实现需要基于表的统计信息
        return sql
    
    def _add_index_hints(self, sql: str, source: DataSource) -> str:
        """添加索引提示(简化实现)"""
        # 实际实现需要基于表的索引信息
        return sql

2. 并行查询处理

from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

class ParallelQueryExecutor:
    """并行查询执行器"""
    
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers
        self.lock = threading.Lock()
    
    def execute_queries_in_parallel(
        self,
        queries: List[Tuple[DataSource, str]]
    ) -> List[pd.DataFrame]:
        """并行执行多个查询"""
        results = [None] * len(queries)
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = {
                executor.submit(self._execute_single_query, idx, source, query): idx
                for idx, (source, query) in enumerate(queries)
            }
            
            for future in as_completed(futures):
                idx = futures[future]
                try:
                    results[idx] = future.result()
                except Exception as e:
                    print(f"查询 {idx} 执行失败: {str(e)}")
                    results[idx] = None
        
        return results
    
    def _execute_single_query(
        self,
        idx: int,
        source: DataSource,
        query: str
    ) -> pd.DataFrame:
        """执行单个查询"""
        engine = create_engine(self._build_connection_string(source))
        with engine.connect() as conn:
            return pd.read_sql(text(query), conn)
    
    def _build_connection_string(self, source: DataSource) -> str:
        """构建连接字符串"""
        # 实现与 DataSourceManager 相同的逻辑
        pass

缓存策略优化

1. 智能缓存失效

class IntelligentCache:
    """智能缓存"""
    
    def __init__(self):
        self.cache = {}
        self.access_count = {}
        self.last_access = {}
    
    def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        if key in self.cache:
            self.access_count[key] = self.access_count.get(key, 0) + 1
            self.last_access[key] = datetime.now()
            return self.cache[key]
        return None
    
    def set(self, key: str, value: Any, ttl: int = 3600):
        """设置缓存"""
        self.cache[key] = {
            "value": value,
            "created_at": datetime.now(),
            "ttl": ttl
        }
        self.access_count[key] = 1
        self.last_access[key] = datetime.now()
    
    def cleanup_expired(self):
        """清理过期缓存"""
        now = datetime.now()
        expired_keys = []
        
        for key, data in self.cache.items():
            if (now - data["created_at"]).total_seconds() > data["ttl"]:
                expired_keys.append(key)
        
        for key in expired_keys:
            del self.cache[key]
            del self.access_count[key]
            del self.last_access[key]
    
    def evict_lru(self, count: int = 10):
        """驱逐最少使用的缓存"""
        # 按访问次数和最后访问时间排序
        sorted_keys = sorted(
            self.access_count.keys(),
            key=lambda k: (self.access_count[k], self.last_access[k])
        )
        
        for key in sorted_keys[:count]:
            del self.cache[key]
            del self.access_count[key]
            del self.last_access[key]

2. 预计算常用报表

class ReportPrecomputer:
    """报表预计算"""
    
    def __init__(self, report_generator: ReportGenerationAgent):
        self.report_generator = report_generator
        self.precomputed_reports = {}
    
    def precompute_popular_reports(
        self,
        popular_queries: List[str],
        source: DataSource
    ):
        """预计算热门报表"""
        for query in popular_queries:
            try:
                report = self.report_generator.generate_report(query, source)
                self.precomputed_reports[query] = {
                    "report": report,
                    "computed_at": datetime.now(),
                    "ttl": 3600  # 1小时有效期
                }
            except Exception as e:
                print(f"预计算失败: {query}, 错误: {str(e)}")
    
    def get_precomputed_report(self, query: str) -> Optional[Dict[str, Any]]:
        """获取预计算的报表"""
        if query in self.precomputed_reports:
            cached = self.precomputed_reports[query]
            age = (datetime.now() - cached["computed_at"]).total_seconds()
            
            if age < cached["ttl"]:
                return cached["report"]
            else:
                # 缓存过期,删除
                del self.precomputed_reports[query]
        
        return None

参考资源

官方文档与框架

相关技术

学术资源

  • "Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task" - Text-to-SQL 基准数据集
  • "Table-GPT: Table-tuned GPT for Diverse Table Tasks" - 表格数据处理的 LLM 应用研究

社区资源

自动化报表 Agent 是数据智能化的关键组件,它将 LLM 的理解能力与传统数据分析工具结合,大幅降低了数据洞察的门槛。随着技术的不断成熟,我们有理由相信,数据分析将变得更加智能、高效和普惠。