mirror of
https://github.com/home-assistant/core.git
synced 2025-11-11 20:10:12 +00:00
Use satellite entity area in the default agent (#152762)
This commit is contained in:
@@ -153,8 +153,8 @@ class IntentCacheKey:
|
|||||||
language: str
|
language: str
|
||||||
"""Language of text."""
|
"""Language of text."""
|
||||||
|
|
||||||
device_id: str | None
|
satellite_id: str | None
|
||||||
"""Device id from user input."""
|
"""Satellite id from user input."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -443,9 +443,15 @@ class DefaultAgent(ConversationEntity):
|
|||||||
}
|
}
|
||||||
for entity in result.entities_list
|
for entity in result.entities_list
|
||||||
}
|
}
|
||||||
device_area = self._get_device_area(user_input.device_id)
|
|
||||||
if device_area:
|
satellite_id = user_input.satellite_id
|
||||||
slots["preferred_area_id"] = {"value": device_area.id}
|
device_id = user_input.device_id
|
||||||
|
satellite_area, device_id = self._get_satellite_area_and_device(
|
||||||
|
satellite_id, device_id
|
||||||
|
)
|
||||||
|
if satellite_area is not None:
|
||||||
|
slots["preferred_area_id"] = {"value": satellite_area.id}
|
||||||
|
|
||||||
async_conversation_trace_append(
|
async_conversation_trace_append(
|
||||||
ConversationTraceEventType.TOOL_CALL,
|
ConversationTraceEventType.TOOL_CALL,
|
||||||
{
|
{
|
||||||
@@ -467,8 +473,8 @@ class DefaultAgent(ConversationEntity):
|
|||||||
user_input.context,
|
user_input.context,
|
||||||
language,
|
language,
|
||||||
assistant=DOMAIN,
|
assistant=DOMAIN,
|
||||||
device_id=user_input.device_id,
|
device_id=device_id,
|
||||||
satellite_id=user_input.satellite_id,
|
satellite_id=satellite_id,
|
||||||
conversation_agent_id=user_input.agent_id,
|
conversation_agent_id=user_input.agent_id,
|
||||||
)
|
)
|
||||||
except intent.MatchFailedError as match_error:
|
except intent.MatchFailedError as match_error:
|
||||||
@@ -534,7 +540,9 @@ class DefaultAgent(ConversationEntity):
|
|||||||
|
|
||||||
# Try cache first
|
# Try cache first
|
||||||
cache_key = IntentCacheKey(
|
cache_key = IntentCacheKey(
|
||||||
text=user_input.text, language=language, device_id=user_input.device_id
|
text=user_input.text,
|
||||||
|
language=language,
|
||||||
|
satellite_id=user_input.satellite_id,
|
||||||
)
|
)
|
||||||
cache_value = self._intent_cache.get(cache_key)
|
cache_value = self._intent_cache.get(cache_key)
|
||||||
if cache_value is not None:
|
if cache_value is not None:
|
||||||
@@ -1304,28 +1312,40 @@ class DefaultAgent(ConversationEntity):
|
|||||||
self, user_input: ConversationInput
|
self, user_input: ConversationInput
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Return intent recognition context for user input."""
|
"""Return intent recognition context for user input."""
|
||||||
if not user_input.device_id:
|
satellite_area, _ = self._get_satellite_area_and_device(
|
||||||
|
user_input.satellite_id, user_input.device_id
|
||||||
|
)
|
||||||
|
if satellite_area is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
device_area = self._get_device_area(user_input.device_id)
|
return {"area": {"value": satellite_area.name, "text": satellite_area.name}}
|
||||||
if device_area is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return {"area": {"value": device_area.name, "text": device_area.name}}
|
def _get_satellite_area_and_device(
|
||||||
|
self, satellite_id: str | None, device_id: str | None = None
|
||||||
|
) -> tuple[ar.AreaEntry | None, str | None]:
|
||||||
|
"""Return area entry and device id."""
|
||||||
|
hass = self.hass
|
||||||
|
|
||||||
def _get_device_area(self, device_id: str | None) -> ar.AreaEntry | None:
|
area_id: str | None = None
|
||||||
"""Return area object for given device identifier."""
|
|
||||||
if device_id is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
devices = dr.async_get(self.hass)
|
if (
|
||||||
device = devices.async_get(device_id)
|
satellite_id is not None
|
||||||
if (device is None) or (device.area_id is None):
|
and (entity_entry := er.async_get(hass).async_get(satellite_id)) is not None
|
||||||
return None
|
):
|
||||||
|
area_id = entity_entry.area_id
|
||||||
|
device_id = entity_entry.device_id
|
||||||
|
|
||||||
areas = ar.async_get(self.hass)
|
if (
|
||||||
|
area_id is None
|
||||||
|
and device_id is not None
|
||||||
|
and (device_entry := dr.async_get(hass).async_get(device_id)) is not None
|
||||||
|
):
|
||||||
|
area_id = device_entry.area_id
|
||||||
|
|
||||||
return areas.async_get_area(device.area_id)
|
if area_id is None:
|
||||||
|
return None, device_id
|
||||||
|
|
||||||
|
return ar.async_get(hass).async_get_area(area_id), device_id
|
||||||
|
|
||||||
def _get_error_text(
|
def _get_error_text(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
|
from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
|
||||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant
|
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv, entity_registry as er
|
||||||
from homeassistant.helpers.script import ScriptRunResult
|
from homeassistant.helpers.script import ScriptRunResult
|
||||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
||||||
@@ -71,6 +71,8 @@ async def async_attach_trigger(
|
|||||||
trigger_data = trigger_info["trigger_data"]
|
trigger_data = trigger_info["trigger_data"]
|
||||||
sentences = config.get(CONF_COMMAND, [])
|
sentences = config.get(CONF_COMMAND, [])
|
||||||
|
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
|
||||||
job = HassJob(action)
|
job = HassJob(action)
|
||||||
|
|
||||||
async def call_action(
|
async def call_action(
|
||||||
@@ -92,6 +94,14 @@ async def async_attach_trigger(
|
|||||||
for entity_name, entity in result.entities.items()
|
for entity_name, entity in result.entities.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
satellite_id = user_input.satellite_id
|
||||||
|
device_id = user_input.device_id
|
||||||
|
if (
|
||||||
|
satellite_id is not None
|
||||||
|
and (satellite_entry := ent_reg.async_get(satellite_id)) is not None
|
||||||
|
):
|
||||||
|
device_id = satellite_entry.device_id
|
||||||
|
|
||||||
trigger_input: dict[str, Any] = { # Satisfy type checker
|
trigger_input: dict[str, Any] = { # Satisfy type checker
|
||||||
**trigger_data,
|
**trigger_data,
|
||||||
"platform": DOMAIN,
|
"platform": DOMAIN,
|
||||||
@@ -100,8 +110,8 @@ async def async_attach_trigger(
|
|||||||
"slots": { # direct access to values
|
"slots": { # direct access to values
|
||||||
entity_name: entity["value"] for entity_name, entity in details.items()
|
entity_name: entity["value"] for entity_name, entity in details.items()
|
||||||
},
|
},
|
||||||
"device_id": user_input.device_id,
|
"device_id": device_id,
|
||||||
"satellite_id": user_input.satellite_id,
|
"satellite_id": satellite_id,
|
||||||
"user_input": user_input.as_dict(),
|
"user_input": user_input.as_dict(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -522,13 +522,13 @@ async def test_respond_intent(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("init_components")
|
@pytest.mark.usefixtures("init_components")
|
||||||
async def test_device_area_context(
|
async def test_satellite_area_context(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
area_registry: ar.AreaRegistry,
|
area_registry: ar.AreaRegistry,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that including a device_id will target a specific area."""
|
"""Test that including a satellite will target a specific area."""
|
||||||
turn_on_calls = async_mock_service(hass, "light", "turn_on")
|
turn_on_calls = async_mock_service(hass, "light", "turn_on")
|
||||||
turn_off_calls = async_mock_service(hass, "light", "turn_off")
|
turn_off_calls = async_mock_service(hass, "light", "turn_off")
|
||||||
|
|
||||||
@@ -560,12 +560,12 @@ async def test_device_area_context(
|
|||||||
entry = MockConfigEntry()
|
entry = MockConfigEntry()
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
kitchen_satellite = device_registry.async_get_or_create(
|
kitchen_satellite = entity_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
"assist_satellite", "demo", "kitchen"
|
||||||
connections=set(),
|
)
|
||||||
identifiers={("demo", "id-satellite-kitchen")},
|
entity_registry.async_update_entity(
|
||||||
|
kitchen_satellite.entity_id, area_id=area_kitchen.id
|
||||||
)
|
)
|
||||||
device_registry.async_update_device(kitchen_satellite.id, area_id=area_kitchen.id)
|
|
||||||
|
|
||||||
bedroom_satellite = device_registry.async_get_or_create(
|
bedroom_satellite = device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
@@ -581,7 +581,7 @@ async def test_device_area_context(
|
|||||||
None,
|
None,
|
||||||
Context(),
|
Context(),
|
||||||
None,
|
None,
|
||||||
device_id=kitchen_satellite.id,
|
satellite_id=kitchen_satellite.entity_id,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
@@ -605,7 +605,7 @@ async def test_device_area_context(
|
|||||||
None,
|
None,
|
||||||
Context(),
|
Context(),
|
||||||
None,
|
None,
|
||||||
device_id=kitchen_satellite.id,
|
satellite_id=kitchen_satellite.entity_id,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
|||||||
Reference in New Issue
Block a user