Omit state from the Assist LLM prompts (#141034)

* Omit state from the Assist LLM prompts

* Add back the stateful prompt
This commit is contained in:
Allen Porter 2025-03-22 12:41:51 -07:00 committed by GitHub
parent 61e30d0e91
commit 4e2dfba45f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 60 additions and 63 deletions

View File

@ -1,41 +0,0 @@
"""LLM API for MCP Server.
This is a modified version of the AssistAPI that does not include the home state
in the prompt. This API is not registered with the LLM API registry since it is
only used by the MCP Server. The MCP server will substitute this API when the
user selects the Assist API.
"""
from homeassistant.core import callback
from homeassistant.helpers import llm
from homeassistant.util import yaml as yaml_util
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
class StatelessAssistAPI(llm.AssistAPI):
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
Syncing the state information is possible, but may put unnecessary load on
the system so we are instead providing the prompt without entity state. Since
actions don't care about the current state, there is little quality loss.
"""
@callback
def _async_get_exposed_entities_prompt(
self, llm_context: llm.LLMContext, exposed_entities: dict | None
) -> list[str]:
"""Return the prompt for the exposed entities."""
prompt = []
if exposed_entities and exposed_entities["entities"]:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
entities = [
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
for entity_info in exposed_entities["entities"].values()
]
prompt.append(yaml_util.dump(list(entities)))
return prompt

View File

@ -22,7 +22,6 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from .const import STATELESS_LLM_API
from .llm_api import StatelessAssistAPI
_LOGGER = logging.getLogger(__name__)
@ -50,15 +49,14 @@ async def create_server(
A Model Context Protocol Server object is associated with a single session.
The MCP SDK handles the details of the protocol.
"""
if llm_api_id == STATELESS_LLM_API:
llm_api_id = llm.LLM_API_ASSIST
server = Server("home-assistant")
async def get_api_instance() -> llm.APIInstance:
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
api = StatelessAssistAPI(hass)
return await api.async_get_api_instance(llm_context)
"""Get the LLM API selected."""
# Backwards compatibility with old MCP Server config
return await llm.async_get_api(hass, llm_api_id, llm_context)
@server.list_prompts() # type: ignore[no-untyped-call, misc]

View File

@ -316,7 +316,7 @@ class AssistAPI(API):
"""Return the instance of the API."""
if llm_context.assistant:
exposed_entities: dict | None = _get_exposed_entities(
self.hass, llm_context.assistant
self.hass, llm_context.assistant, include_state=False
)
else:
exposed_entities = None
@ -463,7 +463,9 @@ class AssistAPI(API):
def _get_exposed_entities(
hass: HomeAssistant, assistant: str
hass: HomeAssistant,
assistant: str,
include_state: bool = True,
) -> dict[str, dict[str, dict[str, Any]]]:
"""Get exposed entities.
@ -524,24 +526,28 @@ def _get_exposed_entities(
info: dict[str, Any] = {
"names": ", ".join(names),
"domain": state.domain,
"state": state.state,
}
if include_state:
info["state"] = state.state
if description:
info["description"] = description
if area_names:
info["areas"] = ", ".join(area_names)
if attributes := {
attr_name: (
str(attr_value)
if isinstance(attr_value, (Enum, Decimal, int))
else attr_value
)
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}:
if include_state and (
attributes := {
attr_name: (
str(attr_value)
if isinstance(attr_value, (Enum, Decimal, int))
else attr_value
)
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}
):
info["attributes"] = attributes
if state.domain in data:

View File

@ -622,6 +622,40 @@ async def test_assist_api_prompt(
domain: light
state: unavailable
areas: Test Area 2
"""
stateless_exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
- names: Kitchen
domain: light
- names: Living Room
domain: light
areas: Test Area, Alternative name
- names: Test Device, my test light
domain: light
areas: Test Area, Alternative name
- names: Test Service
domain: light
areas: Test Area, Alternative name
- names: Test Service
domain: light
areas: Test Area, Alternative name
- names: Test Service
domain: light
areas: Test Area, Alternative name
- names: Test Device 2
domain: light
areas: Test Area 2
- names: Test Device 3
domain: light
areas: Test Area 2
- names: Test Device 4
domain: light
areas: Test Area 2
- names: Unnamed Device
domain: light
areas: Test Area 2
- names: '1'
domain: light
areas: Test Area 2
"""
first_part_prompt = (
"When controlling Home Assistant always call the intent tools. "
@ -640,7 +674,7 @@ async def test_assist_api_prompt(
f"""{first_part_prompt}
{area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)
# Verify that the get_home_state tool returns the same results as the exposed_entities_prompt
@ -663,7 +697,7 @@ async def test_assist_api_prompt(
f"""{first_part_prompt}
{area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)
# Add floor
@ -678,7 +712,7 @@ async def test_assist_api_prompt(
f"""{first_part_prompt}
{area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)
# Register device for timers
@ -689,7 +723,7 @@ async def test_assist_api_prompt(
assert api.api_prompt == (
f"""{first_part_prompt}
{area_prompt}
{exposed_entities_prompt}"""
{stateless_exposed_entities_prompt}"""
)