1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Optional import json, time
@dataclass class Tool: name: str description: str schema: Dict[str, Any] handler: Any
class ToolRegistry: def __init__(self): self._tools: Dict[str, Tool] = {} def register(self, tool: Tool): self._tools[tool.name] = tool def get(self, name: str) -> Tool: return self._tools[name] def list_schemas(self) -> List[Dict[str, Any]]: return [{"name": t.name, "description": t.description, "parameters": t.schema} for t in self._tools.values()]
def get_pr_info(pr_id: str) -> Dict[str, Any]: return {"id": pr_id, "title": "Refactor payment module", "author": "zhangsan", "risk": "medium"}
def create_calendar_event(title: str, when: str, attendees: List[str]) -> Dict[str, Any]: return {"event_id": "evt_1001", "title": title, "time": when, "attendees": attendees}
def send_group_message(channel: str, content: str) -> Dict[str, Any]: return {"ok": True, "channel": channel, "length": len(content)}
registry = ToolRegistry() registry.register(Tool( name="get_pr_info", description="获取 PR 基本信息", schema={"type": "object", "properties": {"pr_id": {"type": "string"}}, "required": ["pr_id"]}, handler=lambda args: get_pr_info(**args), )) registry.register(Tool( name="create_calendar_event", description="在日历创建会议", schema={"type": "object", "properties": { "title": {"type": "string"}, "when": {"type": "string"}, "attendees": {"type": "array", "items": {"type": "string"}}, }, "required": ["title", "when", "attendees"]}, handler=lambda args: create_calendar_event(**args), )) registry.register(Tool( name="send_group_message", description="在指定群聊发送文本", schema={"type": "object", "properties": { "channel": {"type": "string"}, "content": {"type": "string"} }, "required": ["channel", "content"]}, handler=lambda args: send_group_message(**args), ))
@dataclass class Memory: short_term: Dict[str, Any] = field(default_factory=dict) long_term: Dict[str, Any] = field(default_factory=dict) work_log: List[Dict[str, Any]] = field(default_factory=list) def remember(self, key: str, value: Any): self.short_term[key] = value def recall(self, key: str, default=None): return self.short_term.get(key, default) def append_log(self, step: Dict[str, Any]): self.work_log.append(step)
class Orchestrator: def __init__(self, tools: ToolRegistry, memory: Memory, max_steps=5): self.tools = tools self.memory = memory self.max_steps = max_steps
def plan(self, user_intent: str) -> Dict[str, Any]: if "PR" in user_intent or "评审" in user_intent: return {"steps": [ {"tool": "get_pr_info", "args": {"pr_id": "PR-1234"}}, {"tool": "create_calendar_event", "args": { "title": "PR-1234 评审会", "when": "tomorrow 2pm", "attendees": ["zhangsan", "lisi"]}}, {"tool": "send_group_message", "args": { "channel": "team-dev", "content": "已创建评审会,议题:PR-1234"}} ]} return {"steps": []}
def validate_args(self, tool: Tool, args: Dict[str, Any]) -> Tuple[bool, str]: required = tool.schema.get("required", []) for r in required: if r not in args or args[r] in (None, ""): return False, f"缺少必要参数: {r}" return True, "ok"
def run(self, user_intent: str) -> Dict[str, Any]: plan = self.plan(user_intent) results = [] for i, step in enumerate(plan.get("steps", [])): if i >= self.max_steps: break tool = self.tools.get(step["tool"]) ok, msg = self.validate_args(tool, step["args"]) if not ok: step["args"][msg.split(": ")[-1]] = self.memory.recall(msg, "todo") out = tool.handler(step["args"]) self.memory.append_log({"step": i+1, "tool": tool.name, "args": step["args"], "out": out}) results.append(out) return {"summary": "done", "results": results, "trace": self.memory.work_log}
if __name__ == "__main__": mem = Memory() agent = Orchestrator(registry, mem, max_steps=5) intent = "明天下午和张三评审PR-1234,顺便把会议纪要发群里" out = agent.run(intent) print(json.dumps(out, ensure_ascii=False, indent=2))
|