Skip to content

Commit a2e89a2

Browse files
hangfeicopybara-github
authored andcommitted
feat: Upgrade ADK stack to use App instead in addition to root_agent
The convention: - If some fields(like plugin) are defined both at root_agent and app, then a error will be raised. - app code should be located within agent.py. - an instance named app should be created PiperOrigin-RevId: 801084463
1 parent 98b0426 commit a2e89a2

File tree

11 files changed

+476
-36
lines changed

11 files changed

+476
-36
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
from google.adk import Agent
18+
from google.adk.agents.base_agent import BaseAgent
19+
from google.adk.agents.callback_context import CallbackContext
20+
from google.adk.apps import App
21+
from google.adk.models.llm_request import LlmRequest
22+
from google.adk.plugins.base_plugin import BasePlugin
23+
from google.adk.tools.tool_context import ToolContext
24+
from google.genai import types
25+
26+
27+
def roll_die(sides: int, tool_context: ToolContext) -> int:
28+
"""Roll a die and return the rolled result.
29+
30+
Args:
31+
sides: The integer number of sides the die has.
32+
33+
Returns:
34+
An integer of the result of rolling the die.
35+
"""
36+
result = random.randint(1, sides)
37+
if not 'rolls' in tool_context.state:
38+
tool_context.state['rolls'] = []
39+
40+
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
41+
return result
42+
43+
44+
async def check_prime(nums: list[int]) -> str:
45+
"""Check if a given list of numbers are prime.
46+
47+
Args:
48+
nums: The list of numbers to check.
49+
50+
Returns:
51+
A str indicating which number is prime.
52+
"""
53+
primes = set()
54+
for number in nums:
55+
number = int(number)
56+
if number <= 1:
57+
continue
58+
is_prime = True
59+
for i in range(2, int(number**0.5) + 1):
60+
if number % i == 0:
61+
is_prime = False
62+
break
63+
if is_prime:
64+
primes.add(number)
65+
return (
66+
'No prime numbers found.'
67+
if not primes
68+
else f"{', '.join(str(num) for num in primes)} are prime numbers."
69+
)
70+
71+
72+
root_agent = Agent(
73+
model='gemini-2.0-flash',
74+
name='hello_world_agent',
75+
description=(
76+
'hello world agent that can roll a dice of 8 sides and check prime'
77+
' numbers.'
78+
),
79+
instruction="""
80+
You roll dice and answer questions about the outcome of the dice rolls.
81+
You can roll dice of different sizes.
82+
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
83+
It is ok to discuss previous dice roles, and comment on the dice rolls.
84+
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
85+
You should never roll a die on your own.
86+
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
87+
You should not check prime numbers before calling the tool.
88+
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
89+
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
90+
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
91+
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
92+
3. When you respond, you must include the roll_die result from step 1.
93+
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
94+
You should not rely on the previous history on prime results.
95+
""",
96+
tools=[
97+
roll_die,
98+
check_prime,
99+
],
100+
# planner=BuiltInPlanner(
101+
# thinking_config=types.ThinkingConfig(
102+
# include_thoughts=True,
103+
# ),
104+
# ),
105+
generate_content_config=types.GenerateContentConfig(
106+
safety_settings=[
107+
types.SafetySetting( # avoid false alarm about rolling dice.
108+
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
109+
threshold=types.HarmBlockThreshold.OFF,
110+
),
111+
]
112+
),
113+
)
114+
115+
116+
class CountInvocationPlugin(BasePlugin):
117+
"""A custom plugin that counts agent and tool invocations."""
118+
119+
def __init__(self) -> None:
120+
"""Initialize the plugin with counters."""
121+
super().__init__(name='count_invocation')
122+
self.agent_count: int = 0
123+
self.tool_count: int = 0
124+
self.llm_request_count: int = 0
125+
126+
async def before_agent_callback(
127+
self, *, agent: BaseAgent, callback_context: CallbackContext
128+
) -> None:
129+
"""Count agent runs."""
130+
self.agent_count += 1
131+
print(f'[Plugin] Agent run count: {self.agent_count}')
132+
133+
async def before_model_callback(
134+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
135+
) -> None:
136+
"""Count LLM requests."""
137+
self.llm_request_count += 1
138+
print(f'[Plugin] LLM request count: {self.llm_request_count}')
139+
140+
141+
app = App(
142+
name='hello_world_app',
143+
root_agent=root_agent,
144+
plugins=[CountInvocationPlugin()],
145+
)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import time
17+
18+
import agent
19+
from dotenv import load_dotenv
20+
from google.adk.agents.run_config import RunConfig
21+
from google.adk.cli.utils import logs
22+
from google.adk.runners import InMemoryRunner
23+
from google.adk.sessions.session import Session
24+
from google.genai import types
25+
26+
load_dotenv(override=True)
27+
logs.log_to_tmp_folder()
28+
29+
30+
async def main():
31+
app_name = 'my_app'
32+
user_id_1 = 'user1'
33+
runner = InMemoryRunner(
34+
agent=agent.root_agent,
35+
app_name=app_name,
36+
)
37+
session_11 = await runner.session_service.create_session(
38+
app_name=app_name, user_id=user_id_1
39+
)
40+
41+
async def run_prompt(session: Session, new_message: str):
42+
content = types.Content(
43+
role='user', parts=[types.Part.from_text(text=new_message)]
44+
)
45+
print('** User says:', content.model_dump(exclude_none=True))
46+
async for event in runner.run_async(
47+
user_id=user_id_1,
48+
session_id=session.id,
49+
new_message=content,
50+
):
51+
if event.content.parts and event.content.parts[0].text:
52+
print(f'** {event.author}: {event.content.parts[0].text}')
53+
54+
async def run_prompt_bytes(session: Session, new_message: str):
55+
content = types.Content(
56+
role='user',
57+
parts=[
58+
types.Part.from_bytes(
59+
data=str.encode(new_message), mime_type='text/plain'
60+
)
61+
],
62+
)
63+
print('** User says:', content.model_dump(exclude_none=True))
64+
async for event in runner.run_async(
65+
user_id=user_id_1,
66+
session_id=session.id,
67+
new_message=content,
68+
run_config=RunConfig(save_input_blobs_as_artifacts=True),
69+
):
70+
if event.content.parts and event.content.parts[0].text:
71+
print(f'** {event.author}: {event.content.parts[0].text}')
72+
73+
async def check_rolls_in_state(rolls_size: int):
74+
session = await runner.session_service.get_session(
75+
app_name=app_name, user_id=user_id_1, session_id=session_11.id
76+
)
77+
assert len(session.state['rolls']) == rolls_size
78+
for roll in session.state['rolls']:
79+
assert roll > 0 and roll <= 100
80+
81+
start_time = time.time()
82+
print('Start time:', start_time)
83+
print('------------------------------------')
84+
await run_prompt(session_11, 'Hi')
85+
await run_prompt(session_11, 'Roll a die with 100 sides')
86+
await check_rolls_in_state(1)
87+
await run_prompt(session_11, 'Roll a die again with 100 sides.')
88+
await check_rolls_in_state(2)
89+
await run_prompt(session_11, 'What numbers did I got?')
90+
await run_prompt_bytes(session_11, 'Hi bytes')
91+
print(
92+
await runner.artifact_service.list_artifact_keys(
93+
app_name=app_name, user_id=user_id_1, session_id=session_11.id
94+
)
95+
)
96+
end_time = time.time()
97+
print('------------------------------------')
98+
print('End time:', end_time)
99+
print('Total time:', end_time - start_time)
100+
101+
102+
if __name__ == '__main__':
103+
asyncio.run(main())

