1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
| import asyncio import weakref from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Callable, Any
@dataclass class ToolTask: """工具任务定义""" task_id: str tool_name: str agent_id: str parameters: dict dependencies: List[str] callback: Callable created_time: float
class LockFreeToolScheduler: """无锁工具调度器""" def __init__(self, max_workers: int = 10): self.executor = ThreadPoolExecutor(max_workers=max_workers) self.task_queue = asyncio.Queue() self.running_tasks = {} self.completed_results = {} self.tool_pools = {} self.tool_semaphores = {} asyncio.create_task(self._scheduling_loop()) async def submit_tool_call(self, tool_name: str, agent_id: str, dependencies: List[str] = None, **kwargs) -> str: """提交工具调用任务""" task_id = f"{agent_id}_{tool_name}_{int(time.time() * 1000)}" task = ToolTask( task_id=task_id, tool_name=tool_name, agent_id=agent_id, parameters=kwargs, dependencies=dependencies or [], callback=None, created_time=time.time() ) await self.task_queue.put(task) return task_id async def _scheduling_loop(self): """调度主循环""" while True: try: try: task = await asyncio.wait_for(self.task_queue.get(), timeout=0.1) except asyncio.TimeoutError: continue if self._dependencies_satisfied(task): await self._execute_task(task) else: await asyncio.sleep(0.1) await self.task_queue.put(task) except Exception as e: print(f"调度循环异常: {e}") def _dependencies_satisfied(self, task: ToolTask) -> bool: """检查任务依赖是否满足""" for dep_tool in task.dependencies: dep_key = f"{task.agent_id}_{dep_tool}" if dep_key not in self.completed_results: return False return True async def _execute_task(self, task: ToolTask): """执行任务""" semaphore = self._get_tool_semaphore(task.tool_name) async with semaphore: try: dep_results = {} for dep_tool in task.dependencies: dep_key = f"{task.agent_id}_{dep_tool}" dep_results[dep_tool] = self.completed_results[dep_key] loop = asyncio.get_event_loop() result = await loop.run_in_executor( self.executor, self._call_tool_function, task.tool_name, {**task.parameters, **dep_results} ) result_key = f"{task.agent_id}_{task.tool_name}" self.completed_results[result_key] = result print(f"任务完成: {task.task_id}") except Exception as e: print(f"任务执行失败: {task.task_id}, 错误: {e}") def _get_tool_semaphore(self, tool_name: str) -> asyncio.Semaphore: """获取工具信号量""" if tool_name not in self.tool_semaphores: limits = { "database_tool": 3, "api_tool": 5, "file_tool": 2, "default": 1 } limit = limits.get(tool_name, limits["default"]) self.tool_semaphores[tool_name] = asyncio.Semaphore(limit) return self.tool_semaphores[tool_name] def _call_tool_function(self, tool_name: str, parameters: dict) -> Any: """调用具体的工具函数""" tool_registry = { "database_tool": self._execute_database_tool, "api_tool": self._execute_api_tool, "file_tool": self._execute_file_tool } tool_func = tool_registry.get(tool_name) if tool_func: return tool_func(parameters) else: raise ValueError(f"未知工具: {tool_name}") def _execute_database_tool(self, parameters: dict) -> dict: """执行数据库工具""" time.sleep(1) return {"result": "database_query_result", "data": parameters} def _execute_api_tool(self, parameters: dict) -> dict: """执行API工具""" time.sleep(0.5) return {"result": "api_call_result", "data": parameters} def _execute_file_tool(self, parameters: dict) -> dict: """执行文件工具""" time.sleep(0.3) return {"result": "file_operation_result", "data": parameters}
|