diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 4407c5515..cb719b2c9 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -77,7 +77,7 @@ Monitor, ) from .remote_executors import DockerExecutor, E2BExecutor, WasmExecutor -from .tools import BaseTool, Tool, validate_tool_arguments +from .tools import BaseTool, Tool, ToolRegistry, validate_tool_arguments from .utils import ( AgentError, AgentExecutionError, @@ -361,7 +361,7 @@ def _setup_tools(self, tools, add_base_tools): assert all(isinstance(tool, BaseTool) for tool in tools), ( "All elements must be instance of BaseTool (or a subclass)" ) - self.tools = {tool.name: tool for tool in tools} + self.tools = ToolRegistry(tools) if add_base_tools: self.tools.update( { diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 95b6b45b5..b1acc8266 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -103,6 +103,21 @@ def __call__(self, *args, **kwargs) -> Any: pass +class ToolRegistry(dict): + """Registry for tools that provides dict-like access.""" + + def __init__(self, tools: list[BaseTool]): + super().__init__({tool.name: tool for tool in tools}) + + def __setitem__(self, key: str, value: BaseTool): + """Add or update a tool.""" + super().__setitem__(key, value) + + def __repr__(self) -> str: + """String representation of the registry.""" + return f"ToolRegistry({list(self.keys())})" + + class Tool(BaseTool): """ A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the