mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Update MCP server to make the stateless API implicit (#140753)
* Update MCP server to not register the stateless API, but use it implicitly as an Assist API replacement * Ensure backwards compatibility with old registration
This commit is contained in:
parent
c9276aedde
commit
412705302d
@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import http, llm_api
|
||||
from . import http
|
||||
from .const import DOMAIN
|
||||
from .session import SessionManager
|
||||
from .types import MCPServerConfigEntry
|
||||
@ -25,7 +25,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the Model Context Protocol component."""
|
||||
http.async_register(hass)
|
||||
llm_api.async_register_api(hass)
|
||||
return True
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ from homeassistant.helpers.selector import (
|
||||
SelectSelectorConfig,
|
||||
)
|
||||
|
||||
from .const import DOMAIN, LLM_API, LLM_API_NAME
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -33,13 +33,6 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
|
||||
if LLM_API not in llm_apis:
|
||||
# MCP server component is not loaded yet, so make the LLM API a choice.
|
||||
llm_apis = {
|
||||
LLM_API: LLM_API_NAME,
|
||||
**llm_apis,
|
||||
}
|
||||
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input
|
||||
|
@ -2,5 +2,6 @@
|
||||
|
||||
DOMAIN = "mcp_server"
|
||||
TITLE = "Model Context Protocol Server"
|
||||
LLM_API = "stateless_assist"
|
||||
LLM_API_NAME = "Stateless Assist"
|
||||
# The Stateless API is no longer registered explicitly, but this name may still exist in the
|
||||
# users config entry.
|
||||
STATELESS_LLM_API = "stateless_assist"
|
||||
|
@ -1,19 +1,18 @@
|
||||
"""LLM API for MCP Server."""
|
||||
"""LLM API for MCP Server.
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
This is a modified version of the AssistAPI that does not include the home state
|
||||
in the prompt. This API is not registered with the LLM API registry since it is
|
||||
only used by the MCP Server. The MCP server will substitute this API when the
|
||||
user selects the Assist API.
|
||||
"""
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.util import yaml as yaml_util
|
||||
|
||||
from .const import LLM_API, LLM_API_NAME
|
||||
|
||||
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
|
||||
|
||||
|
||||
def async_register_api(hass: HomeAssistant) -> None:
|
||||
"""Register the LLM API."""
|
||||
llm.async_register_api(hass, StatelessAssistAPI(hass))
|
||||
|
||||
|
||||
class StatelessAssistAPI(llm.AssistAPI):
|
||||
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
|
||||
|
||||
@ -22,12 +21,6 @@ class StatelessAssistAPI(llm.AssistAPI):
|
||||
actions don't care about the current state, there is little quality loss.
|
||||
"""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the StatelessAssistAPI."""
|
||||
super().__init__(hass)
|
||||
self.id = LLM_API
|
||||
self.name = LLM_API_NAME
|
||||
|
||||
@callback
|
||||
def _async_get_exposed_entities_prompt(
|
||||
self, llm_context: llm.LLMContext, exposed_entities: dict | None
|
||||
|
@ -21,6 +21,9 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from .const import STATELESS_LLM_API
|
||||
from .llm_api import StatelessAssistAPI
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -50,13 +53,21 @@ async def create_server(
|
||||
|
||||
server = Server("home-assistant")
|
||||
|
||||
async def get_api_instance() -> llm.APIInstance:
|
||||
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
|
||||
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
|
||||
api = StatelessAssistAPI(hass)
|
||||
return await api.async_get_api_instance(llm_context)
|
||||
|
||||
return await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
|
||||
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
||||
async def handle_list_prompts() -> list[types.Prompt]:
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
llm_api = await get_api_instance()
|
||||
return [
|
||||
types.Prompt(
|
||||
name=llm_api.api.name,
|
||||
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
||||
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
|
||||
)
|
||||
]
|
||||
|
||||
@ -64,12 +75,12 @@ async def create_server(
|
||||
async def handle_get_prompt(
|
||||
name: str, arguments: dict[str, str] | None
|
||||
) -> types.GetPromptResult:
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
llm_api = await get_api_instance()
|
||||
if name != llm_api.api.name:
|
||||
raise ValueError(f"Unknown prompt: {name}")
|
||||
|
||||
return types.GetPromptResult(
|
||||
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
||||
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
|
||||
messages=[
|
||||
types.PromptMessage(
|
||||
role="assistant",
|
||||
@ -84,13 +95,13 @@ async def create_server(
|
||||
@server.list_tools() # type: ignore[no-untyped-call, misc]
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
"""List available time tools."""
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
llm_api = await get_api_instance()
|
||||
return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
||||
|
||||
@server.call_tool() # type: ignore[no-untyped-call, misc]
|
||||
async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]:
|
||||
"""Handle calling tools."""
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
llm_api = await get_api_instance()
|
||||
tool_input = llm.ToolInput(tool_name=name, tool_args=arguments)
|
||||
_LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
||||
|
||||
|
@ -5,9 +5,10 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.mcp_server.const import DOMAIN, LLM_API
|
||||
from homeassistant.components.mcp_server.const import DOMAIN
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -21,13 +22,19 @@ def mock_setup_entry() -> Generator[AsyncMock]:
|
||||
yield mock_setup_entry
|
||||
|
||||
|
||||
@pytest.fixture(name="llm_hass_api")
|
||||
def llm_hass_api_fixture() -> str:
|
||||
"""Fixture for the config entry llm_hass_api."""
|
||||
return llm.LLM_API_ASSIST
|
||||
|
||||
|
||||
@pytest.fixture(name="config_entry")
|
||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
def mock_config_entry(hass: HomeAssistant, llm_hass_api: str) -> MockConfigEntry:
|
||||
"""Fixture to load the integration."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_LLM_HASS_API: LLM_API,
|
||||
CONF_LLM_HASS_API: llm_hass_api,
|
||||
},
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
@ -16,6 +16,7 @@ import pytest
|
||||
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
from homeassistant.components.mcp_server.const import STATELESS_LLM_API
|
||||
from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
|
||||
@ -24,6 +25,7 @@ from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
llm,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
@ -297,6 +299,7 @@ async def mcp_session(
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||
async def test_mcp_tools_list(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
@ -319,6 +322,7 @@ async def test_mcp_tools_list(
|
||||
assert properties.get("name") == {"type": "string"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||
async def test_mcp_tool_call(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
@ -371,6 +375,7 @@ async def test_mcp_tool_call_failed(
|
||||
assert "Error calling tool" in result.content[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||
async def test_prompt_list(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
@ -384,13 +389,11 @@ async def test_prompt_list(
|
||||
|
||||
assert len(result.prompts) == 1
|
||||
prompt = result.prompts[0]
|
||||
assert prompt.name == "Stateless Assist"
|
||||
assert (
|
||||
prompt.description
|
||||
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
||||
)
|
||||
assert prompt.name == "Assist"
|
||||
assert prompt.description == "Default prompt for Home Assistant Assist API"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||
async def test_prompt_get(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
@ -400,12 +403,9 @@ async def test_prompt_get(
|
||||
"""Test the get prompt endpoint."""
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.get_prompt(name="Stateless Assist")
|
||||
result = await session.get_prompt(name="Assist")
|
||||
|
||||
assert (
|
||||
result.description
|
||||
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
||||
)
|
||||
assert result.description == "Default prompt for Home Assistant Assist API"
|
||||
assert len(result.messages) == 1
|
||||
assert result.messages[0].role == "assistant"
|
||||
assert result.messages[0].content.type == "text"
|
||||
|
Loading…
x
Reference in New Issue
Block a user