Skip to content

Commit a17bcbb

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Add a tool confirmation flow that can guard tool execution with explicit confirmation and custom input
The existing `LongRunningTool` does not define a programmatic way to provide & validate structured input, also it relies on LLM to reason and parse the user's response. For a quick start, annotate the function with `FunctionTool(my_function, require_confirmation=True)`. A more advanced flow is shown in the `human_tool_confirmation` sample. The new flow is similar to the existing Auth flow: - User request a tool confirmation by calling `tool_context.request_confirmation()` in the tool or `before_tool_callback`, or just using the `require_confirmation` shortcut in FunctionTool. - User can provide custom validation logic before tool call proceeds. - ADK creates corresponding RequestConfirmation FunctionCall Event to ask user for confirmation - User needs to provide the expected tool confirmation to a RequestConfirmation FunctionResponse Event. - ADK then checks the response and continues the tool call. PiperOrigin-RevId: 801019917
1 parent 3ed9097 commit a17bcbb

File tree

15 files changed

+889
-14
lines changed

15 files changed

+889
-14
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk import Agent
16+
from google.adk.tools.function_tool import FunctionTool
17+
from google.adk.tools.tool_confirmation import ToolConfirmation
18+
from google.adk.tools.tool_context import ToolContext
19+
from google.genai import types
20+
21+
22+
def reimburse(amount: int, tool_context: ToolContext) -> str:
23+
"""Reimburse the employee for the given amount."""
24+
return {'status': 'ok'}
25+
26+
27+
def request_time_off(days: int, tool_context: ToolContext):
28+
"""Request day off for the employee."""
29+
if days <= 0:
30+
return {'status': 'Invalid days to request.'}
31+
32+
if days <= 2:
33+
return {
34+
'status': 'ok',
35+
'approved_days': days,
36+
}
37+
38+
tool_confirmation = tool_context.tool_confirmation
39+
if not tool_confirmation:
40+
tool_context.request_confirmation(
41+
hint=(
42+
'Please approve or reject the tool call request_time_off() by'
43+
' responding with a FunctionResponse with an expected'
44+
' ToolConfirmation payload.'
45+
),
46+
payload={
47+
'approved_days': 0,
48+
},
49+
)
50+
return {'status': 'Manager approval is required.'}
51+
52+
approved_days = tool_confirmation.payload['approved_days']
53+
approved_days = min(approved_days, days)
54+
if approved_days == 0:
55+
return {'status': 'The time off request is rejected.', 'approved_days': 0}
56+
return {
57+
'status': 'ok',
58+
'approved_days': approved_days,
59+
}
60+
61+
62+
root_agent = Agent(
63+
model='gemini-2.5-flash',
64+
name='time_off_agent',
65+
instruction="""
66+
You are a helpful assistant that can help employees with reimbursement and time off requests.
67+
- Use the `reimburse` tool for reimbursement requests.
68+
- Use the `request_time_off` tool for time off requests.
69+
- Prioritize using tools to fulfill the user's request.
70+
- Always respond to the user with the tool results.
71+
""",
72+
tools=[
73+
# Set require_confirmation to True to require user confirmation for the
74+
# tool call. This is an easier way to get user confirmation if the tool
75+
# just need a boolean confirmation.
76+
FunctionTool(reimburse, require_confirmation=True),
77+
request_time_off,
78+
],
79+
generate_content_config=types.GenerateContentConfig(temperature=0.1),
80+
)

src/google/adk/events/event_actions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Any
1718
from typing import Optional
1819

1920
from pydantic import alias_generators
@@ -22,6 +23,7 @@
2223
from pydantic import Field
2324

2425
from ..auth.auth_tool import AuthConfig
26+
from ..tools.tool_confirmation import ToolConfirmation
2527

2628

