Agent上下文管理与Token成本优化

深入探讨Agent系统的上下文管理策略和Token成本优化技术,包括上下文压缩、长文本处理和成本监控系统

概述与动机

在大语言模型驱动的Agent系统中,上下文管理是影响系统性能、质量和成本的核心因素之一。上下文不仅包含用户当前的输入,还包括历史对话、系统指令、领域知识等多个维度的信息。随着对话轮次的增加和应用场景的复杂化,上下文长度呈线性甚至指数级增长,这不仅增加了Token使用成本,还可能超出模型的上下文窗口限制,导致系统功能失效。

有效的上下文管理需要在信息完整性、性能和成本之间找到最佳平衡点。一方面,保留足够的上下文信息对于维持对话连贯性和理解用户意图至关重要;另一方面,过长的上下文会显著增加推理延迟和成本,甚至影响模型的理解质量。特别是在多轮对话、代码审查、文档分析等复杂场景中,上下文管理变得更加挑战。

Token成本优化是大语言模型应用的经济基础。API调用成本与Token使用量直接相关,优化Token使用不仅能降低运营成本,还能提升系统性能和用户体验。对于处理大量数据的高并发场景,即使是10-20%的Token节省也能带来显著的成本效益。

本文将深入探讨Agent系统的上下文管理和Token成本优化技术,从理论基础到具体实现,提供完整的解决方案和最佳实践,帮助读者构建高效、经济的Agent系统。

核心概念与架构设计

上下文管理架构

上下文管理是一个多层次的系统,需要在不同层面进行协调优化。下图展示了完整的上下文管理架构:

Rendering diagram...

上下文压缩策略

上下文压缩是在保持关键信息的前提下,减少上下文Token数量的核心技术。主要包括:

语义压缩:通过理解上下文的语义,识别和保留关键信息,压缩次要信息。语义压缩可以保持较高的信息密度,但需要复杂的语义理解能力。常用的方法包括关键词提取、句子简化、段落摘要等。

结构化压缩:利用上下文的结构特征进行压缩。例如,对于对话历史,可以保留最近的完整对话,将较早的对话压缩为摘要;对于代码上下文,可以保留核心函数和类定义,压缩注释和示例代码。

层次化压缩:根据信息的重要性进行分级压缩。核心指令和关键信息保持完整,次要信息进行部分压缩,冗余信息完全删除。层次化压缩可以最大化压缩效果,同时保证重要信息的完整性。

增量压缩:基于上下文的变化进行增量压缩。只压缩新添加的内容,保持已有内容的压缩状态。增量压缩可以减少重复压缩的计算开销。

上下文窗口管理策略

上下文窗口管理是指在有限的空间内有效管理和使用上下文的技术:

滑动窗口策略:保留最近N轮的完整对话,更早的对话被移除或压缩。滑动窗口简单有效,但可能丢失早期的重要信息。可以通过重要性评分优化窗口内容的选择。

动态窗口策略:根据对话内容和复杂度动态调整窗口大小。对于简单对话使用小窗口,复杂对话使用大窗口。动态策略可以更好地平衡性能和信息完整性。

分层窗口策略:将上下文分为不同层级,最近的信息保留完整,较早的信息进行不同程度的压缩。分层策略可以在有限空间内保留更多历史信息。

重要性驱动策略:基于信息重要性动态管理窗口内容。保留重要性高的信息,压缩或移除重要性低的信息。重要性可以通过语义相关性、用户反馈、任务相关性等多个维度评估。

成本监控系统

成本监控系统是优化Token使用的关键工具,主要功能包括:

实时监控:实时跟踪Token使用情况,包括输入Token、输出Token、总Token数等指标。监控数据可以用于及时发现问题并进行调整。

成本预算:为不同的用户、任务、时间周期设置成本预算,防止成本超支。预算控制可以在多个层次实现,包括请求级别、会话级别、用户级别等。

异常检测:检测异常的Token使用模式,如突然增加的Token消耗、异常的请求频率等。异常检测可以帮助发现系统问题或恶意使用。

成本分析:分析Token使用的模式和趋势,识别成本优化的机会。分析可以按照不同维度进行,如用户、任务、时间段、模型类型等。

关键技术实现

智能上下文管理器实现

下面是一个完整的智能上下文管理器实现,支持多种压缩策略和窗口管理:

import heapq
import time
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, field
from enum import Enum
import re
import tiktoken
from collections import deque
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

class ContextPriority(Enum):
    """上下文优先级"""
    CRITICAL = "critical"      # 关键信息,必须保留
    HIGH = "high"              # 重要信息,尽量保留
    MEDIUM = "medium"          # 一般信息,可以压缩
    LOW = "low"                # 次要信息,可以删除

@dataclass(order=True)
class ContextItem:
    """上下文项"""
    priority: int
    content: str
    timestamp: float = field(default_factory=time.time)
    importance_score: float = 0.0
    metadata: Dict[str, Any] = field(default_factory=dict)
    compressed: bool = False
    original_length: int = 0

class ContextCompression:
    """上下文压缩工具"""
    
    def __init__(self, model_encoding: str = "cl100k_base"):
        """
        初始化上下文压缩器
        
        Args:
            model_encoding: 模型编码器
        """
        self.encoding = tiktoken.get_encoding(model_encoding)
        self.tfidf_vectorizer = TfidfVectorizer(max_features=100)
    
    def count_tokens(self, text: str) -> int:
        """计算Token数量"""
        return len(self.encoding.encode(text))
    
    def extract_keywords(self, text: str, top_k: int = 5) -> List[str]:
        """
        提取关键词
        
        Args:
            text: 输入文本
            top_k: 返回的关键词数量
            
        Returns:
            关键词列表
        """
        try:
            # 使用TF-IDF提取关键词
            tfidf_matrix = self.tfidf_vectorizer.fit_transform([text])
            feature_names = self.tfidf_vectorizer.get_feature_names_out()
            
            # 获取TF-IDF分数最高的词
            scores = tfidf_matrix.toarray()[0]
            top_indices = np.argsort(scores)[-top_k:][::-1]
            
            keywords = [feature_names[i] for i in top_indices if scores[i] > 0]
            return keywords
        except Exception as e:
            # 如果提取失败,返回简单的关键词
            words = re.findall(r'\b\w+\b', text.lower())
            word_freq = {}
            for word in words:
                if len(word) > 2:  # 过滤短词
                    word_freq[word] = word_freq.get(word, 0) + 1
            
            sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
            return [word for word, count in sorted_words[:top_k]]
    
    def summarize_text(self, text: str, max_tokens: int = 100) -> str:
        """
        摘要文本
        
        Args:
            text: 输入文本
            max_tokens: 最大Token数
            
        Returns:
            摘要文本
        """
        # 简单的摘要策略:提取关键句子
        sentences = re.split(r'[。!?.!?]', text)
        sentences = [s.strip() for s in sentences if s.strip()]
        
        if not sentences:
            return text[:50]  # 返回前50个字符
        
        # 计算每个句子的分数
        sentence_scores = []
        for sentence in sentences:
            # 分数基于长度和关键词密度
            keywords = self.extract_keywords(sentence, top_k=3)
            score = len(keywords) + len(sentence.split())
            sentence_scores.append((score, sentence))
        
        # 选择得分最高的句子
        sentence_scores.sort(reverse=True)
        selected_sentences = [sentence for _, sentence in sentence_scores[:3]]
        
        summary = '。'.join(selected_sentences)
        
        # 如果摘要仍然过长,截断
        while self.count_tokens(summary) > max_tokens and len(summary) > 10:
            summary = summary[:-10]
        
        return summary
    
    def compress_code(self, code: str, max_tokens: int = 200) -> str:
        """
        压缩代码
        
        Args:
            code: 代码文本
            max_tokens: 最大Token数
            
        Returns:
            压缩后的代码
        """
        # 保留核心结构:函数定义、类定义、关键变量
        lines = code.split('\n')
        compressed_lines = []
        
        # 识别重要的代码行
        for line in lines:
            stripped = line.strip()
            # 保留定义语句
            if any(keyword in stripped for keyword in ['def ', 'class ', 'import ', 'from ']):
                compressed_lines.append(line)
            # 保留关键语句(简单启发式)
            elif stripped and not stripped.startswith('#') and not stripped.startswith('//'):
                # 保留变量赋值和返回语句
                if '=' in stripped or 'return' in stripped or 'if ' in stripped:
                    compressed_lines.append(line)
        
        compressed_code = '\n'.join(compressed_lines)
        
        # 如果仍然过长,截断
        while self.count_tokens(compressed_code) > max_tokens and len(compressed_code) > 20:
            compressed_code = compressed_code[:-20]
        
        return compressed_code
    
    def compress_dialog(self, dialog_history: List[Dict[str, str]], max_tokens: int = 150) -> str:
        """
        压缩对话历史
        
        Args:
            dialog_history: 对话历史列表
            max_tokens: 最大Token数
            
        Returns:
            压缩后的对话摘要
        """
        if not dialog_history:
            return ""
        
        # 提取对话中的关键信息
        key_points = []
        
        for turn in dialog_history:
            role = turn.get('role', 'user')
            content = turn.get('content', '')
            
            if role == 'user':
                # 用户问题:提取关键词
                keywords = self.extract_keywords(content, top_k=3)
                if keywords:
                    key_points.append(f"用户询问: {', '.join(keywords)}")
            elif role == 'assistant':
                # 助手回答:提取关键信息
                keywords = self.extract_keywords(content, top_k=3)
                if keywords:
                    key_points.append(f"助手回答: {', '.join(keywords)}")
        
        summary = '; '.join(key_points)
        
        # 如果摘要过长,截断
        while self.count_tokens(summary) > max_tokens and len(summary) > 10:
            summary = summary[:-10]
        
        return summary

