18
18
from .util ._types import MaybeAwaitable
19
19
20
20
if TYPE_CHECKING :
21
- from .agent import Agent
21
+ from .agent import Agent , AgentBase
22
22
23
23
24
24
# The handoff input type is the type of data passed when the agent is called via a handoff.
25
25
THandoffInput = TypeVar ("THandoffInput" , default = Any )
26
26
27
+ # The agent type that the handoff returns
28
+ TAgent = TypeVar ("TAgent" , bound = "AgentBase[Any]" , default = "Agent[Any]" )
29
+
27
30
OnHandoffWithInput = Callable [[RunContextWrapper [Any ], THandoffInput ], Any ]
28
31
OnHandoffWithoutInput = Callable [[RunContextWrapper [Any ]], Any ]
29
32
@@ -52,7 +55,7 @@ class HandoffInputData:
52
55
53
56
54
57
@dataclass
55
- class Handoff (Generic [TContext ]):
58
+ class Handoff (Generic [TContext , TAgent ]):
56
59
"""A handoff is when an agent delegates a task to another agent.
57
60
For example, in a customer support scenario you might have a "triage agent" that determines
58
61
which agent should handle the user's request, and sub-agents that specialize in different
@@ -69,7 +72,7 @@ class Handoff(Generic[TContext]):
69
72
"""The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
70
73
"""
71
74
72
- on_invoke_handoff : Callable [[RunContextWrapper [Any ], str ], Awaitable [Agent [ TContext ] ]]
75
+ on_invoke_handoff : Callable [[RunContextWrapper [Any ], str ], Awaitable [TAgent ]]
73
76
"""The function that invokes the handoff. The parameters passed are:
74
77
1. The handoff run context
75
78
2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
@@ -100,20 +103,22 @@ class Handoff(Generic[TContext]):
100
103
True, as it increases the likelihood of correct JSON input.
101
104
"""
102
105
103
- is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True
106
+ is_enabled : bool | Callable [[RunContextWrapper [Any ], AgentBase [Any ]], MaybeAwaitable [bool ]] = (
107
+ True
108
+ )
104
109
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105
110
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106
111
a handoff based on your context/state."""
107
112
108
- def get_transfer_message (self , agent : Agent [Any ]) -> str :
113
+ def get_transfer_message (self , agent : AgentBase [Any ]) -> str :
109
114
return json .dumps ({"assistant" : agent .name })
110
115
111
116
@classmethod
112
- def default_tool_name (cls , agent : Agent [Any ]) -> str :
117
+ def default_tool_name (cls , agent : AgentBase [Any ]) -> str :
113
118
return _transforms .transform_string_function_style (f"transfer_to_{ agent .name } " )
114
119
115
120
@classmethod
116
- def default_tool_description (cls , agent : Agent [Any ]) -> str :
121
+ def default_tool_description (cls , agent : AgentBase [Any ]) -> str :
117
122
return (
118
123
f"Handoff to the { agent .name } agent to handle the request. "
119
124
f"{ agent .handoff_description or '' } "
@@ -128,7 +133,7 @@ def handoff(
128
133
tool_description_override : str | None = None ,
129
134
input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
130
135
is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
131
- ) -> Handoff [TContext ]: ...
136
+ ) -> Handoff [TContext , Agent [ TContext ] ]: ...
132
137
133
138
134
139
@overload
@@ -141,7 +146,7 @@ def handoff(
141
146
tool_name_override : str | None = None ,
142
147
input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
143
148
is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
144
- ) -> Handoff [TContext ]: ...
149
+ ) -> Handoff [TContext , Agent [ TContext ] ]: ...
145
150
146
151
147
152
@overload
@@ -153,7 +158,7 @@ def handoff(
153
158
tool_name_override : str | None = None ,
154
159
input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
155
160
is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
156
- ) -> Handoff [TContext ]: ...
161
+ ) -> Handoff [TContext , Agent [ TContext ] ]: ...
157
162
158
163
159
164
def handoff (
@@ -163,8 +168,9 @@ def handoff(
163
168
on_handoff : OnHandoffWithInput [THandoffInput ] | OnHandoffWithoutInput | None = None ,
164
169
input_type : type [THandoffInput ] | None = None ,
165
170
input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
166
- is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
167
- ) -> Handoff [TContext ]:
171
+ is_enabled : bool
172
+ | Callable [[RunContextWrapper [Any ], Agent [TContext ]], MaybeAwaitable [bool ]] = True ,
173
+ ) -> Handoff [TContext , Agent [TContext ]]:
168
174
"""Create a handoff from an agent.
169
175
170
176
Args:
@@ -202,7 +208,7 @@ def handoff(
202
208
203
209
async def _invoke_handoff (
204
210
ctx : RunContextWrapper [Any ], input_json : str | None = None
205
- ) -> Agent [Any ]:
211
+ ) -> Agent [TContext ]:
206
212
if input_type is not None and type_adapter is not None :
207
213
if input_json is None :
208
214
_error_tracing .attach_error_to_current_span (
@@ -239,12 +245,24 @@ async def _invoke_handoff(
239
245
# If there is a need, we can make this configurable in the future
240
246
input_json_schema = ensure_strict_json_schema (input_json_schema )
241
247
248
+ async def _is_enabled (ctx : RunContextWrapper [Any ], agent_base : AgentBase [Any ]) -> bool :
249
+ from .agent import Agent
250
+
251
+ assert callable (is_enabled ), "is_enabled must be non-null here"
252
+ assert isinstance (agent_base , Agent ), "Can't handoff to a non-Agent"
253
+ result = is_enabled (ctx , agent_base )
254
+
255
+ if inspect .isawaitable (result ):
256
+ return await result
257
+
258
+ return result
259
+
242
260
return Handoff (
243
261
tool_name = tool_name ,
244
262
tool_description = tool_description ,
245
263
input_json_schema = input_json_schema ,
246
264
on_invoke_handoff = _invoke_handoff ,
247
265
input_filter = input_filter ,
248
266
agent_name = agent .name ,
249
- is_enabled = is_enabled ,
267
+ is_enabled = _is_enabled if callable ( is_enabled ) else is_enabled ,
250
268
)
0 commit comments