Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Omit state from the Assist LLM prompts #141034

Merged
merged 3 commits into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions homeassistant/components/mcp_server/llm_api.py

This file was deleted.

10 changes: 4 additions & 6 deletions homeassistant/components/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from homeassistant.helpers import llm

from .const import STATELESS_LLM_API
from .llm_api import StatelessAssistAPI

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,15 +49,14 @@ async def create_server(
A Model Context Protocol Server object is associated with a single session.
The MCP SDK handles the details of the protocol.
"""
if llm_api_id == STATELESS_LLM_API:
llm_api_id = llm.LLM_API_ASSIST

server = Server("home-assistant")

async def get_api_instance() -> llm.APIInstance:
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
api = StatelessAssistAPI(hass)
return await api.async_get_api_instance(llm_context)

"""Get the LLM API selected."""
# Backwards compatibility with old MCP Server config
return await llm.async_get_api(hass, llm_api_id, llm_context)

@server.list_prompts() # type: ignore[no-untyped-call, misc]
Expand Down
30 changes: 18 additions & 12 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API."""
if llm_context.assistant:
exposed_entities: dict | None = _get_exposed_entities(
self.hass, llm_context.assistant
self.hass, llm_context.assistant, include_state=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct in my understanding that the change affects only MCP server right now, since it looks like it's the only one using this flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, all are changed.
(1) All prompts now have no state
(2) MCP now uses the assist API only
(3) State is provided via new get_home_state tools via #140971

)
else:
exposed_entities = None
Expand Down Expand Up @@ -463,7 +463,9 @@ def _async_get_tools(


def _get_exposed_entities(
hass: HomeAssistant, assistant: str
hass: HomeAssistant,
assistant: str,
include_state: bool = True,
) -> dict[str, dict[str, dict[str, Any]]]:
"""Get exposed entities.

Expand Down Expand Up @@ -524,24 +526,28 @@ def _get_exposed_entities(
info: dict[str, Any] = {
"names": ", ".join(names),
"domain": state.domain,
"state": state.state,
}

if include_state:
info["state"] = state.state

if description:
info["description"] = description

if area_names:
info["areas"] = ", ".join(area_names)

if attributes := {
attr_name: (
str(attr_value)
if isinstance(attr_value, (Enum, Decimal, int))
else attr_value
)
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}:
if include_state and (
attributes := {
attr_name: (
str(attr_value)
if isinstance(attr_value, (Enum, Decimal, int))
else attr_value
)
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}
):
info["attributes"] = attributes

if state.domain in data:
Expand Down
42 changes: 38 additions & 4 deletions tests/helpers/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,40 @@ def create_entity(
domain: light
state: unavailable
areas: Test Area 2
"""
stateless_exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
- names: Kitchen
domain: light
- names: Living Room
domain: light
areas: Test Area, Alternative name
- names: Test Device, my test light
domain: light
areas: Test Area, Alternative name
- names: Test Service
domain: light
areas: Test Area, Alternative name
- names: Test Service
domain: light
areas: Test Area, Alternative name
- names: Test Service
domain: light
areas: Test Area, Alternative name
- names: Test Device 2
domain: light
areas: Test Area 2
- names: Test Device 3
domain: light
areas: Test Area 2
- names: Test Device 4
domain: light
areas: Test Area 2
- names: Unnamed Device
domain: light
areas: Test Area 2
- names: '1'
domain: light
areas: Test Area 2
"""
first_part_prompt = (
"When controlling Home Assistant always call the intent tools. "
Expand All @@ -640,7 +674,7 @@ def create_entity(
f"""{first_part_prompt}
{area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)

# Verify that the get_home_state tool returns the same results as the exposed_entities_prompt
Expand All @@ -663,7 +697,7 @@ def create_entity(
f"""{first_part_prompt}
{area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)

# Add floor
Expand All @@ -678,7 +712,7 @@ def create_entity(
f"""{first_part_prompt}
{area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)

# Register device for timers
Expand All @@ -689,7 +723,7 @@ def create_entity(
assert api.api_prompt == (
f"""{first_part_prompt}
{area_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)


Expand Down