class ContextManager:
    """智能上下文管理器"""
    
    def __init__(
        self,
        max_context_tokens: int = 4000,
        window_strategy: str = "sliding",
        window_size: int = 5,
        importance_threshold: float = 0.5
    ):
        """
        初始化上下文管理器
        
        Args:
            max_context_tokens: 最大上下文Token数
            window_strategy: 窗口策略
            window_size: 窗口大小
            importance_threshold: 重要性阈值
        """
        self.max_context_tokens = max_context_tokens
        self.window_strategy = window_strategy
        self.window_size = window_size
        self.importance_threshold = importance_threshold
        
        self.compression = ContextCompression()
        
        # 上下文存储
        self.context_queue: deque = deque(maxlen=window_size * 2)
        self.priority_queue: List[ContextItem] = []
        
        # 统计信息
        self.stats = {
            'total_tokens': 0,
            'compressed_tokens': 0,
            'compression_ratio': 0.0,
            'total_items': 0,
            'compressed_items': 0
        }
    
    def _calculate_importance(self, content: str, metadata: Dict[str, Any]) -> float:
        """
        计算内容重要性
        
        Args:
            content: 内容文本
            metadata: 元数据
            
        Returns:
            重要性分数
        """
        # 基础分数基于内容长度和关键词密度
        keywords = self.compression.extract_keywords(content, top_k=5)
        keyword_density = len(keywords) / max(len(content.split()), 1)
        
        importance = keyword_density * 0.3
        
        # 根据元数据调整重要性
        if metadata.get('is_instruction', False):
            importance += 0.5  # 指令更重要
        if metadata.get('is_critical', False):
            importance += 0.3  # 关键信息更重要
        if metadata.get('user_feedback') == 'positive':
            importance += 0.2  # 用户正面反馈增加重要性
        
        # 时间衰减
        timestamp = metadata.get('timestamp', time.time())
        age = time.time() - timestamp
        time_decay = max(0, 1 - age / 3600)  # 1小时内不衰减
        importance *= time_decay
        
        return min(importance, 1.0)
    
    def _assign_priority(self, importance: float) -> ContextPriority:
        """根据重要性分配优先级"""
        if importance >= 0.8:
            return ContextPriority.CRITICAL
        elif importance >= 0.6:
            return ContextPriority.HIGH
        elif importance >= 0.4:
            return ContextPriority.MEDIUM
        else:
            return ContextPriority.LOW
    
    def add_context(
        self,
        content: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> int:
        """
        添加上下文
        
        Args:
            content: 上下文内容
            metadata: 元数据
            
        Returns:
            当前上下文Token数
        """
        if metadata is None:
            metadata = {}
        
        metadata['timestamp'] = time.time()
        
        # 计算重要性
        importance = self._calculate_importance(content, metadata)
        priority = self._assign_priority(importance)
        
        # 创建上下文项
        priority_value = {
            ContextPriority.CRITICAL: 0,
            ContextPriority.HIGH: 1,
            ContextPriority.MEDIUM: 2,
            ContextPriority.LOW: 3
        }[priority]
        
        context_item = ContextItem(
            priority=priority_value,
            content=content,
            timestamp=metadata.get('timestamp'),
            importance_score=importance,
            metadata=metadata,
            original_length=self.compression.count_tokens(content)
        )
        
        # 添加到优先级队列
        heapq.heappush(self.priority_queue, context_item)
        
        # 添加到滑动窗口
        self.context_queue.append(context_item)
        
        # 更新统计信息
        self.stats['total_items'] += 1
        
        # 执行窗口管理
        self._manage_window()
        
        return self.get_total_tokens()
    
    def _manage_window(self):
        """管理上下文窗口"""
        total_tokens = self.get_total_tokens()
        
        # 如果超过限制,压缩或移除低优先级内容
        while total_tokens > self.max_context_tokens and self.priority_queue:
            # 获取最低优先级的项
            item = heapq.heappop(self.priority_queue)
            
            if item.compressed:
                # 已压缩的项直接删除
                if item in self.context_queue:
                    self.context_queue.remove(item)
            elif item.priority == 3:  # LOW priority
                # 低优先级项直接删除
                if item in self.context_queue:
                    self.context_queue.remove(item)
                    total_tokens -= item.original_length
            else:
                # 高优先级项进行压缩
                compressed_content = self._compress_item(item)
                if compressed_content:
                    compressed_length = self.compression.count_tokens(compressed_content)
                    saved_tokens = item.original_length - compressed_length
                    
                    # 更新项
                    item.content = compressed_content
                    item.compressed = True
                    item.original_length = compressed_length
                    
                    # 更新统计
                    self.stats['compressed_items'] += 1
                    self.stats['compressed_tokens'] += saved_tokens
                    
                    total_tokens -= saved_tokens
                    
                    # 如果压缩后仍然有用,放回队列
                    if item.priority < 3:  # 不是LOW优先级
                        heapq.heappush(self.priority_queue, item)
    
    def _compress_item(self, item: ContextItem) -> Optional[str]:
        """
        压缩单个上下文项
        
        Args:
            item: 上下文项
            
        Returns:
            压缩后的内容或None
        """
        content_type = item.metadata.get('type', 'text')
        
        if content_type == 'code':
            return self.compression.compress_code(item.content)
        elif content_type == 'dialog':
            return self.compression.compress_dialog(item.metadata.get('dialog_history', []))
        elif content_type == 'instruction':
            # 指令通常很短,不需要压缩
            return item.content
        else:
            # 默认使用摘要压缩
            return self.compression.summarize_text(item.content, max_tokens=50)
    
    def get_context(
        self,
        max_tokens: Optional[int] = None,
        include_metadata: bool = False
    ) -> str:
        """
        获取当前上下文
        
        Args:
            max_tokens: 最大Token数
            include_metadata: 是否包含元数据
            
        Returns:
            上下文文本
        """
        if max_tokens is None:
            max_tokens = self.max_context_tokens
        
        # 按优先级排序
        sorted_items = sorted(self.priority_queue, key=lambda x: x.priority)
        
        context_parts = []
        current_tokens = 0
        
        for item in sorted_items:
            item_tokens = self.compression.count_tokens(item.content)
            
            if current_tokens + item_tokens <= max_tokens:
                if include_metadata:
                    metadata_str = f"[{item.metadata.get('type', 'context')}] "
                    context_parts.append(metadata_str + item.content)
                else:
                    context_parts.append(item.content)
                current_tokens += item_tokens
            else:
                break
        
        context = '\n\n'.join(context_parts)
        return context
    
    def get_total_tokens(self) -> int:
        """获取当前上下文总Token数"""
        return sum(
            self.compression.count_tokens(item.content)
            for item in self.priority_queue
        )
    
    def get_stats(self) -> Dict[str, Any]:
        """获取统计信息"""
        total_tokens = self.get_total_tokens()
        if self.stats['total_tokens'] > 0:
            self.stats['compression_ratio'] = (
                self.stats['compressed_tokens'] / self.stats['total_tokens']
            )
        
        return {
            **self.stats,
            'current_tokens': total_tokens,
            'context_queue_size': len(self.context_queue),
            'priority_queue_size': len(self.priority_queue),
            'compression_ratio': self.stats['compression_ratio']
        }
    
    def clear(self):
        """清空上下文"""
        self.context_queue.clear()
        self.priority_queue.clear()
        self.stats = {
            'total_tokens': 0,
            'compressed_tokens': 0,
            'compression_ratio': 0.0,
            'total_items': 0,
            'compressed_items': 0
        }

# 使用示例
def demonstrate_context_manager():
    """演示上下文管理器的使用"""
    print("=== 智能上下文管理器演示 ===\n")
    
    # 初始化上下文管理器
    manager = ContextManager(
        max_context_tokens=1000,
        window_strategy="sliding",
        window_size=3
    )
    
    # 添加不同类型的上下文
    print("添加上下文项:")
    
    # 1. 系统指令
    instruction = "你是一个专业的代码审查助手,请帮助用户检查代码质量、发现潜在问题并提供改进建议。"
    manager.add_context(
        instruction,
        metadata={'type': 'instruction', 'is_critical': True}
    )
    print(f"1. 添加系统指令 ({manager.compression.count_tokens(instruction)} tokens)")
    
    # 2. 用户问题
    question = "请帮我审查这段Python代码,看看有没有潜在的性能问题和安全漏洞。"
    manager.add_context(
        question,
        metadata={'type': 'text', 'user_feedback': 'positive'}
    )
    print(f"2. 添加用户问题 ({manager.compression.count_tokens(question)} tokens)")
    
    # 3. 代码上下文
    code = """
def process_user_data(user_id):
    # 查询用户数据
    user = database.query(f"SELECT * FROM users WHERE id = {user_id}")
    
    # 处理用户数据
    if user:
        result = {
            'name': user['name'],
            'email': user['email'],
            'created_at': user['created_at']
        }
        return result
    else:
        return None
"""
    manager.add_context(
        code,
        metadata={'type': 'code', 'is_critical': True}
    )
    print(f"3. 添加代码上下文 ({manager.compression.count_tokens(code)} tokens)")
    
    # 4. 长文档上下文
    document = "这是一个关于Web应用安全的长文档,包含了很多关于SQL注入、XSS攻击、CSRF防护等安全主题的详细说明。在实际开发中,我们需要特别注意这些安全问题,并采用相应的防护措施。例如,对于用户输入要进行严格的验证和过滤,避免SQL注入攻击。对于输出到页面的内容要进行适当的编码,防止XSS攻击。对于跨域请求要使用CSRF令牌进行保护..."
    manager.add_context(
        document,
        metadata={'type': 'text'}
    )
    print(f"4. 添加文档上下文 ({manager.compression.count_tokens(document)} tokens)")
    
    # 5. 对话历史
    dialog_history = [
        {'role': 'user', 'content': '什么是SQL注入?'},
        {'role': 'assistant', 'content': 'SQL注入是一种常见的Web安全漏洞,攻击者通过在输入字段中插入恶意SQL代码来操纵数据库查询。'},
        {'role': 'user', 'content': '如何防止SQL注入?'},
        {'role': 'assistant', 'content': '防止SQL注入的主要方法包括:1. 使用参数化查询或预编译语句;2. 对用户输入进行严格验证和过滤;3. 使用ORM框架;4. 最小权限原则。'}
    ]
    manager.add_context(
        "对话历史",
        metadata={'type': 'dialog', 'dialog_history': dialog_history}
    )
    print(f"5. 添加对话历史 (approx 200 tokens)")
    
    # 显示当前状态
    print(f"\n当前上下文状态:")
    print(f"总Token数: {manager.get_total_tokens()}")
    print(f"上下文项数量: {len(manager.priority_queue)}")
    
    # 显示统计信息
    print(f"\n管理器统计信息:")
    stats = manager.get_stats()
    for key, value in stats.items():
        if isinstance(value, float):
            print(f"{key}: {value:.4f}")
        else:
            print(f"{key}: {value}")
    
    # 获取优化后的上下文
    print(f"\n优化后的上下文 ({manager.get_total_tokens()} tokens):")
    print("=" * 80)
    print(manager.get_context())
    print("=" * 80)

if __name__ == "__main__":
    demonstrate_context_manager()

成本监控系统实现

下面是一个完整的成本监控系统实现,支持实时监控、预算控制和异常检测:

import time
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, field
from collections import defaultdict
from enum import Enum
import sqlite3
from datetime import datetime, timedelta
import json

class CostAlertLevel(Enum):
    """成本告警级别"""
    INFO = "info"
    WARNING = "warning"
    CRITICAL = "critical"

@dataclass
class CostRecord:
    """成本记录"""
    timestamp: float
    user_id: str
    task_id: str
    model_name: str
    input_tokens: int
    output_tokens: int
    total_tokens: int
    cost: float
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class BudgetRule:
    """预算规则"""
    name: str
    budget_limit: float
    period: str  # hourly, daily, weekly, monthly
    scope: str   # user, task, global
    scope_id: Optional[str] = None
    alert_threshold: float = 0.8  # 告警阈值(百分比)

class CostMonitoringSystem:
    """成本监控系统"""
    
    def __init__(
        self,
        db_path: str = "cost_monitoring.db",
        model_costs: Optional[Dict[str, float]] = None
    ):
        """
        初始化成本监控系统
        
        Args:
            db_path: 数据库路径
            model_costs: 模型Token成本配置(每1000 tokens的价格)
        """
        self.db_path = db_path
        self.model_costs = model_costs or {
            'gpt-4': {'input': 0.03, 'output': 0.06},
            'gpt-3.5-turbo': {'input': 0.0015, 'output': 0.002},
            'claude-3-opus': {'input': 0.015, 'output': 0.075},
            'claude-3-sonnet': {'input': 0.003, 'output': 0.015}
        }
        
        # 预算规则
        self.budget_rules: List[BudgetRule] = []
        
        # 告警回调
        self.alert_callbacks: Dict[CostAlertLevel, List[Callable]] = {
            CostAlertLevel.INFO: [],
            CostAlertLevel.WARNING: [],
            CostAlertLevel.CRITICAL: []
        }
        
        # 初始化数据库
        self._init_database()
    
    def _init_database(self):
        """初始化数据库"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # 创建成本记录表
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS cost_records (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp REAL NOT NULL,
                user_id TEXT NOT NULL,
                task_id TEXT NOT NULL,
                model_name TEXT NOT NULL,
                input_tokens INTEGER NOT NULL,
                output_tokens INTEGER NOT NULL,
                total_tokens INTEGER NOT NULL,
                cost REAL NOT NULL,
                metadata TEXT
            )
        """)
        
        # 创建预算规则表
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS budget_rules (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL,
                budget_limit REAL NOT NULL,
                period TEXT NOT NULL,
                scope TEXT NOT NULL,
                scope_id TEXT,
                alert_threshold REAL DEFAULT 0.8,
                created_at REAL NOT NULL
            )
        """)
        
        # 创建索引
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_timestamp 
            ON cost_records(timestamp)
        """)
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_user_id 
            ON cost_records(user_id)
        """)
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_task_id 
            ON cost_records(task_id)
        """)
        
        conn.commit()
        conn.close()
    
    def calculate_cost(
        self,
        model_name: str,
        input_tokens: int,
        output_tokens: int
    ) -> float:
        """
        计算成本
        
        Args:
            model_name: 模型名称
            input_tokens: 输入Token数
            output_tokens: 输出Token数
            
        Returns:
            成本(美元)
        """
        if model_name not in self.model_costs:
            raise ValueError(f"Unknown model: {model_name}")
        
        model_cost = self.model_costs[model_name]
        input_cost = (input_tokens / 1000) * model_cost['input']
        output_cost = (output_tokens / 1000) * model_cost['output']
        
        return input_cost + output_cost
    
    def record_cost(
        self,
        user_id: str,
        task_id: str,
        model_name: str,
        input_tokens: int,
        output_tokens: int,
        metadata: Optional[Dict[str, Any]] = None
    ) -> CostRecord:
        """
        记录成本
        
        Args:
            user_id: 用户ID
            task_id: 任务ID
            model_name: 模型名称
            input_tokens: 输入Token数
            output_tokens: 输出Token数
            metadata: 元数据
            
        Returns:
            成本记录
        """
        # 计算成本
        cost = self.calculate_cost(model_name, input_tokens, output_tokens)
        total_tokens = input_tokens + output_tokens
        
        # 创建成本记录
        record = CostRecord(
            timestamp=time.time(),
            user_id=user_id,
            task_id=task_id,
            model_name=model_name,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            total_tokens=total_tokens,
            cost=cost,
            metadata=metadata or {}
        )
        
        # 存储到数据库
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute("""
            INSERT INTO cost_records 
            (timestamp, user_id, task_id, model_name, input_tokens, output_tokens, total_tokens, cost, metadata)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            record.timestamp,
            record.user_id,
            record.task_id,
            record.model_name,
            record.input_tokens,
            record.output_tokens,
            record.total_tokens,
            record.cost,
            json.dumps(record.metadata)
        ))
        conn.commit()
        conn.close()
        
        # 检查预算
        self._check_budgets(record)
        
        return record
    
    def add_budget_rule(
        self,
        name: str,
        budget_limit: float,
        period: str,
        scope: str,
        scope_id: Optional[str] = None,
        alert_threshold: float = 0.8
    ):
        """
        添加预算规则
        
        Args:
            name: 规则名称
            budget_limit: 预算限制
            period: 周期
            scope: 作用域
            scope_id: 作用域ID
            alert_threshold: 告警阈值
        """
        rule = BudgetRule(
            name=name,
            budget_limit=budget_limit,
            period=period,
            scope=scope,
            scope_id=scope_id,
            alert_threshold=alert_threshold
        )
        
        self.budget_rules.append(rule)
        
        # 存储到数据库
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute("""
            INSERT INTO budget_rules 
            (name, budget_limit, period, scope, scope_id, alert_threshold, created_at)
            VALUES (?, ?, ?, ?, ?, ?, ?)
        """, (
            rule.name,
            rule.budget_limit,
            rule.period,
            rule.scope,
            rule.scope_id,
            rule.alert_threshold,
            time.time()
        ))
        conn.commit()
        conn.close()
    
    def _check_budgets(self, record: CostRecord):
        """
        检查预算
        
        Args:
            record: 成本记录
        """
        for rule in self.budget_rules:
            # 检查作用域
            if rule.scope == 'user' and rule.scope_id != record.user_id:
                continue
            if rule.scope == 'task' and rule.scope_id != record.task_id:
                continue
            
            # 计算周期内的总成本
            period_cost = self._get_period_cost(rule, record)
            
            # 检查是否超过告警阈值
            usage_ratio = period_cost / rule.budget_limit
            
            if usage_ratio >= 1.0:
                # 超过预算,发送严重告警
                self._send_alert(
                    CostAlertLevel.CRITICAL,
                    f"预算超支: {rule.name} 使用了 ${period_cost:.2f} / ${rule.budget_limit:.2f}",
                    rule,
                    record
                )
            elif usage_ratio >= rule.alert_threshold:
                # 超过告警阈值,发送警告
                self._send_alert(
                    CostAlertLevel.WARNING,
                    f"预算警告: {rule.name} 已使用 {usage_ratio:.1%}",
                    rule,
                    record
                )
    
    def _get_period_cost(self, rule: BudgetRule, record: CostRecord) -> float:
        """
        获取周期内的成本
        
        Args:
            rule: 预算规则
            record: 成本记录
            
        Returns:
            周期成本
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # 计算时间范围
        current_time = time.time()
        time_delta = {
            'hourly': timedelta(hours=1),
            'daily': timedelta(days=1),
            'weekly': timedelta(weeks=1),
            'monthly': timedelta(days=30)
        }.get(rule.period, timedelta(days=1))
        
        start_time = current_time - time_delta.total_seconds()
        
        # 构建查询条件
        query = """
            SELECT SUM(cost) FROM cost_records
            WHERE timestamp >= ? AND timestamp <= ?
        """
        params = [start_time, current_time]
        
        if rule.scope == 'user' and rule.scope_id:
            query += " AND user_id = ?"
            params.append(rule.scope_id)
        elif rule.scope == 'task' and rule.scope_id:
            query += " AND task_id = ?"
            params.append(rule.scope_id)
        
        cursor.execute(query, params)
        result = cursor.fetchone()
        
        conn.close()
        
        return result[0] if result[0] else 0.0
    
    def _send_alert(
        self,
        level: CostAlertLevel,
        message: str,
        rule: BudgetRule,
        record: CostRecord
    ):
        """
        发送告警
        
        Args:
            level: 告警级别
            message: 告警消息
            rule: 预算规则
            record: 成本记录
        """
        print(f"[{level.value.upper()}] {message}")
        
        # 调用注册的回调函数
        for callback in self.alert_callbacks[level]:
            try:
                callback(message, rule, record)
            except Exception as e:
                print(f"告警回调执行失败: {e}")
    
    def register_alert_callback(
        self,
        level: CostAlertLevel,
        callback: Callable
    ):
        """
        注册告警回调
        
        Args:
            level: 告警级别
            callback: 回调函数
        """
        self.alert_callbacks[level].append(callback)
    
    def get_cost_stats(
        self,
        start_time: Optional[float] = None,
        end_time: Optional[float] = None,
        user_id: Optional[str] = None,
        task_id: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        获取成本统计
        
        Args:
            start_time: 开始时间
            end_time: 结束时间
            user_id: 用户ID
            task_id: 任务ID
            
        Returns:
            统计信息
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # 设置默认时间范围
        if end_time is None:
            end_time = time.time()
        if start_time is None:
            start_time = end_time - 86400  # 默认24小时
        
        # 构建查询
        query = """
            SELECT 
                COUNT(*) as total_requests,
                SUM(total_tokens) as total_tokens,
                SUM(input_tokens) as input_tokens,
                SUM(output_tokens) as output_tokens,
                SUM(cost) as total_cost,
                AVG(cost) as avg_cost_per_request,
                AVG(total_tokens) as avg_tokens_per_request
            FROM cost_records
            WHERE timestamp >= ? AND timestamp <= ?
        """
        params = [start_time, end_time]
        
        if user_id:
            query += " AND user_id = ?"
            params.append(user_id)
        
        if task_id:
            query += " AND task_id = ?"
            params.append(task_id)
        
        cursor.execute(query, params)
        result = cursor.fetchone()
        
        # 按模型统计
        model_query = """
            SELECT model_name, COUNT(*) as requests, SUM(cost) as cost
            FROM cost_records
            WHERE timestamp >= ? AND timestamp <= ?
            GROUP BY model_name
            ORDER BY cost DESC
        """
        cursor.execute(model_query, [start_time, end_time])
        model_stats = cursor.fetchall()
        
        conn.close()
        
        return {
            'total_requests': result[0] or 0,
            'total_tokens': result[1] or 0,
            'input_tokens': result[2] or 0,
            'output_tokens': result[3] or 0,
            'total_cost': result[4] or 0.0,
            'avg_cost_per_request': result[5] or 0.0,
            'avg_tokens_per_request': result[6] or 0,
            'model_stats': [
                {
                    'model': row[0],
                    'requests': row[1],
                    'cost': row[2]
                }
                for row in model_stats
            ]
        }
    
    def detect_anomalies(
        self,
        time_window: int = 3600,
        threshold_multiplier: float = 3.0
    ) -> List[Dict[str, Any]]:
        """
        检测异常
        
        Args:
            time_window: 时间窗口(秒)
            threshold_multiplier: 阈值倍数
            
        Returns:
            异常列表
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        current_time = time.time()
        start_time = current_time - time_window
        
        # 获取当前时间窗口的统计数据
        cursor.execute("""
            SELECT 
                COUNT(*) as requests,
                SUM(total_tokens) as tokens,
                SUM(cost) as cost
            FROM cost_records
            WHERE timestamp >= ? AND timestamp <= ?
        """, [start_time, current_time])
        
        current_stats = cursor.fetchone()
        current_requests = current_stats[0] or 0
        current_tokens = current_stats[1] or 0
        current_cost = current_stats[2] or 0.0
        
        # 获取历史统计数据
        cursor.execute("""
            SELECT 
                AVG(requests) as avg_requests,
                STDDEV(requests) as std_requests,
                AVG(tokens) as avg_tokens,
                STDDEV(tokens) as std_tokens,
                AVG(cost) as avg_cost,
                STDDEV(cost) as std_cost
            FROM (
                SELECT 
                    COUNT(*) as requests,
                    SUM(total_tokens) as tokens,
                    SUM(cost) as cost
                FROM cost_records
                WHERE timestamp >= ? AND timestamp < ?
                GROUP BY 
                    CAST(timestamp / 300 AS INTEGER)  -- 5分钟分组
            )
        """, [start_time - time_window * 10, start_time])
        
        historical_stats = cursor.fetchone()
        avg_requests = historical_stats[0] or 0
        std_requests = historical_stats[1] or 0
        avg_tokens = historical_stats[2] or 0
        std_tokens = historical_stats[3] or 0
        avg_cost = historical_stats[4] or 0.0
        std_cost = historical_stats[5] or 0.0
        
        conn.close()
        
        # 检测异常
        anomalies = []
        
        if std_requests > 0:
            request_z_score = (current_requests - avg_requests) / std_requests
            if abs(request_z_score) > threshold_multiplier:
                anomalies.append({
                    'type': 'request_volume',
                    'current_value': current_requests,
                    'historical_avg': avg_requests,
                    'z_score': request_z_score,
                    'severity': 'high' if abs(request_z_score) > threshold_multiplier * 1.5 else 'medium'
                })
        
        if std_cost > 0:
            cost_z_score = (current_cost - avg_cost) / std_cost
            if abs(cost_z_score) > threshold_multiplier:
                anomalies.append({
                    'type': 'cost_surge',
                    'current_value': current_cost,
                    'historical_avg': avg_cost,
                    'z_score': cost_z_score,
                    'severity': 'high' if abs(cost_z_score) > threshold_multiplier * 1.5 else 'medium'
                })
        
        return anomalies

# 使用示例
def demonstrate_cost_monitoring():
    """演示成本监控系统的使用"""
    print("=== 成本监控系统演示 ===\n")
    
    # 初始化成本监控系统
    cost_monitor = CostMonitoringSystem(db_path=":memory:")  # 使用内存数据库
    
    # 添加预算规则
    cost_monitor.add_budget_rule(
        name="用户user001日预算",
        budget_limit=10.0,
        period="daily",
        scope="user",
        scope_id="user001",
        alert_threshold=0.8
    )
    
    cost_monitor.add_budget_rule(
        name="全局小时预算",
        budget_limit=5.0,
        period="hourly",
        scope="global"
    )
    
    # 注册告警回调
    def alert_callback(message, rule, record):
        print(f"  告警: {message}")
        print(f"  规则: {rule.name}")
        print(f"  记录: 用户={record.user_id}, 任务={record.task_id}, 成本=${record.cost:.4f}")
    
    cost_monitor.register_alert_callback(CostAlertLevel.WARNING, alert_callback)
    cost_monitor.register_alert_callback(CostAlertLevel.CRITICAL, alert_callback)
    
    print("模拟API调用记录成本:")
    print("-" * 80)
    
    # 模拟一些API调用
    test_calls = [
        ("user001", "task001", "gpt-4", 500, 300),
        ("user001", "task001", "gpt-4", 800, 500),
        ("user001", "task002", "gpt-3.5-turbo", 300, 200),
        ("user002", "task003", "gpt-3.5-turbo", 400, 300),
        ("user001", "task001", "gpt-4", 1200, 800),  # 这会触发预算告警
        ("user002", "task004", "claude-3-sonnet", 600, 400),
        ("user001", "task002", "gpt-4", 1500, 1000),  # 这会触发预算超支
    ]
    
    for i, (user_id, task_id, model, input_tokens, output_tokens) in enumerate(test_calls, 1):
        print(f"\n{i.} {user_id} 调用 {model}:")
        print(f"   输入Token: {input_tokens}, 输出Token: {output_tokens}")
        
        record = cost_monitor.record_cost(
            user_id=user_id,
            task_id=task_id,
            model_name=model,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            metadata={'operation': f'test_call_{i}'}
        )
        
        print(f"   成本: ${record.cost:.4f}, 总Token: {record.total_tokens}")
    
    # 显示成本统计
    print(f"\n{'='*80}")
    print("成本统计信息:")
    print("="*80)
    
    stats = cost_monitor.get_cost_stats()
    print(f"总请求数: {stats['total_requests']}")
    print(f"总Token数: {stats['total_tokens']}")
    print(f"输入Token: {stats['input_tokens']}")
    print(f"输出Token: {stats['output_tokens']}")
    print(f"总成本: ${stats['total_cost']:.4f}")
    print(f"平均成本/请求: ${stats['avg_cost_per_request']:.4f}")
    print(f"平均Token/请求: {stats['avg_tokens_per_request']:.1f}")
    
    print(f"\n按模型统计:")
    for model_stat in stats['model_stats']:
        print(f"  {model_stat['model']}: {model_stat['requests']} 次请求, ${model_stat['cost']:.4f}")
    
    # 检测异常
    print(f"\n{'='*80}")
    print("异常检测结果:")
    print("="*80)
    
    anomalies = cost_monitor.detect_anomalies()
    if anomalies:
        for anomaly in anomalies:
            print(f"异常类型: {anomaly['type']}")
            print(f"  当前值: {anomaly['current_value']:.2f}")
            print(f"  历史平均值: {anomaly['historical_avg']:.2f}")
            print(f"  Z分数: {anomaly['z_score']:.2f}")
            print(f"  严重程度: {anomaly['severity']}")
    else:
        print("未检测到异常")

if __name__ == "__main__":
    demonstrate_cost_monitoring()

长文本处理工具实现

下面是一个长文本处理工具的实现,支持智能分段、语义分析和增量处理:

import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
import tiktoken

@dataclass
class TextSegment:
    """文本分段"""
    content: str
    start_pos: int
    end_pos: int
    tokens: int
    importance: float = 0.0
    cluster_id: int = -1

class LongTextProcessor:
    """长文本处理器"""
    
    def __init__(
        self,
        max_segment_tokens: int = 1000,
        overlap_ratio: float = 0.1,
        encoding: str = "cl100k_base"
    ):
        """
        初始化长文本处理器
        
        Args:
            max_segment_tokens: 最大分段Token数
            overlap_ratio: 重叠比例
            encoding: 编码器
        """
        self.max_segment_tokens = max_segment_tokens
        self.overlap_ratio = overlap_ratio
        self.encoding = tiktoken.get_encoding(encoding)
        self.tfidf_vectorizer = TfidfVectorizer(max_features=50)
    
    def count_tokens(self, text: str) -> int:
        """计算Token数量"""
        return len(self.encoding.encode(text))
    
    def split_text(
        self,
        text: str,
        preserve_sentences: bool = True
    ) -> List[TextSegment]:
        """
        分割文本
        
        Args:
            text: 输入文本
            preserve_sentences: 是否保持句子完整性
            
        Returns:
            文本分段列表
        """
        segments = []
        position = 0
        text_length = len(text)
        
        while position < text_length:
            # 计算目标结束位置
            target_end = position + self.max_segment_tokens * 4  # 粗略估计
            
            if target_end >= text_length:
                # 到达文本末尾
                segment = self._create_segment(text, position, text_length)
                if segment.tokens > 0:
                    segments.append(segment)
                break
            
            # 寻找最佳分割点
            if preserve_sentences:
                split_pos = self._find_sentence_boundary(
                    text, position, target_end, text_length
                )
            else:
                split_pos = min(target_end, text_length)
            
            # 创建分段
            segment = self._create_segment(text, position, split_pos)
            if segment.tokens > 0:
                segments.append(segment)
            
            # 计算下一个分段的开始位置(考虑重叠)
            overlap_size = int(segment.tokens * self.overlap_ratio)
            position = split_pos - overlap_size * 4  # 粗略转换
            
            # 确保向前推进
            if position <= segment.start_pos:
                position = split_pos
        
        return segments
    
    def _find_sentence_boundary(
        self,
        text: str,
        start: int,
        target: int,
        max_length: int
    ) -> int:
        """
        寻找句子边界
        
        Args:
            text: 文本
            start: 开始位置
            target: 目标位置
            max_length: 最大长度
            
        Returns:
            最佳分割位置
        """
        # 在目标位置附近寻找句子结束标记
        search_range = 200  # 搜索范围
        
        search_start = max(start, target - search_range)
        search_end = min(max_length, target + search_range)
        
        # 优先寻找最近的句子结束标记
        sentence_ends = []
        for i in range(search_start, search_end):
            char = text[i]
            if char in ['。', '!', '?', '.', '!', '?']:
                sentence_ends.append(i)
        
        if sentence_ends:
            # 返回最接近目标位置的句子结束标记
            best_pos = min(sentence_ends, key=lambda x: abs(x - target))
            return best_pos + 1  # 包括结束标记
        
        # 如果没有找到句子结束标记,寻找段落边界
        paragraph_ends = []
        for i in range(search_start, search_end):
            if text[i] == '\n' and i + 1 < max_length and text[i + 1] == '\n':
                paragraph_ends.append(i)
        
        if paragraph_ends:
            best_pos = min(paragraph_ends, key=lambda x: abs(x - target))
            return best_pos + 2  # 跳过换行符
        
        # 最后的选择:在目标位置分割
        return min(target, max_length)
    
    def _create_segment(
        self,
        text: str,
        start: int,
        end: int
    ) -> TextSegment:
        """
        创建文本分段
        
        Args:
            text: 原始文本
            start: 开始位置
            end: 结束位置
            
        Returns:
            文本分段
        """
        content = text[start:end].strip()
        tokens = self.count_tokens(content)
        
        return TextSegment(
            content=content,
            start_pos=start,
            end_pos=end,
            tokens=tokens
        )
    
    def analyze_importance(self, segments: List[TextSegment]) -> List[TextSegment]:
        """
        分析分段重要性
        
        Args:
            segments: 文本分段列表
            
        Returns:
            更新后的分段列表
        """
        if not segments:
            return segments
        
        # 计算TF-IDF特征
        texts = [seg.content for seg in segments]
        tfidf_matrix = self.tfidf_vectorizer.fit_transform(texts)
        
        # 计算每个分段的重要性分数
        for i, segment in enumerate(segments):
            # TF-IDF分数
            tfidf_score = np.mean(tfidf_matrix[i].toarray())
            
            # 长度因子(适中的长度更重要)
            length_factor = 1.0
            if segment.tokens < 100:  # 太短
                length_factor = 0.5
            elif segment.tokens > self.max_segment_tokens * 1.5:  # 太长
                length_factor = 0.7
            
            # 特殊内容加分
            special_content_bonus = 0
            content_lower = segment.content.lower()
            if any(keyword in content_lower for keyword in ['重要', '关键', '必须', 'critical', 'important']):
                special_content_bonus += 0.2
            if any(keyword in content_lower for keyword in ['结论', '总结', 'conclusion', 'summary']):
                special_content_bonus += 0.1
            
            # 综合重要性分数
            segment.importance = (tfidf_score * 0.5 + length_factor * 0.3 + special_content_bonus * 0.2)
        
        # 归一化重要性分数
        if segments:
            max_importance = max(seg.importance for seg in segments)
            if max_importance > 0:
                for segment in segments:
                    segment.importance = segment.importance / max_importance
        
        return segments
    
    def cluster_segments(
        self,
        segments: List[TextSegment],
        n_clusters: int = 3
    ) -> List[TextSegment]:
        """
        对分段进行聚类
        
        Args:
            segments: 文本分段列表
            n_clusters: 聚类数量
            
        Returns:
            更新后的分段列表
        """
        if len(segments) < n_clusters:
            # 分段数量不足,每个分段一个簇
            for i, segment in enumerate(segments):
                segment.cluster_id = i
            return segments
        
        # 计算TF-IDF特征
        texts = [seg.content for seg in segments]
        tfidf_matrix = self.tfidf_vectorizer.fit_transform(texts)
        
        # K-means聚类
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_ids = kmeans.fit_predict(tfidf_matrix.toarray())
        
        # 分配簇ID
        for segment, cluster_id in zip(segments, cluster_ids):
            segment.cluster_id = int(cluster_id)
        
        return segments
    
    def select_important_segments(
        self,
        segments: List[TextSegment],
        max_total_tokens: int,
        diversity: bool = True
    ) -> List[TextSegment]:
        """
        选择重要的分段
        
        Args:
            segments: 文本分段列表
            max_total_tokens: 最大总Token数
            diversity: 是否考虑多样性
            
        Returns:
            选择的分段列表
        """
        if not segments:
            return []
        
        # 按重要性排序
        sorted_segments = sorted(segments, key=lambda x: x.importance, reverse=True)
        
        selected = []
        current_tokens = 0
        
        if diversity:
            # 多样性选择:从每个簇中选择最重要的分段
            clusters = {}
            for segment in sorted_segments:
                if segment.cluster_id not in clusters:
                    clusters[segment.cluster_id] = []
                clusters[segment.cluster_id].append(segment)
            
            # 轮流选择每个簇中的分段
            cluster_ids = sorted(clusters.keys())
            cluster_indices = {cid: 0 for cid in cluster_ids}
            
            while True:
                added_any = False
                for cluster_id in cluster_ids:
                    cluster = clusters[cluster_id]
                    idx = cluster_indices[cluster_id]
                    
                    if idx < len(cluster):
                        segment = cluster[idx]
                        if current_tokens + segment.tokens <= max_total_tokens:
                            selected.append(segment)
                            current_tokens += segment.tokens
                            cluster_indices[cluster_id] += 1
                            added_any = True
                        else:
                            cluster_indices[cluster_id] = len(cluster)  # 标记为已处理
                
                if not added_any:
                    break
        else:
            # 简单选择重要性最高的分段
            for segment in sorted_segments:
                if current_tokens + segment.tokens <= max_total_tokens:
                    selected.append(segment)
                    current_tokens += segment.tokens
                else:
                    break
        
        # 按原始位置排序
        selected.sort(key=lambda x: x.start_pos)
        
        return selected
    
    def create_summary(
        self,
        segments: List[TextSegment],
        max_summary_tokens: int = 300
    ) -> str:
        """
        创建摘要
        
        Args:
            segments: 文本分段列表
            max_summary_tokens: 最大摘要Token数
            
        Returns:
            摘要文本
        """
        if not segments:
            return ""
        
        # 选择重要分段
        important_segments = self.select_important_segments(
            segments,
            max_total_tokens=max_summary_tokens,
            diversity=True
        )
        
        # 提取关键句子
        summary_parts = []
        for segment in important_segments:
            # 提取前3句话
            sentences = re.split(r'[。!?.!?]', segment.content)
            sentences = [s.strip() for s in sentences if s.strip()]
            
            for sentence in sentences[:3]:
                if self.count_tokens(' '.join(summary_parts) + sentence) < max_summary_tokens:
                    summary_parts.append(sentence)
                else:
                    break
        
        summary = '。'.join(summary_parts)
        if summary and not summary.endswith('。'):
            summary += '。'
        
        return summary
    
    def process_long_text(
        self,
        text: str,
        max_context_tokens: int = 2000
    ) -> Dict[str, Any]:
        """
        处理长文本
        
        Args:
            text: 输入文本
            max_context_tokens: 最大上下文Token数
            
        Returns:
            处理结果
        """
        # 分割文本
        segments = self.split_text(text, preserve_sentences=True)
        
        # 分析重要性
        segments = self.analyze_importance(segments)
        
        # 聚类
        segments = self.cluster_segments(segments, n_clusters=3)
        
        # 选择重要分段
        selected_segments = self.select_important_segments(
            segments,
            max_total_tokens=max_context_tokens,
            diversity=True
        )
        
        # 创建摘要
        summary = self.create_summary(segments, max_summary_tokens=200)
        
        return {
            'original_text': text,
            'original_tokens': self.count_tokens(text),
            'total_segments': len(segments),
            'selected_segments': selected_segments,
            'selected_tokens': sum(seg.tokens for seg in selected_segments),
            'compression_ratio': 1 - (sum(seg.tokens for seg in selected_segments) / max(self.count_tokens(text), 1)),
            'summary': summary,
            'summary_tokens': self.count_tokens(summary),
            'segments_analysis': [
                {
                    'content_preview': seg.content[:100] + '...',
                    'tokens': seg.tokens,
                    'importance': seg.importance,
                    'cluster_id': seg.cluster_id
                }
                for seg in segments
            ]
        }

# 使用示例
def demonstrate_long_text_processing():
    """演示长文本处理器的使用"""
    print("=== 长文本处理器演示 ===\n")
    
    processor = LongTextProcessor(
        max_segment_tokens=500,
        overlap_ratio=0.15
    )
    
    # 创建测试长文本
    long_text = """
    人工智能技术正在快速发展,对各个行业都产生了深远的影响。在医疗领域,AI技术被用于疾病诊断、药物研发和个性化治疗。通过分析大量的医疗数据,AI系统能够发现人类难以察觉的模式,从而提供更准确的诊断结果。
    
    在教育领域,AI正在改变传统的教学模式。智能辅导系统能够根据学生的学习进度和理解能力,提供个性化的学习建议。虚拟现实技术结合AI,创造出沉浸式的学习环境,大大提高了学习效果。
    
    金融行业是AI应用的重要领域。风险评估、欺诈检测、算法交易等都可以通过AI技术得到显著提升。AI系统能够实时分析大量交易数据,及时发现异常行为,保护金融安全。
    
    制造业也在积极采用AI技术。智能制造系统能够优化生产流程,提高生产效率,降低成本。预测性维护可以提前发现设备故障,减少停机时间。质量控制系统通过计算机视觉技术,自动检测产品缺陷。
    
    交通运输行业因为AI技术而变得更加智能。自动驾驶汽车正在逐步成为现实,这将彻底改变我们的出行方式。智能交通管理系统能够优化交通流量,减少拥堵,提高道路使用效率。
    
    然而,AI技术的发展也带来了一些挑战。隐私保护、算法公平性、就业影响等问题需要我们认真对待。我们需要在推动技术发展的同时,确保技术造福于全人类。
    
    总之,人工智能技术正在深刻改变我们的世界。我们需要积极拥抱这些变化,同时保持理性和审慎的态度。通过合理规划和政策引导,我们可以确保AI技术的健康发展,创造更美好的未来。
    """
    
    print(f"原始文本: {processor.count_tokens(long_text)} tokens")
    print(f"文本长度: {len(long_text)} 字符\n")
    
    # 处理长文本
    result = processor.process_long_text(long_text, max_context_tokens=800)
    
    print("处理结果:")
    print(f"总分段数: {result['total_segments']}")
    print(f"选择分段数: {len(result['selected_segments'])}")
    print(f"选择Token数: {result['selected_tokens']}")
    print(f"压缩率: {result['compression_ratio']:.2%}")
    print(f"摘要Token数: {result['summary_tokens']}\n")
    
    print("摘要:")
    print("=" * 80)
    print(result['summary'])
    print("=" * 80)
    
    print(f"\n选择的分段:")
    for i, segment in enumerate(result['selected_segments'], 1):
        print(f"{i}. [{segment.cluster_id}] 重要性={segment.importance:.2f} ({segment.tokens} tokens)")
        print(f"   {segment.content[:150]}...")
    
    print(f"\n所有分段分析:")
    for analysis in result['segments_analysis']:
        print(f"簇{analysis['cluster_id']}: {analysis['tokens']} tokens, "
              f"重要性={analysis['importance']:.2f}")
        print(f"  {analysis['content_preview']}")

if __name__ == "__main__":
    demonstrate_long_text_processing()

最佳实践与常见陷阱

上下文管理最佳实践

分层上下文策略:将上下文分为不同层次,最近的信息保持完整,较早的信息进行不同程度的压缩。系统指令和关键配置信息优先级最高,应该始终保持完整。用户输入和对话历史可以采用滑动窗口策略。

智能优先级管理:基于信息的重要性动态管理上下文。重要信息包括:系统指令、用户明确标记的重要信息、包含关键决策的信息、高频访问的信息等。重要性评估可以结合语义分析、用户反馈、时间衰减等多个维度。

上下文质量监控:定期监控上下文管理的质量,包括压缩后的信息完整性、用户满意度、任务完成度等指标。建立反馈机制,根据实际效果调整压缩策略和参数。

渐进式压缩:采用渐进式压缩策略,随着上下文的增长逐步增加压缩力度。早期可以保留更多细节,后期可以更大力度地压缩。渐进式压缩可以平衡信息完整性和性能要求。

成本优化最佳实践

成本预算分层:在不同层次设置成本预算,包括请求级别、会话级别、用户级别和系统级别。分层预算可以更精确地控制成本,及时发现和阻止异常消耗。

智能模型选择:根据任务复杂度和性能要求选择合适的模型。简单任务使用成本较低的小型模型,复杂任务使用功能强大的大型模型。模型选择可以基于任务类型、输入复杂度、输出要求等多个因素。

成本效益分析:定期进行成本效益分析,评估不同功能、不同用户、不同时间段的成本和收益。基于分析结果优化资源分配和成本控制策略。

异常检测和预防:建立成本异常检测机制,及时发现异常的Token使用模式。对于检测到的异常,可以采取限流、告警、拒绝服务等措施,防止成本失控。

长文本处理最佳实践

智能分段策略:在分段时保持语义完整性,尽量在句子、段落等自然边界处分割。避免在关键信息中间分割,确保每个分段都是语义完整的单元。

多样性保证:在选择重要分段时考虑多样性,避免选择内容相似的分段。多样性选择可以提供更全面的信息覆盖,提高理解质量。

增量处理:对于动态变化的文本,采用增量处理策略,只处理新增或变化的部分。增量处理可以减少重复计算,提高处理效率。

质量验证:建立长文本处理的质量验证机制,定期检查摘要和压缩结果的准确性。可以采用人工审核、自动评分、用户反馈等方式验证质量。

常见陷阱及解决方案

过度压缩导致信息丢失:过度压缩可能丢失关键信息,影响任务完成质量。解决方案包括:设置合理的压缩阈值、保留关键信息标识符、建立压缩质量监控、提供人工干预机制。

上下文不一致:压缩过程中可能导致上下文信息不一致,影响模型理解。解决方案包括:保持压缩前后的逻辑一致性、验证压缩结果的连贯性、使用增量压缩保持上下文连续性。

成本监控延迟:成本监控的延迟可能导致成本超支。解决方案包括:实时成本跟踪、预扣费机制、及时告警和限流、预算主动控制。

长文本处理性能问题:长文本处理可能消耗大量计算资源,影响系统性能。解决方案包括:并行处理、缓存中间结果、优化算法复杂度、合理设置处理优先级。

性能优化考虑

性能基准测试

建立完善的性能基准测试体系,评估上下文管理和成本优化的效果:

import time
from typing import Dict, Any
import statistics

class ContextPerformanceBenchmark:
    """上下文性能基准测试"""
    
    def __init__(self):
        self.results = []
    
    def benchmark_compression(
        self,
        processor,
        test_texts: list,
        iterations: int = 5
    ) -> Dict[str, Any]:
        """
        测试压缩性能
        
        Args:
            processor: 处理器实例
            test_texts: 测试文本列表
            iterations: 迭代次数
            
        Returns:
            测试结果
        """
        compression_times = []
        compression_ratios = []
        
        for text in test_texts:
            for _ in range(iterations):
                start_time = time.time()
                
                # 执行压缩
                original_tokens = processor.count_tokens(text)
                result = processor.process_long_text(text)
                compressed_tokens = result['selected_tokens']
                
                end_time = time.time()
                
                compression_times.append((end_time - start_time) * 1000)
                compression_ratio = 1 - (compressed_tokens / max(original_tokens, 1))
                compression_ratios.append(compression_ratio)
        
        return {
            'avg_compression_time_ms': statistics.mean(compression_times),
            'p50_compression_time_ms': statistics.median(compression_times),
            'p95_compression_time_ms': statistics.quantiles(compression_times, n=20)[18] if len(compression_times) > 20 else 0,
            'avg_compression_ratio': statistics.mean(compression_ratios),
            'min_compression_ratio': min(compression_ratios),
            'max_compression_ratio': max(compression_ratios),
            'samples_tested': len(compression_times)
        }
    
    def benchmark_cost_tracking(
        self,
        cost_monitor,
        test_records: list,
        iterations: int = 10
    ) -> Dict[str, Any]:
        """
        测试成本跟踪性能
        
        Args:
            cost_monitor: 成本监控器实例
            test_records: 测试记录列表
            iterations: 迭代次数
            
        Returns:
            测试结果
        """
        tracking_times = []
        
        for record in test_records:
            for _ in range(iterations):
                start_time = time.time()
                
                # 记录成本
                cost_monitor.record_cost(**record)
                
                end_time = time.time()
                
                tracking_times.append((end_time - start_time) * 1000)
        
        return {
            'avg_tracking_time_ms': statistics.mean(tracking_times),
            'p50_tracking_time_ms': statistics.median(tracking_times),
            'p95_tracking_time_ms': statistics.quantiles(tracking_times, n=20)[18] if len(tracking_times) > 20 else 0,
            'throughput_records_per_sec': 1000 / statistics.mean(tracking_times),
            'samples_tested': len(tracking_times)
        }
    
    def generate_report(self, benchmarks: Dict[str, Any]) -> str:
        """
        生成性能报告
        
        Args:
            benchmarks: 基准测试结果
            
        Returns:
            报告文本
        """
        report = []
        report.append("=" * 80)
        report.append("上下文管理与成本优化性能报告")
        report.append("=" * 80)
        
        if 'compression' in benchmarks:
            comp = benchmarks['compression']
            report.append("\n压缩性能:")
            report.append(f"  平均压缩时间: {comp['avg_compression_time_ms']:.2f}ms")
            report.append(f"  P95压缩时间: {comp['p95_compression_time_ms']:.2f}ms")
            report.append(f"  平均压缩率: {comp['avg_compression_ratio']:.2%}")
            report.append(f"  压缩率范围: {comp['min_compression_ratio']:.2%} - {comp['max_compression_ratio']:.2%}")
        
        if 'cost_tracking' in benchmarks:
            cost = benchmarks['cost_tracking']
            report.append("\n成本跟踪性能:")
            report.append(f"  平均跟踪时间: {cost['avg_tracking_time_ms']:.2f}ms")
            report.append(f"  吞吐量: {cost['throughput_records_per_sec']:.0f} 记录/秒")
        
        report.append("\n" + "=" * 80)
        
        return '\n'.join(report)

# 使用示例
def run_performance_benchmarks():
    """运行性能基准测试"""
    print("=== 性能基准测试 ===\n")
    
    # 初始化组件
    processor = LongTextProcessor(max_segment_tokens=500)
    cost_monitor = CostMonitoringSystem(db_path=":memory:")
    
    # 准备测试数据
    test_texts = [
        "这是一个测试文本。" * 100,
        "人工智能正在改变世界。" * 200,
        "Python是一种流行的编程语言。" * 150
    ]
    
    test_records = [
        {"user_id": "user1", "task_id": "task1", "model_name": "gpt-3.5-turbo", "input_tokens": 500, "output_tokens": 300},
        {"user_id": "user1", "task_id": "task2", "model_name": "gpt-4", "input_tokens": 800, "output_tokens": 500},
        {"user_id": "user2", "task_id": "task3", "model_name": "gpt-3.5-turbo", "input_tokens": 300, "output_tokens": 200},
    ]
    
    # 运行基准测试
    benchmark = ContextPerformanceBenchmark()
    
    compression_results = benchmark.benchmark_compression(processor, test_texts)
    cost_results = benchmark.benchmark_cost_tracking(cost_monitor, test_records)
    
    benchmarks = {
        'compression': compression_results,
        'cost_tracking': cost_results
    }
    
    # 生成报告
    report = benchmark.generate_report(benchmarks)
    print(report)

if __name__ == "__main__":
    run_performance_benchmarks()

监控和调优策略

实时性能监控:建立实时监控系统,跟踪上下文管理和成本优化的关键指标。使用仪表盘可视化显示压缩率、处理时间、成本趋势等数据。

自适应参数调优:基于实际性能数据自动调整优化参数。例如,根据压缩质量反馈动态调整压缩阈值,根据成本趋势调整预算限制。

A/B测试框架:对不同的优化策略进行A/B测试,量化比较效果。测试指标包括压缩率、处理时间、成本节省、用户满意度等。

容量规划:基于历史数据和预测模型进行容量规划,确保系统资源能够满足未来的增长需求。容量规划要考虑峰值负载、增长趋势、成本约束等因素。

成本控制策略

多维度成本控制:从多个维度控制成本,包括时间维度(小时、日、月)、用户维度(个人用户、企业用户)、功能维度(不同功能模块)、模型维度(不同模型类型)等。

预测性成本控制:使用机器学习模型预测未来的成本趋势,提前采取控制措施。预测模型可以考虑历史数据、季节因素、业务增长等多个因素。

分级服务质量:根据成本预算提供分级的服务质量。对于预算有限的场景,可以降低服务级别;对于预算充足的场景,可以提供更高质量的服务。

成本优化激励:建立成本优化激励机制,鼓励用户和开发者采用更高效的使用方式。激励机制可以包括费用折扣、服务升级、功能解锁等。

参考资源

官方文档

技术论文和文章

  • "Efficient Transformer Implementation: A Survey": 关于高效Transformer实现的综述论文,涵盖了上下文优化的多种技术
  • "Context Compression for Large Language Models": 大语言模型上下文压缩的研究论文,讨论了各种压缩算法和评估方法
  • "Token Cost Optimization Strategies": Token成本优化策略的技术文章,提供了实用的优化建议和案例研究

开源工具和库

实战案例

  • Enterprise Context Management: 企业级上下文管理的白皮书,分享了大规模应用的实践经验
  • Cost Optimization Case Studies: 成本优化的案例研究,展示了不同行业的优化策略和效果
  • Long Document Processing: 长文档处理的实践指南,包含了各种场景下的处理方案

通过本文的学习,读者应该能够掌握Agent系统的上下文管理和Token成本优化的核心技术,能够在实际项目中构建高效、经济的上下文管理系统。下一篇文章将深入探讨Agent系统监控与调试的技术细节。