1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
| import hashlib from collections import OrderedDict from typing import Optional import tiktoken
class OptimizedContextManager: """优化的上下文管理器""" def __init__(self, max_context_tokens: int = 8192, max_sessions: int = 1000): self.max_context_tokens = max_context_tokens self.max_sessions = max_sessions self.conversations = OrderedDict() self.tokenizer = tiktoken.get_encoding("cl100k_base") self.memory_threshold = 0.8 self.cleanup_batch_size = 50 self.token_cache = {} self.max_cache_size = 10000 self.stats = { "total_messages": 0, "context_compressions": 0, "memory_cleanups": 0 } def add_message(self, session_id: str, role: str, content: str) -> bool: """添加消息到上下文 - 优化版本""" if self._is_memory_pressure(): self._emergency_cleanup() content_tokens = self._count_tokens_cached(content) if session_id not in self.conversations: self.conversations[session_id] = { "messages": [], "total_tokens": 0, "last_access": time.time() } conversation = self.conversations[session_id] conversation["last_access"] = time.time() self.conversations.move_to_end(session_id) message = { "role": role, "content": content, "timestamp": time.time(), "tokens": content_tokens } conversation["messages"].append(message) conversation["total_tokens"] += content_tokens self.stats["total_messages"] += 1 if conversation["total_tokens"] > self.max_context_tokens: self._compress_context(session_id) if len(self.conversations) > self.max_sessions: self._cleanup_lru_sessions() return True def get_context(self, session_id: str, max_tokens: Optional[int] = None) -> List[Dict]: """获取对话上下文 - 优化版本""" if session_id not in self.conversations: return [] conversation = self.conversations[session_id] conversation["last_access"] = time.time() self.conversations.move_to_end(session_id) messages = conversation["messages"] if max_tokens and conversation["total_tokens"] > max_tokens: return self._truncate_to_token_limit(messages, max_tokens) return messages.copy() def _compress_context(self, session_id: str): """压缩上下文 - 梯度压缩策略""" conversation = self.conversations[session_id] messages = conversation["messages"] if len(messages) <= 2: return recent_count = max(6, len(messages) // 4) recent_messages = messages[-recent_count:] older_messages = messages[:-recent_count] if older_messages: compressed_summary = self._create_context_summary(older_messages) new_messages = [compressed_summary] + recent_messages new_total_tokens = sum(msg["tokens"] for msg in new_messages) conversation["messages"] = new_messages conversation["total_tokens"] = new_total_tokens self.stats["context_compressions"] += 1 print(f"🗜️ 上下文压缩: 会话 {session_id}, " f"{len(older_messages)}条消息 -> 1条摘要") def _create_context_summary(self, messages: List[Dict]) -> Dict: """创建上下文摘要""" summary_content = f"[上下文摘要: {len(messages)}条消息]" user_messages = [msg for msg in messages if msg["role"] == "user"] if user_messages: recent_content = user_messages[-1]["content"][:100] + "..." summary_content += f" 最近话题: {recent_content}" return { "role": "system", "content": summary_content, "timestamp": time.time(), "tokens": self._count_tokens_cached(summary_content), "is_summary": True } def _truncate_to_token_limit(self, messages: List[Dict], max_tokens: int) -> List[Dict]: """动态截断到指定token限制""" selected_messages = [] current_tokens = 0 for message in reversed(messages): if current_tokens + message["tokens"] <= max_tokens: selected_messages.insert(0, message) current_tokens += message["tokens"] else: break return selected_messages or messages[-1:] if messages else [] def _count_tokens_cached(self, text: str) -> int: """带缓存的token计算""" text_hash = hashlib.md5(text.encode()).hexdigest() if text_hash in self.token_cache: return self.token_cache[text_hash] tokens = len(self.tokenizer.encode(text)) if len(self.token_cache) < self.max_cache_size: self.token_cache[text_hash] = tokens return tokens def _is_memory_pressure(self) -> bool: """检查内存压力""" memory = psutil.virtual_memory() return memory.percent > self.memory_threshold * 100 def _emergency_cleanup(self): """紧急内存清理""" print("🚨 内存压力过高,执行紧急清理") cleanup_count = min(self.cleanup_batch_size, len(self.conversations) // 4) for _ in range(cleanup_count): if self.conversations: removed_session_id = next(iter(self.conversations)) self.conversations.pop(removed_session_id) if len(self.token_cache) > self.max_cache_size // 2: self.token_cache.clear() self.stats["memory_cleanups"] += 1 def _cleanup_lru_sessions(self): """清理LRU会话""" while len(self.conversations) > self.max_sessions: oldest_session_id = next(iter(self.conversations)) self.conversations.pop(oldest_session_id) def get_global_stats(self) -> Dict: """获取全局统计信息""" total_tokens = sum(conv["total_tokens"] for conv in self.conversations.values()) return { **self.stats, "active_sessions": len(self.conversations), "total_context_tokens": total_tokens, "avg_tokens_per_session": total_tokens / max(len(self.conversations), 1), "cache_size": len(self.token_cache), "estimated_memory_mb": total_tokens * 4 / 1024 / 1024 }
|