2729
class EventActions(BaseModel):
@@ -64,3 +66,9 @@ class EventActions(BaseModel):
6466
identify the function call.
6567
- Values: The requested auth config.
6668
"""
69+
70+
requested_tool_confirmations: dict[str, ToolConfirmation] = Field(
71+
default_factory=dict
72+
)
73+
"""A dict of tool confirmation requested by this event, keyed by
74+
function call id."""

src/google/adk/flows/llm_flows/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
from . import functions
1919
from . import identity
2020
from . import instructions
21+
from . import request_confirmation

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,12 @@ async def _postprocess_handle_function_calls_async(
638638
if auth_event:
639639
yield auth_event
640640

641+
tool_confirmation_event = functions.generate_request_confirmation_event(
642+
invocation_context, function_call_event, function_response_event
643+
)
644+
if tool_confirmation_event:
645+
yield tool_confirmation_event
646+
641647
# Always yield the function response event first
642648
yield function_response_event
643649

src/google/adk/flows/llm_flows/contents.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ...models.llm_request import LlmRequest
2828
from ._base_llm_processor import BaseLlmRequestProcessor
2929
from .functions import remove_client_function_call_id
30+
from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
3031
from .functions import REQUEST_EUC_FUNCTION_CALL_NAME
3132

3233

@@ -238,6 +239,9 @@ def _get_contents(
238239
if _is_auth_event(event):
239240
# Skip auth events.
240241
continue
242+
if _is_request_confirmation_event(event):
243+
# Skip request confirmation events.
244+
continue
241245
filtered_events.append(
242246
_convert_foreign_event(event)
243247
if _is_other_agent_reply(agent_name, event)
@@ -431,18 +435,23 @@ def _is_event_belongs_to_branch(
431435
return invocation_branch.startswith(event.branch)
432436

433437

434-
def _is_auth_event(event: Event) -> bool:
435-
if not event.content.parts:
438+
def _is_function_call_event(event: Event, function_name: str) -> bool:
439+
"""Checks if an event is a function call/response for a given function name."""
440+
if not event.content or not event.content.parts:
436441
return False
437442
for part in event.content.parts:
438-
if (
439-
part.function_call
440-
and part.function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
441-
):
443+
if part.function_call and part.function_call.name == function_name:
442444
return True
443-
if (
444-
part.function_response
445-
and part.function_response.name == REQUEST_EUC_FUNCTION_CALL_NAME
446-
):
445+
if part.function_response and part.function_response.name == function_name:
447446
return True
448447
return False
448+
449+
450+
def _is_auth_event(event: Event) -> bool:
451+
"""Checks if the event is an authentication event."""
452+
return _is_function_call_event(event, REQUEST_EUC_FUNCTION_CALL_NAME)
453+
454+
455+
def _is_request_confirmation_event(event: Event) -> bool:
456+
"""Checks if the event is a request confirmation event."""
457+
return _is_function_call_event(event, REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)

src/google/adk/flows/llm_flows/functions.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from ...telemetry import trace_tool_call
4040
from ...telemetry import tracer
4141
from ...tools.base_tool import BaseTool
42+
from ...tools.tool_confirmation import ToolConfirmation
4243
from ...tools.tool_context import ToolContext
4344
from ...utils.context_utils import Aclosing
4445

@@ -47,6 +48,7 @@
4748

4849
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
4950
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
51+
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation'
5052

5153
logger = logging.getLogger('google_adk.' + __name__)
5254

@@ -130,11 +132,76 @@ def generate_auth_event(
130132
)
131133

132134

135+
def generate_request_confirmation_event(
136+
invocation_context: InvocationContext,
137+
function_call_event: Event,
138+
function_response_event: Event,
139+
) -> Optional[Event]:
140+
"""Generates a request confirmation event from a function response event."""
141+
if not function_response_event.actions.requested_tool_confirmations:
142+
return None
143+
parts = []
144+
long_running_tool_ids = set()
145+
function_calls = function_call_event.get_function_calls()
146+
for (
147+
function_call_id,
148+
tool_confirmation,
149+
) in function_response_event.actions.requested_tool_confirmations.items():
150+
original_function_call = next(
151+
(fc for fc in function_calls if fc.id == function_call_id), None
152+
)
153+
if not original_function_call:
154+
continue
155+
request_confirmation_function_call = types.FunctionCall(
156+
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
157+
args={
158+
'originalFunctionCall': original_function_call.model_dump(
159+
exclude_none=True, by_alias=True
160+
),
161+
'toolConfirmation': tool_confirmation.model_dump(
162+
by_alias=True, exclude_none=True
163+
),
164+
},
165+
)
166+
request_confirmation_function_call.id = generate_client_function_call_id()
167+
long_running_tool_ids.add(request_confirmation_function_call.id)
168+
parts.append(types.Part(function_call=request_confirmation_function_call))
169+
170+
return Event(
171+
invocation_id=invocation_context.invocation_id,
172+
author=invocation_context.agent.name,
173+
branch=invocation_context.branch,
174+
content=types.Content(
175+
parts=parts, role=function_response_event.content.role
176+
),
177+
long_running_tool_ids=long_running_tool_ids,
178+
)
179+
180+
133181
async def handle_function_calls_async(
134182
invocation_context: InvocationContext,
135183
function_call_event: Event,
136184
tools_dict: dict[str, BaseTool],
137185
filters: Optional[set[str]] = None,
186+
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
187+
) -> Optional[Event]:
188+
"""Calls the functions and returns the function response event."""
189+
function_calls = function_call_event.get_function_calls()
190+
return await handle_function_call_list_async(
191+
invocation_context,
192+
function_calls,
193+
tools_dict,
194+
filters,
195+
tool_confirmation_dict,
196+
)
197+
198+
199+
async def handle_function_call_list_async(
200+
invocation_context: InvocationContext,
201+
function_calls: list[types.FunctionCall],
202+
tools_dict: dict[str, BaseTool],
203+
filters: Optional[set[str]] = None,
204+
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
138205
) -> Optional[Event]:
139206
"""Calls the functions and returns the function response event."""
140207
from ...agents.llm_agent import LlmAgent
@@ -143,8 +210,6 @@ async def handle_function_calls_async(
143210
if not isinstance(agent, LlmAgent):
144211
return None
145212

146-
function_calls = function_call_event.get_function_calls()
147-
148213
# Filter function calls
149214
filtered_calls = [
150215
fc for fc in function_calls if not filters or fc.id in filters
@@ -161,6 +226,9 @@ async def handle_function_calls_async(
161226
function_call,
162227
tools_dict,
163228
agent,
229+
tool_confirmation_dict[function_call.id]
230+
if tool_confirmation_dict
231+
else None,
164232
)
165233
)
166234
for function_call in filtered_calls
@@ -198,12 +266,14 @@ async def _execute_single_function_call_async(
198266
function_call: types.FunctionCall,
199267
tools_dict: dict[str, BaseTool],
200268
agent: LlmAgent,
269+
tool_confirmation: Optional[ToolConfirmation] = None,
201270
) -> Optional[Event]:
202271
"""Execute a single function call with thread safety for state modifications."""
203272
tool, tool_context = _get_tool_and_context(
204273
invocation_context,
205274
function_call,
206275
tools_dict,
276+
tool_confirmation,
207277
)
208278

209279
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
@@ -567,6 +637,7 @@ def _get_tool_and_context(
567637
invocation_context: InvocationContext,
568638
function_call: types.FunctionCall,
569639
tools_dict: dict[str, BaseTool],
640+
tool_confirmation: Optional[ToolConfirmation] = None,
570641
):
571642
if function_call.name not in tools_dict:
572643
raise ValueError(
@@ -576,6 +647,7 @@ def _get_tool_and_context(
576647
tool_context = ToolContext(
577648
invocation_context=invocation_context,
578649
function_call_id=function_call.id,
650+
tool_confirmation=tool_confirmation,
579651
)
580652

581653
tool = tools_dict[function_call.name]

0 commit comments

Comments
 (0)