mirror of
https://github.com/home-assistant/core.git
synced 2025-07-12 15:57:06 +00:00
Omit state from the Assist LLM prompts (#141034)
* Omit state from the Assist LLM prompts * Add back the stateful prompt
This commit is contained in:
parent
61e30d0e91
commit
4e2dfba45f
@ -1,41 +0,0 @@
|
|||||||
"""LLM API for MCP Server.
|
|
||||||
|
|
||||||
This is a modified version of the AssistAPI that does not include the home state
|
|
||||||
in the prompt. This API is not registered with the LLM API registry since it is
|
|
||||||
only used by the MCP Server. The MCP server will substitute this API when the
|
|
||||||
user selects the Assist API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from homeassistant.core import callback
|
|
||||||
from homeassistant.helpers import llm
|
|
||||||
from homeassistant.util import yaml as yaml_util
|
|
||||||
|
|
||||||
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
|
|
||||||
|
|
||||||
|
|
||||||
class StatelessAssistAPI(llm.AssistAPI):
|
|
||||||
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
|
|
||||||
|
|
||||||
Syncing the state information is possible, but may put unnecessary load on
|
|
||||||
the system so we are instead providing the prompt without entity state. Since
|
|
||||||
actions don't care about the current state, there is little quality loss.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_get_exposed_entities_prompt(
|
|
||||||
self, llm_context: llm.LLMContext, exposed_entities: dict | None
|
|
||||||
) -> list[str]:
|
|
||||||
"""Return the prompt for the exposed entities."""
|
|
||||||
prompt = []
|
|
||||||
|
|
||||||
if exposed_entities and exposed_entities["entities"]:
|
|
||||||
prompt.append(
|
|
||||||
"An overview of the areas and the devices in this smart home:"
|
|
||||||
)
|
|
||||||
entities = [
|
|
||||||
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
|
|
||||||
for entity_info in exposed_entities["entities"].values()
|
|
||||||
]
|
|
||||||
prompt.append(yaml_util.dump(list(entities)))
|
|
||||||
|
|
||||||
return prompt
|
|
@ -22,7 +22,6 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
from .const import STATELESS_LLM_API
|
from .const import STATELESS_LLM_API
|
||||||
from .llm_api import StatelessAssistAPI
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -50,15 +49,14 @@ async def create_server(
|
|||||||
A Model Context Protocol Server object is associated with a single session.
|
A Model Context Protocol Server object is associated with a single session.
|
||||||
The MCP SDK handles the details of the protocol.
|
The MCP SDK handles the details of the protocol.
|
||||||
"""
|
"""
|
||||||
|
if llm_api_id == STATELESS_LLM_API:
|
||||||
|
llm_api_id = llm.LLM_API_ASSIST
|
||||||
|
|
||||||
server = Server("home-assistant")
|
server = Server("home-assistant")
|
||||||
|
|
||||||
async def get_api_instance() -> llm.APIInstance:
|
async def get_api_instance() -> llm.APIInstance:
|
||||||
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
|
"""Get the LLM API selected."""
|
||||||
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
|
# Backwards compatibility with old MCP Server config
|
||||||
api = StatelessAssistAPI(hass)
|
|
||||||
return await api.async_get_api_instance(llm_context)
|
|
||||||
|
|
||||||
return await llm.async_get_api(hass, llm_api_id, llm_context)
|
return await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||||
|
|
||||||
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
||||||
|
@ -316,7 +316,7 @@ class AssistAPI(API):
|
|||||||
"""Return the instance of the API."""
|
"""Return the instance of the API."""
|
||||||
if llm_context.assistant:
|
if llm_context.assistant:
|
||||||
exposed_entities: dict | None = _get_exposed_entities(
|
exposed_entities: dict | None = _get_exposed_entities(
|
||||||
self.hass, llm_context.assistant
|
self.hass, llm_context.assistant, include_state=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exposed_entities = None
|
exposed_entities = None
|
||||||
@ -463,7 +463,9 @@ class AssistAPI(API):
|
|||||||
|
|
||||||
|
|
||||||
def _get_exposed_entities(
|
def _get_exposed_entities(
|
||||||
hass: HomeAssistant, assistant: str
|
hass: HomeAssistant,
|
||||||
|
assistant: str,
|
||||||
|
include_state: bool = True,
|
||||||
) -> dict[str, dict[str, dict[str, Any]]]:
|
) -> dict[str, dict[str, dict[str, Any]]]:
|
||||||
"""Get exposed entities.
|
"""Get exposed entities.
|
||||||
|
|
||||||
@ -524,24 +526,28 @@ def _get_exposed_entities(
|
|||||||
info: dict[str, Any] = {
|
info: dict[str, Any] = {
|
||||||
"names": ", ".join(names),
|
"names": ", ".join(names),
|
||||||
"domain": state.domain,
|
"domain": state.domain,
|
||||||
"state": state.state,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if include_state:
|
||||||
|
info["state"] = state.state
|
||||||
|
|
||||||
if description:
|
if description:
|
||||||
info["description"] = description
|
info["description"] = description
|
||||||
|
|
||||||
if area_names:
|
if area_names:
|
||||||
info["areas"] = ", ".join(area_names)
|
info["areas"] = ", ".join(area_names)
|
||||||
|
|
||||||
if attributes := {
|
if include_state and (
|
||||||
attr_name: (
|
attributes := {
|
||||||
str(attr_value)
|
attr_name: (
|
||||||
if isinstance(attr_value, (Enum, Decimal, int))
|
str(attr_value)
|
||||||
else attr_value
|
if isinstance(attr_value, (Enum, Decimal, int))
|
||||||
)
|
else attr_value
|
||||||
for attr_name, attr_value in state.attributes.items()
|
)
|
||||||
if attr_name in interesting_attributes
|
for attr_name, attr_value in state.attributes.items()
|
||||||
}:
|
if attr_name in interesting_attributes
|
||||||
|
}
|
||||||
|
):
|
||||||
info["attributes"] = attributes
|
info["attributes"] = attributes
|
||||||
|
|
||||||
if state.domain in data:
|
if state.domain in data:
|
||||||
|
@ -622,6 +622,40 @@ async def test_assist_api_prompt(
|
|||||||
domain: light
|
domain: light
|
||||||
state: unavailable
|
state: unavailable
|
||||||
areas: Test Area 2
|
areas: Test Area 2
|
||||||
|
"""
|
||||||
|
stateless_exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
|
||||||
|
- names: Kitchen
|
||||||
|
domain: light
|
||||||
|
- names: Living Room
|
||||||
|
domain: light
|
||||||
|
areas: Test Area, Alternative name
|
||||||
|
- names: Test Device, my test light
|
||||||
|
domain: light
|
||||||
|
areas: Test Area, Alternative name
|
||||||
|
- names: Test Service
|
||||||
|
domain: light
|
||||||
|
areas: Test Area, Alternative name
|
||||||
|
- names: Test Service
|
||||||
|
domain: light
|
||||||
|
areas: Test Area, Alternative name
|
||||||
|
- names: Test Service
|
||||||
|
domain: light
|
||||||
|
areas: Test Area, Alternative name
|
||||||
|
- names: Test Device 2
|
||||||
|
domain: light
|
||||||
|
areas: Test Area 2
|
||||||
|
- names: Test Device 3
|
||||||
|
domain: light
|
||||||
|
areas: Test Area 2
|
||||||
|
- names: Test Device 4
|
||||||
|
domain: light
|
||||||
|
areas: Test Area 2
|
||||||
|
- names: Unnamed Device
|
||||||
|
domain: light
|
||||||
|
areas: Test Area 2
|
||||||
|
- names: '1'
|
||||||
|
domain: light
|
||||||
|
areas: Test Area 2
|
||||||
"""
|
"""
|
||||||
first_part_prompt = (
|
first_part_prompt = (
|
||||||
"When controlling Home Assistant always call the intent tools. "
|
"When controlling Home Assistant always call the intent tools. "
|
||||||
@ -640,7 +674,7 @@ async def test_assist_api_prompt(
|
|||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
{no_timer_prompt}
|
{no_timer_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{stateless_exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that the get_home_state tool returns the same results as the exposed_entities_prompt
|
# Verify that the get_home_state tool returns the same results as the exposed_entities_prompt
|
||||||
@ -663,7 +697,7 @@ async def test_assist_api_prompt(
|
|||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
{no_timer_prompt}
|
{no_timer_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{stateless_exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add floor
|
# Add floor
|
||||||
@ -678,7 +712,7 @@ async def test_assist_api_prompt(
|
|||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
{no_timer_prompt}
|
{no_timer_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{stateless_exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register device for timers
|
# Register device for timers
|
||||||
@ -689,7 +723,7 @@ async def test_assist_api_prompt(
|
|||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{stateless_exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user