Update MCP server to make the stateless API implicit (#140753)

* Update MCP server to not register the stateless API, but use it implicitly as an Assist API replacement

* Ensure backwards compatibility with old registration
This commit is contained in:
Allen Porter 2025-03-17 14:38:21 -07:00 committed by GitHub
parent c9276aedde
commit 412705302d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 50 additions and 46 deletions

View File

@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType
from . import http, llm_api
from . import http
from .const import DOMAIN
from .session import SessionManager
from .types import MCPServerConfigEntry
@ -25,7 +25,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Model Context Protocol component."""
http.async_register(hass)
llm_api.async_register_api(hass)
return True

View File

@ -16,7 +16,7 @@ from homeassistant.helpers.selector import (
SelectSelectorConfig,
)
from .const import DOMAIN, LLM_API, LLM_API_NAME
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
@ -33,13 +33,6 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
"""Handle the initial step."""
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
if LLM_API not in llm_apis:
# MCP server component is not loaded yet, so make the LLM API a choice.
llm_apis = {
LLM_API: LLM_API_NAME,
**llm_apis,
}
if user_input is not None:
return self.async_create_entry(
title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input

View File

@ -2,5 +2,6 @@
DOMAIN = "mcp_server"
TITLE = "Model Context Protocol Server"
LLM_API = "stateless_assist"
LLM_API_NAME = "Stateless Assist"
# The Stateless API is no longer registered explicitly, but this name may still exist in the
# users config entry.
STATELESS_LLM_API = "stateless_assist"

View File

@ -1,19 +1,18 @@
"""LLM API for MCP Server."""
"""LLM API for MCP Server.
from homeassistant.core import HomeAssistant, callback
This is a modified version of the AssistAPI that does not include the home state
in the prompt. This API is not registered with the LLM API registry since it is
only used by the MCP Server. The MCP server will substitute this API when the
user selects the Assist API.
"""
from homeassistant.core import callback
from homeassistant.helpers import llm
from homeassistant.util import yaml as yaml_util
from .const import LLM_API, LLM_API_NAME
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
def async_register_api(hass: HomeAssistant) -> None:
"""Register the LLM API."""
llm.async_register_api(hass, StatelessAssistAPI(hass))
class StatelessAssistAPI(llm.AssistAPI):
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
@ -22,12 +21,6 @@ class StatelessAssistAPI(llm.AssistAPI):
actions don't care about the current state, there is little quality loss.
"""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the StatelessAssistAPI."""
super().__init__(hass)
self.id = LLM_API
self.name = LLM_API_NAME
@callback
def _async_get_exposed_entities_prompt(
self, llm_context: llm.LLMContext, exposed_entities: dict | None

View File

@ -21,6 +21,9 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from .const import STATELESS_LLM_API
from .llm_api import StatelessAssistAPI
_LOGGER = logging.getLogger(__name__)
@ -50,13 +53,21 @@ async def create_server(
server = Server("home-assistant")
async def get_api_instance() -> llm.APIInstance:
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
api = StatelessAssistAPI(hass)
return await api.async_get_api_instance(llm_context)
return await llm.async_get_api(hass, llm_api_id, llm_context)
@server.list_prompts() # type: ignore[no-untyped-call, misc]
async def handle_list_prompts() -> list[types.Prompt]:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
llm_api = await get_api_instance()
return [
types.Prompt(
name=llm_api.api.name,
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
)
]
@ -64,12 +75,12 @@ async def create_server(
async def handle_get_prompt(
name: str, arguments: dict[str, str] | None
) -> types.GetPromptResult:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
llm_api = await get_api_instance()
if name != llm_api.api.name:
raise ValueError(f"Unknown prompt: {name}")
return types.GetPromptResult(
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
messages=[
types.PromptMessage(
role="assistant",
@ -84,13 +95,13 @@ async def create_server(
@server.list_tools() # type: ignore[no-untyped-call, misc]
async def list_tools() -> list[types.Tool]:
"""List available time tools."""
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
llm_api = await get_api_instance()
return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
@server.call_tool() # type: ignore[no-untyped-call, misc]
async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]:
"""Handle calling tools."""
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
llm_api = await get_api_instance()
tool_input = llm.ToolInput(tool_name=name, tool_args=arguments)
_LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)

View File

@ -5,9 +5,10 @@ from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.mcp_server.const import DOMAIN, LLM_API
from homeassistant.components.mcp_server.const import DOMAIN
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from tests.common import MockConfigEntry
@ -21,13 +22,19 @@ def mock_setup_entry() -> Generator[AsyncMock]:
yield mock_setup_entry
@pytest.fixture(name="llm_hass_api")
def llm_hass_api_fixture() -> str:
"""Fixture for the config entry llm_hass_api."""
return llm.LLM_API_ASSIST
@pytest.fixture(name="config_entry")
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
def mock_config_entry(hass: HomeAssistant, llm_hass_api: str) -> MockConfigEntry:
"""Fixture to load the integration."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_LLM_HASS_API: LLM_API,
CONF_LLM_HASS_API: llm_hass_api,
},
)
config_entry.add_to_hass(hass)

View File

@ -16,6 +16,7 @@ import pytest
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.components.mcp_server.const import STATELESS_LLM_API
from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
@ -24,6 +25,7 @@ from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
llm,
)
from homeassistant.setup import async_setup_component
@ -297,6 +299,7 @@ async def mcp_session(
yield session
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_mcp_tools_list(
hass: HomeAssistant,
setup_integration: None,
@ -319,6 +322,7 @@ async def test_mcp_tools_list(
assert properties.get("name") == {"type": "string"}
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_mcp_tool_call(
hass: HomeAssistant,
setup_integration: None,
@ -371,6 +375,7 @@ async def test_mcp_tool_call_failed(
assert "Error calling tool" in result.content[0].text
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_prompt_list(
hass: HomeAssistant,
setup_integration: None,
@ -384,13 +389,11 @@ async def test_prompt_list(
assert len(result.prompts) == 1
prompt = result.prompts[0]
assert prompt.name == "Stateless Assist"
assert (
prompt.description
== "Default prompt for the Home Assistant LLM API Stateless Assist"
)
assert prompt.name == "Assist"
assert prompt.description == "Default prompt for Home Assistant Assist API"
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_prompt_get(
hass: HomeAssistant,
setup_integration: None,
@ -400,12 +403,9 @@ async def test_prompt_get(
"""Test the get prompt endpoint."""
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.get_prompt(name="Stateless Assist")
result = await session.get_prompt(name="Assist")
assert (
result.description
== "Default prompt for the Home Assistant LLM API Stateless Assist"
)
assert result.description == "Default prompt for Home Assistant Assist API"
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].content.type == "text"