diff --git a/homeassistant/components/mcp_server/llm_api.py b/homeassistant/components/mcp_server/llm_api.py deleted file mode 100644 index f7dd4421480..00000000000 --- a/homeassistant/components/mcp_server/llm_api.py +++ /dev/null @@ -1,41 +0,0 @@ -"""LLM API for MCP Server. - -This is a modified version of the AssistAPI that does not include the home state -in the prompt. This API is not registered with the LLM API registry since it is -only used by the MCP Server. The MCP server will substitute this API when the -user selects the Assist API. -""" - -from homeassistant.core import callback -from homeassistant.helpers import llm -from homeassistant.util import yaml as yaml_util - -EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"} - - -class StatelessAssistAPI(llm.AssistAPI): - """LLM API for MCP Server that provides the Assist API without state information in the prompt. - - Syncing the state information is possible, but may put unnecessary load on - the system so we are instead providing the prompt without entity state. Since - actions don't care about the current state, there is little quality loss. - """ - - @callback - def _async_get_exposed_entities_prompt( - self, llm_context: llm.LLMContext, exposed_entities: dict | None - ) -> list[str]: - """Return the prompt for the exposed entities.""" - prompt = [] - - if exposed_entities and exposed_entities["entities"]: - prompt.append( - "An overview of the areas and the devices in this smart home:" - ) - entities = [ - {k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS} - for entity_info in exposed_entities["entities"].values() - ] - prompt.append(yaml_util.dump(list(entities))) - - return prompt diff --git a/homeassistant/components/mcp_server/server.py b/homeassistant/components/mcp_server/server.py index 307fcdda8f3..88b179ae7c2 100644 --- a/homeassistant/components/mcp_server/server.py +++ b/homeassistant/components/mcp_server/server.py @@ -22,7 +22,6 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import llm from .const import STATELESS_LLM_API -from .llm_api import StatelessAssistAPI _LOGGER = logging.getLogger(__name__) @@ -50,15 +49,14 @@ async def create_server( A Model Context Protocol Server object is associated with a single session. The MCP SDK handles the details of the protocol. """ + if llm_api_id == STATELESS_LLM_API: + llm_api_id = llm.LLM_API_ASSIST server = Server("home-assistant") async def get_api_instance() -> llm.APIInstance: - """Substitute the StatelessAssistAPI for the Assist API if selected.""" - if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST): - api = StatelessAssistAPI(hass) - return await api.async_get_api_instance(llm_context) - + """Get the LLM API selected.""" + # Backwards compatibility with old MCP Server config return await llm.async_get_api(hass, llm_api_id, llm_context) @server.list_prompts() # type: ignore[no-untyped-call, misc] diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 5995543914f..7f6fe22ec70 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -316,7 +316,7 @@ class AssistAPI(API): """Return the instance of the API.""" if llm_context.assistant: exposed_entities: dict | None = _get_exposed_entities( - self.hass, llm_context.assistant + self.hass, llm_context.assistant, include_state=False ) else: exposed_entities = None @@ -463,7 +463,9 @@ class AssistAPI(API): def _get_exposed_entities( - hass: HomeAssistant, assistant: str + hass: HomeAssistant, + assistant: str, + include_state: bool = True, ) -> dict[str, dict[str, dict[str, Any]]]: """Get exposed entities. @@ -524,24 +526,28 @@ def _get_exposed_entities( info: dict[str, Any] = { "names": ", ".join(names), "domain": state.domain, - "state": state.state, } + if include_state: + info["state"] = state.state + if description: info["description"] = description if area_names: info["areas"] = ", ".join(area_names) - if attributes := { - attr_name: ( - str(attr_value) - if isinstance(attr_value, (Enum, Decimal, int)) - else attr_value - ) - for attr_name, attr_value in state.attributes.items() - if attr_name in interesting_attributes - }: + if include_state and ( + attributes := { + attr_name: ( + str(attr_value) + if isinstance(attr_value, (Enum, Decimal, int)) + else attr_value + ) + for attr_name, attr_value in state.attributes.items() + if attr_name in interesting_attributes + } + ): info["attributes"] = attributes if state.domain in data: diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 45ed009fcf1..19ada407550 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -622,6 +622,40 @@ async def test_assist_api_prompt( domain: light state: unavailable areas: Test Area 2 +""" + stateless_exposed_entities_prompt = """An overview of the areas and the devices in this smart home: +- names: Kitchen + domain: light +- names: Living Room + domain: light + areas: Test Area, Alternative name +- names: Test Device, my test light + domain: light + areas: Test Area, Alternative name +- names: Test Service + domain: light + areas: Test Area, Alternative name +- names: Test Service + domain: light + areas: Test Area, Alternative name +- names: Test Service + domain: light + areas: Test Area, Alternative name +- names: Test Device 2 + domain: light + areas: Test Area 2 +- names: Test Device 3 + domain: light + areas: Test Area 2 +- names: Test Device 4 + domain: light + areas: Test Area 2 +- names: Unnamed Device + domain: light + areas: Test Area 2 +- names: '1' + domain: light + areas: Test Area 2 """ first_part_prompt = ( "When controlling Home Assistant always call the intent tools. " @@ -640,7 +674,7 @@ async def test_assist_api_prompt( f"""{first_part_prompt} {area_prompt} {no_timer_prompt} -{exposed_entities_prompt}""" +{stateless_exposed_entities_prompt}""" ) # Verify that the get_home_state tool returns the same results as the exposed_entities_prompt @@ -663,7 +697,7 @@ async def test_assist_api_prompt( f"""{first_part_prompt} {area_prompt} {no_timer_prompt} -{exposed_entities_prompt}""" +{stateless_exposed_entities_prompt}""" ) # Add floor @@ -678,7 +712,7 @@ async def test_assist_api_prompt( f"""{first_part_prompt} {area_prompt} {no_timer_prompt} -{exposed_entities_prompt}""" +{stateless_exposed_entities_prompt}""" ) # Register device for timers @@ -689,7 +723,7 @@ async def test_assist_api_prompt( assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} -{exposed_entities_prompt}""" +{stateless_exposed_entities_prompt}""" )