diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index e81c62ae25c..bbe77f0ea1a 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -3,7 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, replace +from enum import Enum from typing import Any import voluptuous as vol @@ -13,12 +14,20 @@ from homeassistant.components.conversation.trace import ( ConversationTraceEventType, async_conversation_trace_append, ) +from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError +from homeassistant.util import yaml from homeassistant.util.json import JsonObjectType -from . import area_registry, device_registry, floor_registry, intent +from . import ( + area_registry as ar, + device_registry as dr, + entity_registry as er, + floor_registry as fr, + intent, +) from .singleton import singleton LLM_API_ASSIST = "assist" @@ -140,19 +149,16 @@ class API(ABC): else: raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') - _tool_input = ToolInput( - tool_name=tool.name, - tool_args=tool.parameters(tool_input.tool_args), - platform=tool_input.platform, - context=tool_input.context or Context(), - user_prompt=tool_input.user_prompt, - language=tool_input.language, - assistant=tool_input.assistant, - device_id=tool_input.device_id, + return await tool.async_call( + self.hass, + replace( + tool_input, + tool_name=tool.name, + tool_args=tool.parameters(tool_input.tool_args), + context=tool_input.context or Context(), + ), ) - return await tool.async_call(self.hass, _tool_input) - class IntentTool(Tool): """LLM Tool representing an Intent.""" @@ -209,28 +215,51 @@ class AssistAPI(API): async def async_get_api_prompt(self, tool_input: ToolInput) -> str: """Return the prompt for the API.""" - prompt = ( - "Call the intent tools to control Home Assistant. " - "Just pass the name to the intent." - ) + if tool_input.assistant: + exposed_entities: dict | None = _get_exposed_entities( + self.hass, tool_input.assistant + ) + else: + exposed_entities = None + + if not exposed_entities: + return ( + "Only if the user wants to control a device, tell them to expose entities " + "to their voice assistant in Home Assistant." + ) + + prompt = [ + ( + "Call the intent tools to control Home Assistant. " + "Just pass the name to the intent. " + "When controlling an area, prefer passing area name." + ) + ] if tool_input.device_id: - device_reg = device_registry.async_get(self.hass) + device_reg = dr.async_get(self.hass) device = device_reg.async_get(tool_input.device_id) if device: - area_reg = area_registry.async_get(self.hass) + area_reg = ar.async_get(self.hass) if device.area_id and (area := area_reg.async_get_area(device.area_id)): - floor_reg = floor_registry.async_get(self.hass) + floor_reg = fr.async_get(self.hass) if area.floor_id and ( floor := floor_reg.async_get_floor(area.floor_id) ): - prompt += f" You are in {area.name} ({floor.name})." + prompt.append(f"You are in {area.name} ({floor.name}).") else: - prompt += f" You are in {area.name}." + prompt.append(f"You are in {area.name}.") if tool_input.context and tool_input.context.user_id: user = await self.hass.auth.async_get_user(tool_input.context.user_id) if user: - prompt += f" The user name is {user.name}." - return prompt + prompt.append(f"The user name is {user.name}.") + + if exposed_entities: + prompt.append( + "An overview of the areas and the devices in this smart home:" + ) + prompt.append(yaml.dump(exposed_entities)) + + return "\n".join(prompt) @callback def async_get_tools(self) -> list[Tool]: @@ -240,3 +269,84 @@ class AssistAPI(API): for intent_handler in intent.async_get(self.hass) if intent_handler.intent_type not in self.IGNORE_INTENTS ] + + +def _get_exposed_entities( + hass: HomeAssistant, assistant: str +) -> dict[str, dict[str, Any]]: + """Get exposed entities.""" + area_registry = ar.async_get(hass) + entity_registry = er.async_get(hass) + device_registry = dr.async_get(hass) + interesting_domains = { + "binary_sensor", + "climate", + "cover", + "fan", + "light", + "lock", + "sensor", + "switch", + "weather", + } + interesting_attributes = { + "temperature", + "current_temperature", + "temperature_unit", + "brightness", + "humidity", + "unit_of_measurement", + "device_class", + "current_position", + "percentage", + } + + entities = {} + + for state in hass.states.async_all(): + if state.domain not in interesting_domains: + continue + + if not async_should_expose(hass, assistant, state.entity_id): + continue + + entity_entry = entity_registry.async_get(state.entity_id) + names = [state.name] + area_names = [] + + if entity_entry is not None: + names.extend(entity_entry.aliases) + if entity_entry.area_id and ( + area := area_registry.async_get_area(entity_entry.area_id) + ): + # Entity is in area + area_names.append(area.name) + area_names.extend(area.aliases) + elif entity_entry.device_id and ( + device := device_registry.async_get(entity_entry.device_id) + ): + # Check device area + if device.area_id and ( + area := area_registry.async_get_area(device.area_id) + ): + area_names.append(area.name) + area_names.extend(area.aliases) + + info: dict[str, Any] = { + "names": ", ".join(names), + "state": state.state, + } + + if area_names: + info["areas"] = ", ".join(area_names) + + if attributes := { + attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value + for attr_name, attr_value in state.attributes.items() + if attr_name in interesting_attributes + }: + info["attributes"] = attributes + + entities[state.entity_id] = info + + return entities diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index 6ffe3d747d3..b40224b21d0 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -262,7 +262,49 @@ Answer in plain text. Keep it simple and to the point. The current time is 05:00:00. Today's date is 05/24/24. - Call the intent tools to control Home Assistant. Just pass the name to the intent. + Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name. + An overview of the areas and the devices in this smart home: + light.test_device: + names: Test Device + state: unavailable + areas: Test Area + light.test_service: + names: Test Service + state: unavailable + areas: Test Area + light.test_service_2: + names: Test Service + state: unavailable + areas: Test Area + light.test_service_3: + names: Test Service + state: unavailable + areas: Test Area + light.test_device_2: + names: Test Device 2 + state: unavailable + areas: Test Area 2 + light.test_device_3: + names: Test Device 3 + state: unavailable + areas: Test Area 2 + light.test_device_4: + names: Test Device 4 + state: unavailable + areas: Test Area 2 + light.test_device_3_2: + names: Test Device 3 + state: unavailable + areas: Test Area 2 + light.none: + names: None + state: unavailable + areas: Test Area 2 + light.1: + names: '1' + state: unavailable + areas: Test Area 2 + ''', 'role': 'user', }), @@ -318,7 +360,49 @@ Answer in plain text. Keep it simple and to the point. The current time is 05:00:00. Today's date is 05/24/24. - Call the intent tools to control Home Assistant. Just pass the name to the intent. + Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name. + An overview of the areas and the devices in this smart home: + light.test_device: + names: Test Device + state: unavailable + areas: Test Area + light.test_service: + names: Test Service + state: unavailable + areas: Test Area + light.test_service_2: + names: Test Service + state: unavailable + areas: Test Area + light.test_service_3: + names: Test Service + state: unavailable + areas: Test Area + light.test_device_2: + names: Test Device 2 + state: unavailable + areas: Test Area 2 + light.test_device_3: + names: Test Device 3 + state: unavailable + areas: Test Area 2 + light.test_device_4: + names: Test Device 4 + state: unavailable + areas: Test Area 2 + light.test_device_3_2: + names: Test Device 3 + state: unavailable + areas: Test Area 2 + light.none: + names: None + state: unavailable + areas: Test Area 2 + light.1: + names: '1' + state: unavailable + areas: Test Area 2 + ''', 'role': 'user', }), diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 1f11cc58705..ad169d9ae0d 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -17,11 +17,13 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( area_registry as ar, device_registry as dr, + entity_registry as er, intent, llm, ) from tests.common import MockConfigEntry +from tests.typing import WebSocketGenerator @pytest.fixture(autouse=True) @@ -47,9 +49,11 @@ async def test_default_prompt( mock_init_component, area_registry: ar.AreaRegistry, device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, snapshot: SnapshotAssertion, agent_id: str | None, config_entry_options: {}, + hass_ws_client: WebSocketGenerator, ) -> None: """Test that the default prompt works.""" entry = MockConfigEntry(title=None) @@ -64,46 +68,70 @@ async def test_default_prompt( mock_config_entry, options={**mock_config_entry.options, **config_entry_options}, ) + entities = [] - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "1234")}, - name="Test Device", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - ) - for i in range(3): + def create_entity(device: dr.DeviceEntry) -> None: + """Create an entity for a device and track entity_id.""" + entity = entity_registry.async_get_or_create( + "light", + "test", + device.id, + device_id=device.id, + original_name=str(device.name), + suggested_object_id=str(device.name), + ) + entity.write_unavailable_state(hass) + entities.append(entity.entity_id) + + create_entity( device_registry.async_get_or_create( config_entry_id=entry.entry_id, - connections={("test", f"{i}abcd")}, - name="Test Service", + connections={("test", "1234")}, + name="Test Device", manufacturer="Test Manufacturer", model="Test Model", suggested_area="Test Area", - entry_type=dr.DeviceEntryType.SERVICE, ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "5678")}, - name="Test Device 2", - manufacturer="Test Manufacturer 2", - model="Device 2", - suggested_area="Test Area 2", ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", + for i in range(3): + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", f"{i}abcd")}, + name="Test Service", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + entry_type=dr.DeviceEntryType.SERVICE, + ) + ) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "5678")}, + name="Test Device 2", + manufacturer="Test Manufacturer 2", + model="Device 2", + suggested_area="Test Area 2", + ) ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "qwer")}, - name="Test Device 4", - suggested_area="Test Area 2", + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + ) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "qwer")}, + name="Test Device 4", + suggested_area="Test Area 2", + ) ) device = device_registry.async_get_or_create( config_entry_id=entry.entry_id, @@ -116,21 +144,40 @@ async def test_default_prompt( device_registry.async_update_device( device.id, disabled_by=dr.DeviceEntryDisabler.USER ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-no-name")}, - manufacturer="Test Manufacturer NoName", - model="Test Model NoName", - suggested_area="Test Area 2", + create_entity(device) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-no-name")}, + manufacturer="Test Manufacturer NoName", + model="Test Model NoName", + suggested_area="Test Area 2", + ) ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-integer-values")}, - name=1, - manufacturer=2, - model=3, - suggested_area="Test Area 2", + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-integer-values")}, + name=1, + manufacturer=2, + model=3, + suggested_area="Test Area 2", + ) ) + + # Set options for registered entities + ws_client = await hass_ws_client(hass) + await ws_client.send_json_auto_id( + { + "type": "homeassistant/expose_entity", + "assistants": ["conversation"], + "entity_ids": entities, + "should_expose": True, + } + ) + response = await ws_client.receive_json() + assert response["success"] + with ( patch("google.generativeai.GenerativeModel") as mock_model, patch( diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 43eef04734c..97f5e30f6fe 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -11,10 +11,13 @@ from homeassistant.helpers import ( area_registry as ar, config_validation as cv, device_registry as dr, + entity_registry as er, floor_registry as fr, intent, llm, ) +from homeassistant.setup import async_setup_component +from homeassistant.util import yaml from tests.common import MockConfigEntry @@ -158,10 +161,12 @@ async def test_assist_api_description(hass: HomeAssistant) -> None: async def test_assist_api_prompt( hass: HomeAssistant, device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, area_registry: ar.AreaRegistry, floor_registry: fr.FloorRegistry, ) -> None: """Test prompt for the assist API.""" + assert await async_setup_component(hass, "homeassistant", {}) context = Context() tool_input = llm.ToolInput( tool_name=None, @@ -170,41 +175,232 @@ async def test_assist_api_prompt( context=context, user_prompt="test_text", language="*", - assistant="test_assistant", + assistant="conversation", device_id="test_device", ) api = llm.async_get_api(hass, "assist") prompt = await api.async_get_api_prompt(tool_input) assert prompt == ( - "Call the intent tools to control Home Assistant." - " Just pass the name to the intent." + "Only if the user wants to control a device, tell them to expose entities to their " + "voice assistant in Home Assistant." ) + # Expose entities entry = MockConfigEntry(title=None) entry.add_to_hass(hass) - tool_input.device_id = device_registry.async_get_or_create( + device = device_registry.async_get_or_create( config_entry_id=entry.entry_id, connections={("test", "1234")}, - name="Test Device", - manufacturer="Test Manufacturer", - model="Test Model", suggested_area="Test Area", - ).id - prompt = await api.async_get_api_prompt(tool_input) - assert prompt == ( - "Call the intent tools to control Home Assistant." - " Just pass the name to the intent. You are in Test Area." + ) + area = area_registry.async_get_area_by_name("Test Area") + area_registry.async_update(area.id, aliases=["Alternative name"]) + entry1 = entity_registry.async_get_or_create( + "light", + "kitchen", + "mock-id-kitchen", + original_name="Kitchen", + suggested_object_id="kitchen", + ) + entry2 = entity_registry.async_get_or_create( + "light", + "living_room", + "mock-id-living-room", + original_name="Living Room", + suggested_object_id="living_room", + device_id=device.id, + ) + hass.states.async_set(entry1.entity_id, "on", {"friendly_name": "Kitchen"}) + hass.states.async_set(entry2.entity_id, "on", {"friendly_name": "Living Room"}) + + def create_entity(device: dr.DeviceEntry, write_state=True) -> None: + """Create an entity for a device and track entity_id.""" + entity = entity_registry.async_get_or_create( + "light", + "test", + device.id, + device_id=device.id, + original_name=str(device.name or "Unnamed Device"), + suggested_object_id=str(device.name or "unnamed_device"), + ) + if write_state: + entity.write_unavailable_state(hass) + + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "1234")}, + name="Test Device", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + ) + ) + for i in range(3): + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", f"{i}abcd")}, + name="Test Service", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + entry_type=dr.DeviceEntryType.SERVICE, + ) + ) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "5678")}, + name="Test Device 2", + manufacturer="Test Manufacturer 2", + model="Device 2", + suggested_area="Test Area 2", + ) + ) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + ) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "qwer")}, + name="Test Device 4", + suggested_area="Test Area 2", + ) + ) + device2 = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-disabled")}, + name="Test Device 3 - disabled", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_update_device( + device2.id, disabled_by=dr.DeviceEntryDisabler.USER + ) + create_entity(device2, False) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-no-name")}, + manufacturer="Test Manufacturer NoName", + model="Test Model NoName", + suggested_area="Test Area 2", + ) + ) + create_entity( + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-integer-values")}, + name=1, + manufacturer=2, + model=3, + suggested_area="Test Area 2", + ) ) + exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant) + assert exposed_entities == { + "light.1": { + "areas": "Test Area 2", + "names": "1", + "state": "unavailable", + }, + entry1.entity_id: { + "names": "Kitchen", + "state": "on", + }, + entry2.entity_id: { + "areas": "Test Area, Alternative name", + "names": "Living Room", + "state": "on", + }, + "light.test_device": { + "areas": "Test Area, Alternative name", + "names": "Test Device", + "state": "unavailable", + }, + "light.test_device_2": { + "areas": "Test Area 2", + "names": "Test Device 2", + "state": "unavailable", + }, + "light.test_device_3": { + "areas": "Test Area 2", + "names": "Test Device 3", + "state": "unavailable", + }, + "light.test_device_4": { + "areas": "Test Area 2", + "names": "Test Device 4", + "state": "unavailable", + }, + "light.test_service": { + "areas": "Test Area, Alternative name", + "names": "Test Service", + "state": "unavailable", + }, + "light.test_service_2": { + "areas": "Test Area, Alternative name", + "names": "Test Service", + "state": "unavailable", + }, + "light.test_service_3": { + "areas": "Test Area, Alternative name", + "names": "Test Service", + "state": "unavailable", + }, + "light.unnamed_device": { + "areas": "Test Area 2", + "names": "Unnamed Device", + "state": "unavailable", + }, + } + exposed_entities_prompt = ( + "An overview of the areas and the devices in this smart home:\n" + + yaml.dump(exposed_entities) + ) + first_part_prompt = ( + "Call the intent tools to control Home Assistant. " + "Just pass the name to the intent. " + "When controlling an area, prefer passing area name." + ) + + prompt = await api.async_get_api_prompt(tool_input) + assert prompt == ( + f"""{first_part_prompt} +{exposed_entities_prompt}""" + ) + + # Fake that request is made from a specific device ID + tool_input.device_id = device.id + prompt = await api.async_get_api_prompt(tool_input) + assert prompt == ( + f"""{first_part_prompt} +You are in Test Area. +{exposed_entities_prompt}""" + ) + + # Add floor floor = floor_registry.async_create("second floor") - area = area_registry.async_get_area_by_name("Test Area") area_registry.async_update(area.id, floor_id=floor.floor_id) prompt = await api.async_get_api_prompt(tool_input) assert prompt == ( - "Call the intent tools to control Home Assistant." - " Just pass the name to the intent. You are in Test Area (second floor)." + f"""{first_part_prompt} +You are in Test Area (second floor). +{exposed_entities_prompt}""" ) + # Add user context.user_id = "12345" mock_user = Mock() mock_user.id = "12345" @@ -212,7 +408,8 @@ async def test_assist_api_prompt( with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user): prompt = await api.async_get_api_prompt(tool_input) assert prompt == ( - "Call the intent tools to control Home Assistant." - " Just pass the name to the intent. You are in Test Area (second floor)." - " The user name is Test User." + f"""{first_part_prompt} +You are in Test Area (second floor). +The user name is Test User. +{exposed_entities_prompt}""" )