diff --git a/homeassistant/components/mcp_server/__init__.py b/homeassistant/components/mcp_server/__init__.py index 941eccbe528..e523f46228f 100644 --- a/homeassistant/components/mcp_server/__init__.py +++ b/homeassistant/components/mcp_server/__init__.py @@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from . import http, llm_api +from . import http from .const import DOMAIN from .session import SessionManager from .types import MCPServerConfigEntry @@ -25,7 +25,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Model Context Protocol component.""" http.async_register(hass) - llm_api.async_register_api(hass) return True diff --git a/homeassistant/components/mcp_server/config_flow.py b/homeassistant/components/mcp_server/config_flow.py index 8d8d311b874..e8df68de5e2 100644 --- a/homeassistant/components/mcp_server/config_flow.py +++ b/homeassistant/components/mcp_server/config_flow.py @@ -16,7 +16,7 @@ from homeassistant.helpers.selector import ( SelectSelectorConfig, ) -from .const import DOMAIN, LLM_API, LLM_API_NAME +from .const import DOMAIN _LOGGER = logging.getLogger(__name__) @@ -33,13 +33,6 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Handle the initial step.""" llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)} - if LLM_API not in llm_apis: - # MCP server component is not loaded yet, so make the LLM API a choice. - llm_apis = { - LLM_API: LLM_API_NAME, - **llm_apis, - } - if user_input is not None: return self.async_create_entry( title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input diff --git a/homeassistant/components/mcp_server/const.py b/homeassistant/components/mcp_server/const.py index 8958ac36616..3f2e12cbb6a 100644 --- a/homeassistant/components/mcp_server/const.py +++ b/homeassistant/components/mcp_server/const.py @@ -2,5 +2,6 @@ DOMAIN = "mcp_server" TITLE = "Model Context Protocol Server" -LLM_API = "stateless_assist" -LLM_API_NAME = "Stateless Assist" +# The Stateless API is no longer registered explicitly, but this name may still exist in the +# users config entry. +STATELESS_LLM_API = "stateless_assist" diff --git a/homeassistant/components/mcp_server/llm_api.py b/homeassistant/components/mcp_server/llm_api.py index 5c29b29153e..f7dd4421480 100644 --- a/homeassistant/components/mcp_server/llm_api.py +++ b/homeassistant/components/mcp_server/llm_api.py @@ -1,19 +1,18 @@ -"""LLM API for MCP Server.""" +"""LLM API for MCP Server. -from homeassistant.core import HomeAssistant, callback +This is a modified version of the AssistAPI that does not include the home state +in the prompt. This API is not registered with the LLM API registry since it is +only used by the MCP Server. The MCP server will substitute this API when the +user selects the Assist API. +""" + +from homeassistant.core import callback from homeassistant.helpers import llm from homeassistant.util import yaml as yaml_util -from .const import LLM_API, LLM_API_NAME - EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"} -def async_register_api(hass: HomeAssistant) -> None: - """Register the LLM API.""" - llm.async_register_api(hass, StatelessAssistAPI(hass)) - - class StatelessAssistAPI(llm.AssistAPI): """LLM API for MCP Server that provides the Assist API without state information in the prompt. @@ -22,12 +21,6 @@ class StatelessAssistAPI(llm.AssistAPI): actions don't care about the current state, there is little quality loss. """ - def __init__(self, hass: HomeAssistant) -> None: - """Initialize the StatelessAssistAPI.""" - super().__init__(hass) - self.id = LLM_API - self.name = LLM_API_NAME - @callback def _async_get_exposed_entities_prompt( self, llm_context: llm.LLMContext, exposed_entities: dict | None diff --git a/homeassistant/components/mcp_server/server.py b/homeassistant/components/mcp_server/server.py index ba21abd722c..307fcdda8f3 100644 --- a/homeassistant/components/mcp_server/server.py +++ b/homeassistant/components/mcp_server/server.py @@ -21,6 +21,9 @@ from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import llm +from .const import STATELESS_LLM_API +from .llm_api import StatelessAssistAPI + _LOGGER = logging.getLogger(__name__) @@ -50,13 +53,21 @@ async def create_server( server = Server("home-assistant") + async def get_api_instance() -> llm.APIInstance: + """Substitute the StatelessAssistAPI for the Assist API if selected.""" + if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST): + api = StatelessAssistAPI(hass) + return await api.async_get_api_instance(llm_context) + + return await llm.async_get_api(hass, llm_api_id, llm_context) + @server.list_prompts() # type: ignore[no-untyped-call, misc] async def handle_list_prompts() -> list[types.Prompt]: - llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + llm_api = await get_api_instance() return [ types.Prompt( name=llm_api.api.name, - description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}", + description=f"Default prompt for Home Assistant {llm_api.api.name} API", ) ] @@ -64,12 +75,12 @@ async def create_server( async def handle_get_prompt( name: str, arguments: dict[str, str] | None ) -> types.GetPromptResult: - llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + llm_api = await get_api_instance() if name != llm_api.api.name: raise ValueError(f"Unknown prompt: {name}") return types.GetPromptResult( - description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}", + description=f"Default prompt for Home Assistant {llm_api.api.name} API", messages=[ types.PromptMessage( role="assistant", @@ -84,13 +95,13 @@ async def create_server( @server.list_tools() # type: ignore[no-untyped-call, misc] async def list_tools() -> list[types.Tool]: """List available time tools.""" - llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + llm_api = await get_api_instance() return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools] @server.call_tool() # type: ignore[no-untyped-call, misc] async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]: """Handle calling tools.""" - llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + llm_api = await get_api_instance() tool_input = llm.ToolInput(tool_name=name, tool_args=arguments) _LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args) diff --git a/tests/components/mcp_server/conftest.py b/tests/components/mcp_server/conftest.py index 5ec67fb6ce3..b5e25d9fe50 100644 --- a/tests/components/mcp_server/conftest.py +++ b/tests/components/mcp_server/conftest.py @@ -5,9 +5,10 @@ from unittest.mock import AsyncMock, patch import pytest -from homeassistant.components.mcp_server.const import DOMAIN, LLM_API +from homeassistant.components.mcp_server.const import DOMAIN from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from tests.common import MockConfigEntry @@ -21,13 +22,19 @@ def mock_setup_entry() -> Generator[AsyncMock]: yield mock_setup_entry +@pytest.fixture(name="llm_hass_api") +def llm_hass_api_fixture() -> str: + """Fixture for the config entry llm_hass_api.""" + return llm.LLM_API_ASSIST + + @pytest.fixture(name="config_entry") -def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: +def mock_config_entry(hass: HomeAssistant, llm_hass_api: str) -> MockConfigEntry: """Fixture to load the integration.""" config_entry = MockConfigEntry( domain=DOMAIN, data={ - CONF_LLM_HASS_API: LLM_API, + CONF_LLM_HASS_API: llm_hass_api, }, ) config_entry.add_to_hass(hass) diff --git a/tests/components/mcp_server/test_http.py b/tests/components/mcp_server/test_http.py index 905bfaa11d7..70efd211b57 100644 --- a/tests/components/mcp_server/test_http.py +++ b/tests/components/mcp_server/test_http.py @@ -16,6 +16,7 @@ import pytest from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN +from homeassistant.components.mcp_server.const import STATELESS_LLM_API from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API from homeassistant.config_entries import ConfigEntryState from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON @@ -24,6 +25,7 @@ from homeassistant.helpers import ( area_registry as ar, device_registry as dr, entity_registry as er, + llm, ) from homeassistant.setup import async_setup_component @@ -297,6 +299,7 @@ async def mcp_session( yield session +@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API]) async def test_mcp_tools_list( hass: HomeAssistant, setup_integration: None, @@ -319,6 +322,7 @@ async def test_mcp_tools_list( assert properties.get("name") == {"type": "string"} +@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API]) async def test_mcp_tool_call( hass: HomeAssistant, setup_integration: None, @@ -371,6 +375,7 @@ async def test_mcp_tool_call_failed( assert "Error calling tool" in result.content[0].text +@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API]) async def test_prompt_list( hass: HomeAssistant, setup_integration: None, @@ -384,13 +389,11 @@ async def test_prompt_list( assert len(result.prompts) == 1 prompt = result.prompts[0] - assert prompt.name == "Stateless Assist" - assert ( - prompt.description - == "Default prompt for the Home Assistant LLM API Stateless Assist" - ) + assert prompt.name == "Assist" + assert prompt.description == "Default prompt for Home Assistant Assist API" +@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API]) async def test_prompt_get( hass: HomeAssistant, setup_integration: None, @@ -400,12 +403,9 @@ async def test_prompt_get( """Test the get prompt endpoint.""" async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: - result = await session.get_prompt(name="Stateless Assist") + result = await session.get_prompt(name="Assist") - assert ( - result.description - == "Default prompt for the Home Assistant LLM API Stateless Assist" - ) + assert result.description == "Default prompt for Home Assistant Assist API" assert len(result.messages) == 1 assert result.messages[0].role == "assistant" assert result.messages[0].content.type == "text"