Use satellite entity area in the default agent (#152762)

This commit is contained in:
Artur Pragacz
2025-09-22 20:34:31 +02:00
committed by GitHub
parent 7b7265a6b0
commit 4eaf6784af
3 changed files with 65 additions and 35 deletions

View File

@@ -153,8 +153,8 @@ class IntentCacheKey:
language: str language: str
"""Language of text.""" """Language of text."""
device_id: str | None satellite_id: str | None
"""Device id from user input.""" """Satellite id from user input."""
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -443,9 +443,15 @@ class DefaultAgent(ConversationEntity):
} }
for entity in result.entities_list for entity in result.entities_list
} }
device_area = self._get_device_area(user_input.device_id)
if device_area: satellite_id = user_input.satellite_id
slots["preferred_area_id"] = {"value": device_area.id} device_id = user_input.device_id
satellite_area, device_id = self._get_satellite_area_and_device(
satellite_id, device_id
)
if satellite_area is not None:
slots["preferred_area_id"] = {"value": satellite_area.id}
async_conversation_trace_append( async_conversation_trace_append(
ConversationTraceEventType.TOOL_CALL, ConversationTraceEventType.TOOL_CALL,
{ {
@@ -467,8 +473,8 @@ class DefaultAgent(ConversationEntity):
user_input.context, user_input.context,
language, language,
assistant=DOMAIN, assistant=DOMAIN,
device_id=user_input.device_id, device_id=device_id,
satellite_id=user_input.satellite_id, satellite_id=satellite_id,
conversation_agent_id=user_input.agent_id, conversation_agent_id=user_input.agent_id,
) )
except intent.MatchFailedError as match_error: except intent.MatchFailedError as match_error:
@@ -534,7 +540,9 @@ class DefaultAgent(ConversationEntity):
# Try cache first # Try cache first
cache_key = IntentCacheKey( cache_key = IntentCacheKey(
text=user_input.text, language=language, device_id=user_input.device_id text=user_input.text,
language=language,
satellite_id=user_input.satellite_id,
) )
cache_value = self._intent_cache.get(cache_key) cache_value = self._intent_cache.get(cache_key)
if cache_value is not None: if cache_value is not None:
@@ -1304,28 +1312,40 @@ class DefaultAgent(ConversationEntity):
self, user_input: ConversationInput self, user_input: ConversationInput
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Return intent recognition context for user input.""" """Return intent recognition context for user input."""
if not user_input.device_id: satellite_area, _ = self._get_satellite_area_and_device(
user_input.satellite_id, user_input.device_id
)
if satellite_area is None:
return None return None
device_area = self._get_device_area(user_input.device_id) return {"area": {"value": satellite_area.name, "text": satellite_area.name}}
if device_area is None:
return None
return {"area": {"value": device_area.name, "text": device_area.name}} def _get_satellite_area_and_device(
self, satellite_id: str | None, device_id: str | None = None
) -> tuple[ar.AreaEntry | None, str | None]:
"""Return area entry and device id."""
hass = self.hass
def _get_device_area(self, device_id: str | None) -> ar.AreaEntry | None: area_id: str | None = None
"""Return area object for given device identifier."""
if device_id is None:
return None
devices = dr.async_get(self.hass) if (
device = devices.async_get(device_id) satellite_id is not None
if (device is None) or (device.area_id is None): and (entity_entry := er.async_get(hass).async_get(satellite_id)) is not None
return None ):
area_id = entity_entry.area_id
device_id = entity_entry.device_id
areas = ar.async_get(self.hass) if (
area_id is None
and device_id is not None
and (device_entry := dr.async_get(hass).async_get(device_id)) is not None
):
area_id = device_entry.area_id
return areas.async_get_area(device.area_id) if area_id is None:
return None, device_id
return ar.async_get(hass).async_get_area(area_id), device_id
def _get_error_text( def _get_error_text(
self, self,

View File

@@ -15,7 +15,7 @@ import voluptuous as vol
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.script import ScriptRunResult from homeassistant.helpers.script import ScriptRunResult
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import UNDEFINED, ConfigType from homeassistant.helpers.typing import UNDEFINED, ConfigType
@@ -71,6 +71,8 @@ async def async_attach_trigger(
trigger_data = trigger_info["trigger_data"] trigger_data = trigger_info["trigger_data"]
sentences = config.get(CONF_COMMAND, []) sentences = config.get(CONF_COMMAND, [])
ent_reg = er.async_get(hass)
job = HassJob(action) job = HassJob(action)
async def call_action( async def call_action(
@@ -92,6 +94,14 @@ async def async_attach_trigger(
for entity_name, entity in result.entities.items() for entity_name, entity in result.entities.items()
} }
satellite_id = user_input.satellite_id
device_id = user_input.device_id
if (
satellite_id is not None
and (satellite_entry := ent_reg.async_get(satellite_id)) is not None
):
device_id = satellite_entry.device_id
trigger_input: dict[str, Any] = { # Satisfy type checker trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data, **trigger_data,
"platform": DOMAIN, "platform": DOMAIN,
@@ -100,8 +110,8 @@ async def async_attach_trigger(
"slots": { # direct access to values "slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items() entity_name: entity["value"] for entity_name, entity in details.items()
}, },
"device_id": user_input.device_id, "device_id": device_id,
"satellite_id": user_input.satellite_id, "satellite_id": satellite_id,
"user_input": user_input.as_dict(), "user_input": user_input.as_dict(),
} }

View File

@@ -522,13 +522,13 @@ async def test_respond_intent(hass: HomeAssistant) -> None:
@pytest.mark.usefixtures("init_components") @pytest.mark.usefixtures("init_components")
async def test_device_area_context( async def test_satellite_area_context(
hass: HomeAssistant, hass: HomeAssistant,
area_registry: ar.AreaRegistry, area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test that including a device_id will target a specific area.""" """Test that including a satellite will target a specific area."""
turn_on_calls = async_mock_service(hass, "light", "turn_on") turn_on_calls = async_mock_service(hass, "light", "turn_on")
turn_off_calls = async_mock_service(hass, "light", "turn_off") turn_off_calls = async_mock_service(hass, "light", "turn_off")
@@ -560,12 +560,12 @@ async def test_device_area_context(
entry = MockConfigEntry() entry = MockConfigEntry()
entry.add_to_hass(hass) entry.add_to_hass(hass)
kitchen_satellite = device_registry.async_get_or_create( kitchen_satellite = entity_registry.async_get_or_create(
config_entry_id=entry.entry_id, "assist_satellite", "demo", "kitchen"
connections=set(), )
identifiers={("demo", "id-satellite-kitchen")}, entity_registry.async_update_entity(
kitchen_satellite.entity_id, area_id=area_kitchen.id
) )
device_registry.async_update_device(kitchen_satellite.id, area_id=area_kitchen.id)
bedroom_satellite = device_registry.async_get_or_create( bedroom_satellite = device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
@@ -581,7 +581,7 @@ async def test_device_area_context(
None, None,
Context(), Context(),
None, None,
device_id=kitchen_satellite.id, satellite_id=kitchen_satellite.entity_id,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
@@ -605,7 +605,7 @@ async def test_device_area_context(
None, None,
Context(), Context(),
None, None,
device_id=kitchen_satellite.id, satellite_id=kitchen_satellite.entity_id,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE