Update MCP server to make the stateless API implicit (#140753)

* Update MCP server to not register the stateless API, but use it implicitly as an Assist API replacement

* Ensure backwards compatibility with old registration
This commit is contained in:
Allen Porter 2025-03-17 14:38:21 -07:00 committed by GitHub
parent c9276aedde
commit 412705302d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 50 additions and 46 deletions

View File

@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import http, llm_api from . import http
from .const import DOMAIN from .const import DOMAIN
from .session import SessionManager from .session import SessionManager
from .types import MCPServerConfigEntry from .types import MCPServerConfigEntry
@ -25,7 +25,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Model Context Protocol component.""" """Set up the Model Context Protocol component."""
http.async_register(hass) http.async_register(hass)
llm_api.async_register_api(hass)
return True return True

View File

@ -16,7 +16,7 @@ from homeassistant.helpers.selector import (
SelectSelectorConfig, SelectSelectorConfig,
) )
from .const import DOMAIN, LLM_API, LLM_API_NAME from .const import DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -33,13 +33,6 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the initial step.""" """Handle the initial step."""
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)} llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
if LLM_API not in llm_apis:
# MCP server component is not loaded yet, so make the LLM API a choice.
llm_apis = {
LLM_API: LLM_API_NAME,
**llm_apis,
}
if user_input is not None: if user_input is not None:
return self.async_create_entry( return self.async_create_entry(
title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input

View File

@ -2,5 +2,6 @@
DOMAIN = "mcp_server" DOMAIN = "mcp_server"
TITLE = "Model Context Protocol Server" TITLE = "Model Context Protocol Server"
LLM_API = "stateless_assist" # The Stateless API is no longer registered explicitly, but this name may still exist in the
LLM_API_NAME = "Stateless Assist" # users config entry.
STATELESS_LLM_API = "stateless_assist"

View File

@ -1,19 +1,18 @@
"""LLM API for MCP Server.""" """LLM API for MCP Server.
from homeassistant.core import HomeAssistant, callback This is a modified version of the AssistAPI that does not include the home state
in the prompt. This API is not registered with the LLM API registry since it is
only used by the MCP Server. The MCP server will substitute this API when the
user selects the Assist API.
"""
from homeassistant.core import callback
from homeassistant.helpers import llm from homeassistant.helpers import llm
from homeassistant.util import yaml as yaml_util from homeassistant.util import yaml as yaml_util
from .const import LLM_API, LLM_API_NAME
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"} EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
def async_register_api(hass: HomeAssistant) -> None:
"""Register the LLM API."""
llm.async_register_api(hass, StatelessAssistAPI(hass))
class StatelessAssistAPI(llm.AssistAPI): class StatelessAssistAPI(llm.AssistAPI):
"""LLM API for MCP Server that provides the Assist API without state information in the prompt. """LLM API for MCP Server that provides the Assist API without state information in the prompt.
@ -22,12 +21,6 @@ class StatelessAssistAPI(llm.AssistAPI):
actions don't care about the current state, there is little quality loss. actions don't care about the current state, there is little quality loss.
""" """
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the StatelessAssistAPI."""
super().__init__(hass)
self.id = LLM_API
self.name = LLM_API_NAME
@callback @callback
def _async_get_exposed_entities_prompt( def _async_get_exposed_entities_prompt(
self, llm_context: llm.LLMContext, exposed_entities: dict | None self, llm_context: llm.LLMContext, exposed_entities: dict | None

View File

@ -21,6 +21,9 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm from homeassistant.helpers import llm
from .const import STATELESS_LLM_API
from .llm_api import StatelessAssistAPI
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -50,13 +53,21 @@ async def create_server(
server = Server("home-assistant") server = Server("home-assistant")
async def get_api_instance() -> llm.APIInstance:
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
api = StatelessAssistAPI(hass)
return await api.async_get_api_instance(llm_context)
return await llm.async_get_api(hass, llm_api_id, llm_context)
@server.list_prompts() # type: ignore[no-untyped-call, misc] @server.list_prompts() # type: ignore[no-untyped-call, misc]
async def handle_list_prompts() -> list[types.Prompt]: async def handle_list_prompts() -> list[types.Prompt]:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) llm_api = await get_api_instance()
return [ return [
types.Prompt( types.Prompt(
name=llm_api.api.name, name=llm_api.api.name,
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}", description=f"Default prompt for Home Assistant {llm_api.api.name} API",
) )
] ]
@ -64,12 +75,12 @@ async def create_server(
async def handle_get_prompt( async def handle_get_prompt(
name: str, arguments: dict[str, str] | None name: str, arguments: dict[str, str] | None
) -> types.GetPromptResult: ) -> types.GetPromptResult:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) llm_api = await get_api_instance()
if name != llm_api.api.name: if name != llm_api.api.name:
raise ValueError(f"Unknown prompt: {name}") raise ValueError(f"Unknown prompt: {name}")
return types.GetPromptResult( return types.GetPromptResult(
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}", description=f"Default prompt for Home Assistant {llm_api.api.name} API",
messages=[ messages=[
types.PromptMessage( types.PromptMessage(
role="assistant", role="assistant",
@ -84,13 +95,13 @@ async def create_server(
@server.list_tools() # type: ignore[no-untyped-call, misc] @server.list_tools() # type: ignore[no-untyped-call, misc]
async def list_tools() -> list[types.Tool]: async def list_tools() -> list[types.Tool]:
"""List available time tools.""" """List available time tools."""
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) llm_api = await get_api_instance()
return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools] return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
@server.call_tool() # type: ignore[no-untyped-call, misc] @server.call_tool() # type: ignore[no-untyped-call, misc]
async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]: async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]:
"""Handle calling tools.""" """Handle calling tools."""
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) llm_api = await get_api_instance()
tool_input = llm.ToolInput(tool_name=name, tool_args=arguments) tool_input = llm.ToolInput(tool_name=name, tool_args=arguments)
_LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args) _LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)

View File

@ -5,9 +5,10 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from homeassistant.components.mcp_server.const import DOMAIN, LLM_API from homeassistant.components.mcp_server.const import DOMAIN
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -21,13 +22,19 @@ def mock_setup_entry() -> Generator[AsyncMock]:
yield mock_setup_entry yield mock_setup_entry
@pytest.fixture(name="llm_hass_api")
def llm_hass_api_fixture() -> str:
"""Fixture for the config entry llm_hass_api."""
return llm.LLM_API_ASSIST
@pytest.fixture(name="config_entry") @pytest.fixture(name="config_entry")
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: def mock_config_entry(hass: HomeAssistant, llm_hass_api: str) -> MockConfigEntry:
"""Fixture to load the integration.""" """Fixture to load the integration."""
config_entry = MockConfigEntry( config_entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={ data={
CONF_LLM_HASS_API: LLM_API, CONF_LLM_HASS_API: llm_hass_api,
}, },
) )
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)

View File

@ -16,6 +16,7 @@ import pytest
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.components.mcp_server.const import STATELESS_LLM_API
from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
@ -24,6 +25,7 @@ from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
llm,
) )
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -297,6 +299,7 @@ async def mcp_session(
yield session yield session
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_mcp_tools_list( async def test_mcp_tools_list(
hass: HomeAssistant, hass: HomeAssistant,
setup_integration: None, setup_integration: None,
@ -319,6 +322,7 @@ async def test_mcp_tools_list(
assert properties.get("name") == {"type": "string"} assert properties.get("name") == {"type": "string"}
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_mcp_tool_call( async def test_mcp_tool_call(
hass: HomeAssistant, hass: HomeAssistant,
setup_integration: None, setup_integration: None,
@ -371,6 +375,7 @@ async def test_mcp_tool_call_failed(
assert "Error calling tool" in result.content[0].text assert "Error calling tool" in result.content[0].text
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_prompt_list( async def test_prompt_list(
hass: HomeAssistant, hass: HomeAssistant,
setup_integration: None, setup_integration: None,
@ -384,13 +389,11 @@ async def test_prompt_list(
assert len(result.prompts) == 1 assert len(result.prompts) == 1
prompt = result.prompts[0] prompt = result.prompts[0]
assert prompt.name == "Stateless Assist" assert prompt.name == "Assist"
assert ( assert prompt.description == "Default prompt for Home Assistant Assist API"
prompt.description
== "Default prompt for the Home Assistant LLM API Stateless Assist"
)
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
async def test_prompt_get( async def test_prompt_get(
hass: HomeAssistant, hass: HomeAssistant,
setup_integration: None, setup_integration: None,
@ -400,12 +403,9 @@ async def test_prompt_get(
"""Test the get prompt endpoint.""" """Test the get prompt endpoint."""
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.get_prompt(name="Stateless Assist") result = await session.get_prompt(name="Assist")
assert ( assert result.description == "Default prompt for Home Assistant Assist API"
result.description
== "Default prompt for the Home Assistant LLM API Stateless Assist"
)
assert len(result.messages) == 1 assert len(result.messages) == 1
assert result.messages[0].role == "assistant" assert result.messages[0].role == "assistant"
assert result.messages[0].content.type == "text" assert result.messages[0].content.type == "text"