
| import asyncio import weakref from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Callable, Any
@dataclass class ToolTask: """工具任务定义""" task_id: str tool_name: str agent_id: str parameters: dict dependencies: List[str] callback: Callable created_time: float
class LockFreeToolScheduler: """无锁工具调度器""" def __init__(self, max_workers: int = 10): self.executor = ThreadPoolExecutor(max_workers=max_workers) self.task_queue = asyncio.Queue() self.running_tasks = {} self.completed_results = {} self.tool_pools = {} self.tool_semaphores = {} asyncio.create_task(self._scheduling_loop()) async def submit_tool_call(self, tool_name: str, agent_id: str, dependencies: List[str] = None, **kwargs) -> str: """提交工具调用任务""" task_id = f"{agent_id}_{tool_name}_{int(time.time() * 1000)}" task = ToolTask( task_id=task_id, tool_name=tool_name, agent_id=agent_id, parameters=kwargs, dependencies=dependencies or [], callback=None, created_time=time.time() ) await self.task_queue.put(task) return task_id async def _scheduling_loop(self): """调度主循环""" while True: try: try: task = await asyncio.wait_for(self.task_queue.get(), timeout=0.1) except asyncio.TimeoutError: continue if self._dependencies_satisfied(task): await self._execute_task(task) else: await asyncio.sleep(0.1) await self.task_queue.put(task) except Exception as e: print(f"调度循环异常: {e}") def _dependencies_satisfied(self, task: ToolTask) -> bool: """检查任务依赖是否满足""" for dep_tool in task.dependencies: dep_key = f"{task.agent_id}_{dep_tool}" if dep_key not in self.completed_results: return False return True async def _execute_task(self, task: ToolTask): """执行任务""" semaphore = self._get_tool_semaphore(task.tool_name) async with semaphore: try: dep_results = {} for dep_tool in task.dependencies: dep_key = f"{task.agent_id}_{dep_tool}" dep_results[dep_tool] = self.completed_results[dep_key] loop = asyncio.get_event_loop() result = await loop.run_in_executor( self.executor, self._call_tool_function, task.tool_name, {**task.parameters, **dep_results} ) result_key = f"{task.agent_id}_{task.tool_name}" self.completed_results[result_key] = result print(f"任务完成: {task.task_id}") except Exception as e: print(f"任务执行失败: {task.task_id}, 错误: {e}") def _get_tool_semaphore(self, tool_name: str) -> asyncio.Semaphore: """获取工具信号量""" if tool_name not in self.tool_semaphores: limits = { "database_tool": 3, "api_tool": 5, "file_tool": 2, "default": 1 } limit = limits.get(tool_name, limits["default"]) self.tool_semaphores[tool_name] = asyncio.Semaphore(limit) return self.tool_semaphores[tool_name] def _call_tool_function(self, tool_name: str, parameters: dict) -> Any: """调用具体的工具函数""" tool_registry = { "database_tool": self._execute_database_tool, "api_tool": self._execute_api_tool, "file_tool": self._execute_file_tool } tool_func = tool_registry.get(tool_name) if tool_func: return tool_func(parameters) else: raise ValueError(f"未知工具: {tool_name}") def _execute_database_tool(self, parameters: dict) -> dict: """执行数据库工具""" time.sleep(1) return {"result": "database_query_result", "data": parameters} def _execute_api_tool(self, parameters: dict) -> dict: """执行API工具""" time.sleep(0.5) return {"result": "api_call_result", "data": parameters} def _execute_file_tool(self, parameters: dict) -> dict: """执行文件工具""" time.sleep(0.3) return {"result": "file_operation_result", "data": parameters}
|