1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
| from dataclasses import dataclass, field from typing import Any, Dict, Callable, Optional, List import time, json, random import asyncio
@dataclass class ToolSpec: name: str schema: Dict[str, Any] runner: Callable[[Dict[str, Any]], Any] timeout_s: float = 8.0 retry: int = 2
class SchemaError(ValueError): pass
def validate(schema: Dict[str, Any], payload: Dict[str, Any]) -> Dict[str, Any]: required = schema.get("required", []) props = schema.get("properties", {}) for k in required: if k not in payload: raise SchemaError(f"missing field: {k}") for k, v in props.items(): if k not in payload and "default" in v: payload[k] = v["default"] return payload
async def run_with_retry(tool: ToolSpec, payload: Dict[str, Any]) -> Any: payload = validate(tool.schema, dict(payload)) delay = 0.5 for attempt in range(tool.retry + 1): try: return await asyncio.wait_for(asyncio.to_thread(tool.runner, payload), timeout=tool.timeout_s) except Exception as e: if attempt >= tool.retry: raise await asyncio.sleep(delay + random.random()*0.2) delay = min(delay * 2, 3.0)
@dataclass class Budget: tokens: int ms: int start: float = field(default_factory=lambda: time.time()) def left_ms(self) -> int: return int(self.ms - (time.time() - self.start) * 1000)
@dataclass class TraceEvent: name: str at: float meta: Dict[str, Any]
class Tracer: def __init__(self): self.events: List[TraceEvent] = [] def log(self, name: str, **meta): self.events.append(TraceEvent(name, time.time(), meta)) def dump(self) -> List[Dict[str, Any]]: return [dict(name=e.name, at=e.at, meta=e.meta) for e in self.events]
class Planner: def decide(self, query: str) -> Dict[str, Any]: need_search = any(k in query for k in ["规范","流程","价格","说明"]) tools = ["search"] if need_search else [] return {"tools": tools, "k_refs": 2}
class DummyLLM: def generate(self, prompt: str, max_tokens: int = 500) -> str: return "答复:请参考[1][2],并已创建日程。"
class Agent: def __init__(self, tools: Dict[str, ToolSpec], llm: DummyLLM): self.tools = tools self.llm = llm self.planner = Planner() self.tracer = Tracer()
async def run(self, query: str) -> Dict[str, Any]: plan = self.planner.decide(query) budget = Budget(tokens=3000, ms=3000) self.tracer.log("plan", plan=plan)
evidences = [] for t in plan["tools"]: if budget.left_ms() < 400: break self.tracer.log("tool.call", name=t) res = await run_with_retry(self.tools[t], {"q": query, "k": 4}) evidences.extend(res) self.tracer.log("tool.ok", name=t, size=len(res))
ctx = "\n".join([f"[{i+1}] {e['title']}" for i, e in enumerate(evidences[:6])]) prompt = f"基于证据回答并在结尾引用:[示例]\n{ctx}\n问题:{query}" ans = self.llm.generate(prompt) used = [int(x) for x in __import__('re').findall(r"\[(\d+)\]", ans)] if len(used) < plan["k_refs"]: ans = self.llm.generate(prompt[:600]) return {"answer": ans, "trace": self.tracer.dump()}
def search_runner(payload: Dict[str, Any]): q = payload["q"] k = payload.get("k", 4) return [{"title": f"{q}-证据-{i+1}", "url": f"https://kb/{i+1}"} for i in range(k)]
search_tool = ToolSpec( name="search", schema={ "type": "object", "properties": { "q": {"type": "string"}, "k": {"type": "integer", "default": 4} }, "required": ["q"] }, runner=search_runner, timeout_s=2.5, retry=1 )
async def demo(): agent = Agent(tools={"search": search_tool}, llm=DummyLLM()) out = await agent.run("发布流程规范与审批") print(json.dumps(out, ensure_ascii=False, indent=2))
if __name__ == "__main__": asyncio.run(demo())
|