mirror of
https://github.com/home-assistant/core.git
synced 2025-04-27 02:37:50 +00:00
Remove entity state from mcp-server prompt (#137126)
* Create a stateless assist API for MCP server * Update stateless API * Fix areas in exposed entity fields * Add tests that verify areas are returned * Revert the getstate intent * Revert whitespace change * Revert whitespace change * Revert method name changes to avoid breaking openai and google tests
This commit is contained in:
parent
6bf5e95089
commit
1db5da4037
@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from . import http
|
from . import http, llm_api
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .session import SessionManager
|
from .session import SessionManager
|
||||||
from .types import MCPServerConfigEntry
|
from .types import MCPServerConfigEntry
|
||||||
@ -25,6 +25,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up the Model Context Protocol component."""
|
"""Set up the Model Context Protocol component."""
|
||||||
http.async_register(hass)
|
http.async_register(hass)
|
||||||
|
llm_api.async_register_api(hass)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from homeassistant.helpers.selector import (
|
|||||||
SelectSelectorConfig,
|
SelectSelectorConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN, LLM_API, LLM_API_NAME
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -33,6 +33,12 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Handle the initial step."""
|
"""Handle the initial step."""
|
||||||
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
|
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
|
||||||
|
if LLM_API not in llm_apis:
|
||||||
|
# MCP server component is not loaded yet, so make the LLM API a choice.
|
||||||
|
llm_apis = {
|
||||||
|
LLM_API: LLM_API_NAME,
|
||||||
|
**llm_apis,
|
||||||
|
}
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
|
@ -2,3 +2,5 @@
|
|||||||
|
|
||||||
DOMAIN = "mcp_server"
|
DOMAIN = "mcp_server"
|
||||||
TITLE = "Model Context Protocol Server"
|
TITLE = "Model Context Protocol Server"
|
||||||
|
LLM_API = "stateless_assist"
|
||||||
|
LLM_API_NAME = "Stateless Assist"
|
||||||
|
48
homeassistant/components/mcp_server/llm_api.py
Normal file
48
homeassistant/components/mcp_server/llm_api.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
"""LLM API for MCP Server."""
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers import llm
|
||||||
|
from homeassistant.util import yaml as yaml_util
|
||||||
|
|
||||||
|
from .const import LLM_API, LLM_API_NAME
|
||||||
|
|
||||||
|
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
|
||||||
|
|
||||||
|
|
||||||
|
def async_register_api(hass: HomeAssistant) -> None:
|
||||||
|
"""Register the LLM API."""
|
||||||
|
llm.async_register_api(hass, StatelessAssistAPI(hass))
|
||||||
|
|
||||||
|
|
||||||
|
class StatelessAssistAPI(llm.AssistAPI):
|
||||||
|
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
|
||||||
|
|
||||||
|
Syncing the state information is possible, but may put unnecessary load on
|
||||||
|
the system so we are instead providing the prompt without entity state. Since
|
||||||
|
actions don't care about the current state, there is little quality loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
"""Initialize the StatelessAssistAPI."""
|
||||||
|
super().__init__(hass)
|
||||||
|
self.id = LLM_API
|
||||||
|
self.name = LLM_API_NAME
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_get_exposed_entities_prompt(
|
||||||
|
self, llm_context: llm.LLMContext, exposed_entities: dict | None
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return the prompt for the exposed entities."""
|
||||||
|
prompt = []
|
||||||
|
|
||||||
|
if exposed_entities:
|
||||||
|
prompt.append(
|
||||||
|
"An overview of the areas and the devices in this smart home:"
|
||||||
|
)
|
||||||
|
entities = [
|
||||||
|
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
|
||||||
|
for entity_info in exposed_entities.values()
|
||||||
|
]
|
||||||
|
prompt.append(yaml_util.dump(list(entities)))
|
||||||
|
|
||||||
|
return prompt
|
@ -326,12 +326,21 @@ class AssistAPI(API):
|
|||||||
def _async_get_api_prompt(
|
def _async_get_api_prompt(
|
||||||
self, llm_context: LLMContext, exposed_entities: dict | None
|
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return the prompt for the API."""
|
|
||||||
if not exposed_entities:
|
if not exposed_entities:
|
||||||
return (
|
return (
|
||||||
"Only if the user wants to control a device, tell them to expose entities "
|
"Only if the user wants to control a device, tell them to expose entities "
|
||||||
"to their voice assistant in Home Assistant."
|
"to their voice assistant in Home Assistant."
|
||||||
)
|
)
|
||||||
|
return "\n".join(
|
||||||
|
[
|
||||||
|
*self._async_get_preable(llm_context),
|
||||||
|
*self._async_get_exposed_entities_prompt(llm_context, exposed_entities),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_get_preable(self, llm_context: LLMContext) -> list[str]:
|
||||||
|
"""Return the prompt for the API."""
|
||||||
|
|
||||||
prompt = [
|
prompt = [
|
||||||
(
|
(
|
||||||
@ -371,13 +380,22 @@ class AssistAPI(API):
|
|||||||
):
|
):
|
||||||
prompt.append("This device is not able to start timers.")
|
prompt.append("This device is not able to start timers.")
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_get_exposed_entities_prompt(
|
||||||
|
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return the prompt for the API for exposed entities."""
|
||||||
|
prompt = []
|
||||||
|
|
||||||
if exposed_entities:
|
if exposed_entities:
|
||||||
prompt.append(
|
prompt.append(
|
||||||
"An overview of the areas and the devices in this smart home:"
|
"An overview of the areas and the devices in this smart home:"
|
||||||
)
|
)
|
||||||
prompt.append(yaml_util.dump(list(exposed_entities.values())))
|
prompt.append(yaml_util.dump(list(exposed_entities.values())))
|
||||||
|
|
||||||
return "\n".join(prompt)
|
return prompt
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_tools(
|
def _async_get_tools(
|
||||||
|
@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.mcp_server.const import DOMAIN
|
from homeassistant.components.mcp_server.const import DOMAIN, LLM_API
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import llm
|
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -28,7 +27,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|||||||
config_entry = MockConfigEntry(
|
config_entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={
|
data={
|
||||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
CONF_LLM_HASS_API: LLM_API,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
|
@ -20,7 +20,11 @@ from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
|
|||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
|
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
from homeassistant.helpers import (
|
||||||
|
area_registry as ar,
|
||||||
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry, setup_test_component_platform
|
from tests.common import MockConfigEntry, setup_test_component_platform
|
||||||
@ -45,6 +49,11 @@ INITIALIZE_MESSAGE = {
|
|||||||
}
|
}
|
||||||
EVENT_PREFIX = "event: "
|
EVENT_PREFIX = "event: "
|
||||||
DATA_PREFIX = "data: "
|
DATA_PREFIX = "data: "
|
||||||
|
EXPECTED_PROMPT_SUFFIX = """
|
||||||
|
- names: Kitchen Light
|
||||||
|
domain: light
|
||||||
|
areas: Kitchen
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -59,11 +68,13 @@ async def mock_entities(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
setup_integration: None,
|
setup_integration: None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fixture to expose entities to the conversation agent."""
|
"""Fixture to expose entities to the conversation agent."""
|
||||||
entity = MockLight("kitchen", STATE_OFF)
|
entity = MockLight("Kitchen Light", STATE_OFF)
|
||||||
entity.entity_id = TEST_ENTITY
|
entity.entity_id = TEST_ENTITY
|
||||||
|
entity.unique_id = "test-light-unique-id"
|
||||||
setup_test_component_platform(hass, LIGHT_DOMAIN, [entity])
|
setup_test_component_platform(hass, LIGHT_DOMAIN, [entity])
|
||||||
|
|
||||||
assert await async_setup_component(
|
assert await async_setup_component(
|
||||||
@ -71,6 +82,9 @@ async def mock_entities(
|
|||||||
LIGHT_DOMAIN,
|
LIGHT_DOMAIN,
|
||||||
{LIGHT_DOMAIN: [{"platform": "test"}]},
|
{LIGHT_DOMAIN: [{"platform": "test"}]},
|
||||||
)
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
kitchen = area_registry.async_get_or_create("Kitchen")
|
||||||
|
entity_registry.async_update_entity(TEST_ENTITY, area_id=kitchen.id)
|
||||||
|
|
||||||
async_expose_entity(hass, CONVERSATION_DOMAIN, TEST_ENTITY, True)
|
async_expose_entity(hass, CONVERSATION_DOMAIN, TEST_ENTITY, True)
|
||||||
|
|
||||||
@ -320,7 +334,7 @@ async def test_mcp_tool_call(
|
|||||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||||
result = await session.call_tool(
|
result = await session.call_tool(
|
||||||
name="HassTurnOn",
|
name="HassTurnOn",
|
||||||
arguments={"name": "kitchen"},
|
arguments={"name": "kitchen light"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not result.isError
|
assert not result.isError
|
||||||
@ -370,8 +384,11 @@ async def test_prompt_list(
|
|||||||
|
|
||||||
assert len(result.prompts) == 1
|
assert len(result.prompts) == 1
|
||||||
prompt = result.prompts[0]
|
prompt = result.prompts[0]
|
||||||
assert prompt.name == "Assist"
|
assert prompt.name == "Stateless Assist"
|
||||||
assert prompt.description == "Default prompt for the Home Assistant LLM API Assist"
|
assert (
|
||||||
|
prompt.description
|
||||||
|
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_prompt_get(
|
async def test_prompt_get(
|
||||||
@ -383,13 +400,17 @@ async def test_prompt_get(
|
|||||||
"""Test the get prompt endpoint."""
|
"""Test the get prompt endpoint."""
|
||||||
|
|
||||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||||
result = await session.get_prompt(name="Assist")
|
result = await session.get_prompt(name="Stateless Assist")
|
||||||
|
|
||||||
assert result.description == "Default prompt for the Home Assistant LLM API Assist"
|
assert (
|
||||||
|
result.description
|
||||||
|
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
||||||
|
)
|
||||||
assert len(result.messages) == 1
|
assert len(result.messages) == 1
|
||||||
assert result.messages[0].role == "assistant"
|
assert result.messages[0].role == "assistant"
|
||||||
assert result.messages[0].content.type == "text"
|
assert result.messages[0].content.type == "text"
|
||||||
assert "When controlling Home Assistant" in result.messages[0].content.text
|
assert "When controlling Home Assistant" in result.messages[0].content.text
|
||||||
|
assert result.messages[0].content.text.endswith(EXPECTED_PROMPT_SUFFIX)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_unknwon_prompt(
|
async def test_get_unknwon_prompt(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user