mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 22:37:11 +00:00
Omit state from the Assist LLM prompts (#141034)
* Omit state from the Assist LLM prompts * Add back the stateful prompt
This commit is contained in:
parent
61e30d0e91
commit
4e2dfba45f
@ -1,41 +0,0 @@
|
||||
"""LLM API for MCP Server.
|
||||
|
||||
This is a modified version of the AssistAPI that does not include the home state
|
||||
in the prompt. This API is not registered with the LLM API registry since it is
|
||||
only used by the MCP Server. The MCP server will substitute this API when the
|
||||
user selects the Assist API.
|
||||
"""
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.util import yaml as yaml_util
|
||||
|
||||
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
|
||||
|
||||
|
||||
class StatelessAssistAPI(llm.AssistAPI):
|
||||
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
|
||||
|
||||
Syncing the state information is possible, but may put unnecessary load on
|
||||
the system so we are instead providing the prompt without entity state. Since
|
||||
actions don't care about the current state, there is little quality loss.
|
||||
"""
|
||||
|
||||
@callback
|
||||
def _async_get_exposed_entities_prompt(
|
||||
self, llm_context: llm.LLMContext, exposed_entities: dict | None
|
||||
) -> list[str]:
|
||||
"""Return the prompt for the exposed entities."""
|
||||
prompt = []
|
||||
|
||||
if exposed_entities and exposed_entities["entities"]:
|
||||
prompt.append(
|
||||
"An overview of the areas and the devices in this smart home:"
|
||||
)
|
||||
entities = [
|
||||
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
|
||||
for entity_info in exposed_entities["entities"].values()
|
||||
]
|
||||
prompt.append(yaml_util.dump(list(entities)))
|
||||
|
||||
return prompt
|
@ -22,7 +22,6 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from .const import STATELESS_LLM_API
|
||||
from .llm_api import StatelessAssistAPI
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -50,15 +49,14 @@ async def create_server(
|
||||
A Model Context Protocol Server object is associated with a single session.
|
||||
The MCP SDK handles the details of the protocol.
|
||||
"""
|
||||
if llm_api_id == STATELESS_LLM_API:
|
||||
llm_api_id = llm.LLM_API_ASSIST
|
||||
|
||||
server = Server("home-assistant")
|
||||
|
||||
async def get_api_instance() -> llm.APIInstance:
|
||||
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
|
||||
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
|
||||
api = StatelessAssistAPI(hass)
|
||||
return await api.async_get_api_instance(llm_context)
|
||||
|
||||
"""Get the LLM API selected."""
|
||||
# Backwards compatibility with old MCP Server config
|
||||
return await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
|
||||
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
||||
|
@ -316,7 +316,7 @@ class AssistAPI(API):
|
||||
"""Return the instance of the API."""
|
||||
if llm_context.assistant:
|
||||
exposed_entities: dict | None = _get_exposed_entities(
|
||||
self.hass, llm_context.assistant
|
||||
self.hass, llm_context.assistant, include_state=False
|
||||
)
|
||||
else:
|
||||
exposed_entities = None
|
||||
@ -463,7 +463,9 @@ class AssistAPI(API):
|
||||
|
||||
|
||||
def _get_exposed_entities(
|
||||
hass: HomeAssistant, assistant: str
|
||||
hass: HomeAssistant,
|
||||
assistant: str,
|
||||
include_state: bool = True,
|
||||
) -> dict[str, dict[str, dict[str, Any]]]:
|
||||
"""Get exposed entities.
|
||||
|
||||
@ -524,24 +526,28 @@ def _get_exposed_entities(
|
||||
info: dict[str, Any] = {
|
||||
"names": ", ".join(names),
|
||||
"domain": state.domain,
|
||||
"state": state.state,
|
||||
}
|
||||
|
||||
if include_state:
|
||||
info["state"] = state.state
|
||||
|
||||
if description:
|
||||
info["description"] = description
|
||||
|
||||
if area_names:
|
||||
info["areas"] = ", ".join(area_names)
|
||||
|
||||
if attributes := {
|
||||
attr_name: (
|
||||
str(attr_value)
|
||||
if isinstance(attr_value, (Enum, Decimal, int))
|
||||
else attr_value
|
||||
)
|
||||
for attr_name, attr_value in state.attributes.items()
|
||||
if attr_name in interesting_attributes
|
||||
}:
|
||||
if include_state and (
|
||||
attributes := {
|
||||
attr_name: (
|
||||
str(attr_value)
|
||||
if isinstance(attr_value, (Enum, Decimal, int))
|
||||
else attr_value
|
||||
)
|
||||
for attr_name, attr_value in state.attributes.items()
|
||||
if attr_name in interesting_attributes
|
||||
}
|
||||
):
|
||||
info["attributes"] = attributes
|
||||
|
||||
if state.domain in data:
|
||||
|
@ -622,6 +622,40 @@ async def test_assist_api_prompt(
|
||||
domain: light
|
||||
state: unavailable
|
||||
areas: Test Area 2
|
||||
"""
|
||||
stateless_exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
|
||||
- names: Kitchen
|
||||
domain: light
|
||||
- names: Living Room
|
||||
domain: light
|
||||
areas: Test Area, Alternative name
|
||||
- names: Test Device, my test light
|
||||
domain: light
|
||||
areas: Test Area, Alternative name
|
||||
- names: Test Service
|
||||
domain: light
|
||||
areas: Test Area, Alternative name
|
||||
- names: Test Service
|
||||
domain: light
|
||||
areas: Test Area, Alternative name
|
||||
- names: Test Service
|
||||
domain: light
|
||||
areas: Test Area, Alternative name
|
||||
- names: Test Device 2
|
||||
domain: light
|
||||
areas: Test Area 2
|
||||
- names: Test Device 3
|
||||
domain: light
|
||||
areas: Test Area 2
|
||||
- names: Test Device 4
|
||||
domain: light
|
||||
areas: Test Area 2
|
||||
- names: Unnamed Device
|
||||
domain: light
|
||||
areas: Test Area 2
|
||||
- names: '1'
|
||||
domain: light
|
||||
areas: Test Area 2
|
||||
"""
|
||||
first_part_prompt = (
|
||||
"When controlling Home Assistant always call the intent tools. "
|
||||
@ -640,7 +674,7 @@ async def test_assist_api_prompt(
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
{stateless_exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Verify that the get_home_state tool returns the same results as the exposed_entities_prompt
|
||||
@ -663,7 +697,7 @@ async def test_assist_api_prompt(
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
{stateless_exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Add floor
|
||||
@ -678,7 +712,7 @@ async def test_assist_api_prompt(
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
{stateless_exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Register device for timers
|
||||
@ -689,7 +723,7 @@ async def test_assist_api_prompt(
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
{stateless_exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user