src/google/adk/apps/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .app import App
16+
17+
__all__ = [
18+
'App',
19+
]

src/google/adk/apps/app.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from abc import ABC
17+
from typing import Optional
18+
19+
from pydantic import BaseModel
20+
from pydantic import ConfigDict
21+
from pydantic import Field
22+
23+
from ..agents.base_agent import BaseAgent
24+
from ..plugins.base_plugin import BasePlugin
25+
from ..utils.feature_decorator import experimental
26+
27+
28+
@experimental
29+
class App(BaseModel):
30+
"""Represents an LLM-backed agentic application.
31+
32+
An `App` is the top-level container for an agentic system powered by LLMs.
33+
It manages a root agent (`root_agent`), which serves as the root of an agent
34+
tree, enabling coordination and communication across all agents in the
35+
hierarchy.
36+
The `plugins` are application-wide components that provide shared capabilities
37+
and services to the entire system.
38+
"""
39+
40+
model_config = ConfigDict(
41+
arbitrary_types_allowed=True,
42+
extra="forbid",
43+
)
44+
45+
name: str
46+
"""The name of the application."""
47+
48+
root_agent: BaseAgent
49+
"""The root agent in the application. One app can only have one root agent."""
50+
51+
plugins: list[BasePlugin] = Field(default_factory=list)
52+
"""The plugins in the application."""

src/google/adk/cli/adk_web_server.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@
5050
from watchdog.observers import Observer
5151

5252
from . import agent_graph
53+
from ..agents.base_agent import BaseAgent
5354
from ..agents.live_request_queue import LiveRequest
5455
from ..agents.live_request_queue import LiveRequestQueue
5556
from ..agents.run_config import RunConfig
5657
from ..agents.run_config import StreamingMode
58+
from ..apps import App
5759
from ..artifacts.base_artifact_service import BaseArtifactService
5860
from ..auth.credential_service.base_credential_service import BaseCredentialService
5961
from ..errors.not_found_error import NotFoundError
@@ -305,10 +307,17 @@ async def get_runner_async(self, app_name: str) -> Runner:
305307
envs.load_dotenv_for_agent(os.path.basename(app_name), self.agents_dir)
306308
if app_name in self.runner_dict:
307309
return self.runner_dict[app_name]
308-
root_agent = self.agent_loader.load_agent(app_name)
310+
agent_or_app = self.agent_loader.load_agent(app_name)
311+
agentic_app = None
312+
if isinstance(agent_or_app, BaseAgent):
313+
agentic_app = App(
314+
name=app_name,
315+
root_agent=agent_or_app,
316+
)
317+
else:
318+
agentic_app = agent_or_app
309319
runner = Runner(
310-
app_name=app_name,
311-
agent=root_agent,
320+
app=agentic_app,
312321
artifact_service=self.artifact_service,
313322
session_service=self.session_service,
314323
memory_service=self.memory_service,
@@ -597,9 +606,10 @@ async def add_session_to_eval_set(
597606
invocations = evals.convert_session_to_eval_invocations(session)
598607

599608
# Populate the session with initial session state.
600-
initial_session_state = create_empty_state(
601-
self.agent_loader.load_agent(app_name)
602-
)
609+
agent_or_app = self.agent_loader.load_agent(app_name)
610+
if isinstance(agent_or_app, App):
611+
agent_or_app = agent_or_app.root_agent
612+
initial_session_state = create_empty_state(agent_or_app)
603613

604614
new_eval_case = EvalCase(
605615
eval_id=req.eval_id,

0 commit comments

Comments
 (0)