Keep expose setting in sync for assist (#92158)

* Keep expose setting in sync for assist

* Fix initialization, add test

* Fix tests

* Add AgentManager.async_setup

* Fix typo

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2023-04-28 15:59:21 +02:00 committed by GitHub
parent 2bfa521068
commit ebd9cd096a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 7 deletions

View File

@ -23,7 +23,7 @@ from homeassistant.util import language as language_util
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import HOME_ASSISTANT_AGENT
from .default_agent import DefaultAgent
from .default_agent import DefaultAgent, async_setup as async_setup_default_agent
__all__ = [
"DOMAIN",
@ -93,7 +93,9 @@ CONFIG_SCHEMA = vol.Schema(
@core.callback
def _get_agent_manager(hass: HomeAssistant) -> AgentManager:
"""Get the active agent."""
return AgentManager(hass)
manager = AgentManager(hass)
manager.async_setup()
return manager
@core.callback
@ -389,7 +391,11 @@ class AgentManager:
"""Initialize the conversation agents."""
self.hass = hass
self._agents: dict[str, AbstractConversationAgent] = {}
self._default_agent_init_lock = asyncio.Lock()
self._builtin_agent_init_lock = asyncio.Lock()
def async_setup(self) -> None:
"""Set up the conversation agents."""
async_setup_default_agent(self.hass)
async def async_get_agent(
self, agent_id: str | None = None
@ -402,7 +408,7 @@ class AgentManager:
if self._builtin_agent is not None:
return self._builtin_agent
async with self._default_agent_init_lock:
async with self._builtin_agent_init_lock:
if self._builtin_agent is not None:
return self._builtin_agent

View File

@ -73,6 +73,26 @@ def _get_language_variations(language: str) -> Iterable[str]:
yield lang
@core.callback
def async_setup(hass: core.HomeAssistant) -> None:
"""Set up entity registry listener for the default agent."""
entity_registry = er.async_get(hass)
for entity_id in entity_registry.entities:
async_should_expose(hass, DOMAIN, entity_id)
@core.callback
def async_handle_entity_registry_changed(event: core.Event) -> None:
"""Set expose flag on newly created entities."""
if event.data["action"] == "create":
async_should_expose(hass, DOMAIN, event.data["entity_id"])
hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
async_handle_entity_registry_changed,
run_immediately=True,
)
class DefaultAgent(AbstractConversationAgent):
"""Default agent for conversation agent."""
@ -472,10 +492,10 @@ class DefaultAgent(AbstractConversationAgent):
return self._slot_lists
area_ids_with_entities: set[str] = set()
all_entities = er.async_get(self.hass)
entity_registry = er.async_get(self.hass)
entities = [
entity
for entity in all_entities.entities.values()
for entity in entity_registry.entities.values()
if async_should_expose(self.hass, DOMAIN, entity.entity_id)
]
devices = dr.async_get(self.hass)

View File

@ -4,6 +4,9 @@ from unittest.mock import patch
import pytest
from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import (
async_get_assistant_settings,
)
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant
from homeassistant.helpers import (
@ -137,3 +140,34 @@ async def test_conversation_agent(
return_value={"homeassistant": ["dwarvish", "elvish", "entish"]},
):
assert agent.supported_languages == ["dwarvish", "elvish", "entish"]
async def test_expose_flag_automatically_set(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
) -> None:
"""Test DefaultAgent sets the expose flag on all entities automatically."""
assert await async_setup_component(hass, "homeassistant", {})
light = entity_registry.async_get_or_create("light", "demo", "1234")
test = entity_registry.async_get_or_create("test", "demo", "1234")
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {}
assert await async_setup_component(hass, "conversation", {})
await hass.async_block_till_done()
# After setting up conversation, the expose flag should now be set on all entities
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
light.entity_id: {"should_expose": True},
test.entity_id: {"should_expose": False},
}
# New entities will automatically have the expose flag set
new_light = entity_registry.async_get_or_create("light", "demo", "2345")
await hass.async_block_till_done()
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
light.entity_id: {"should_expose": True},
new_light.entity_id: {"should_expose": True},
test.entity_id: {"should_expose": False},
}

View File

@ -18,6 +18,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component
from .const import CALL_SERVICE, FIRE_EVENT, REGISTER_CLEARTEXT, RENDER_TEMPLATE, UPDATE
@ -28,6 +29,12 @@ from tests.components.conversation.conftest import mock_agent
mock_agent = mock_agent
@pytest.fixture
async def homeassistant(hass):
"""Load the homeassistant integration."""
await async_setup_component(hass, "homeassistant", {})
def encrypt_payload(secret_key, payload, encode_json=True):
"""Return a encrypted payload given a key and dictionary of data."""
try:
@ -1014,7 +1021,7 @@ async def test_reregister_sensor(
async def test_webhook_handle_conversation_process(
hass: HomeAssistant, create_registrations, webhook_client, mock_agent
hass: HomeAssistant, homeassistant, create_registrations, webhook_client, mock_agent
) -> None:
"""Test that we can converse."""
webhook_client.server.app.router._frozen = False