mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Keep expose setting in sync for assist (#92158)
* Keep expose setting in sync for assist * Fix initialization, add test * Fix tests * Add AgentManager.async_setup * Fix typo --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
2bfa521068
commit
ebd9cd096a
@ -23,7 +23,7 @@ from homeassistant.util import language as language_util
|
||||
|
||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .const import HOME_ASSISTANT_AGENT
|
||||
from .default_agent import DefaultAgent
|
||||
from .default_agent import DefaultAgent, async_setup as async_setup_default_agent
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
@ -93,7 +93,9 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
@core.callback
|
||||
def _get_agent_manager(hass: HomeAssistant) -> AgentManager:
|
||||
"""Get the active agent."""
|
||||
return AgentManager(hass)
|
||||
manager = AgentManager(hass)
|
||||
manager.async_setup()
|
||||
return manager
|
||||
|
||||
|
||||
@core.callback
|
||||
@ -389,7 +391,11 @@ class AgentManager:
|
||||
"""Initialize the conversation agents."""
|
||||
self.hass = hass
|
||||
self._agents: dict[str, AbstractConversationAgent] = {}
|
||||
self._default_agent_init_lock = asyncio.Lock()
|
||||
self._builtin_agent_init_lock = asyncio.Lock()
|
||||
|
||||
def async_setup(self) -> None:
|
||||
"""Set up the conversation agents."""
|
||||
async_setup_default_agent(self.hass)
|
||||
|
||||
async def async_get_agent(
|
||||
self, agent_id: str | None = None
|
||||
@ -402,7 +408,7 @@ class AgentManager:
|
||||
if self._builtin_agent is not None:
|
||||
return self._builtin_agent
|
||||
|
||||
async with self._default_agent_init_lock:
|
||||
async with self._builtin_agent_init_lock:
|
||||
if self._builtin_agent is not None:
|
||||
return self._builtin_agent
|
||||
|
||||
|
@ -73,6 +73,26 @@ def _get_language_variations(language: str) -> Iterable[str]:
|
||||
yield lang
|
||||
|
||||
|
||||
@core.callback
|
||||
def async_setup(hass: core.HomeAssistant) -> None:
|
||||
"""Set up entity registry listener for the default agent."""
|
||||
entity_registry = er.async_get(hass)
|
||||
for entity_id in entity_registry.entities:
|
||||
async_should_expose(hass, DOMAIN, entity_id)
|
||||
|
||||
@core.callback
|
||||
def async_handle_entity_registry_changed(event: core.Event) -> None:
|
||||
"""Set expose flag on newly created entities."""
|
||||
if event.data["action"] == "create":
|
||||
async_should_expose(hass, DOMAIN, event.data["entity_id"])
|
||||
|
||||
hass.bus.async_listen(
|
||||
er.EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
async_handle_entity_registry_changed,
|
||||
run_immediately=True,
|
||||
)
|
||||
|
||||
|
||||
class DefaultAgent(AbstractConversationAgent):
|
||||
"""Default agent for conversation agent."""
|
||||
|
||||
@ -472,10 +492,10 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
return self._slot_lists
|
||||
|
||||
area_ids_with_entities: set[str] = set()
|
||||
all_entities = er.async_get(self.hass)
|
||||
entity_registry = er.async_get(self.hass)
|
||||
entities = [
|
||||
entity
|
||||
for entity in all_entities.entities.values()
|
||||
for entity in entity_registry.entities.values()
|
||||
if async_should_expose(self.hass, DOMAIN, entity.entity_id)
|
||||
]
|
||||
devices = dr.async_get(self.hass)
|
||||
|
@ -4,6 +4,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.homeassistant.exposed_entities import (
|
||||
async_get_assistant_settings,
|
||||
)
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant
|
||||
from homeassistant.helpers import (
|
||||
@ -137,3 +140,34 @@ async def test_conversation_agent(
|
||||
return_value={"homeassistant": ["dwarvish", "elvish", "entish"]},
|
||||
):
|
||||
assert agent.supported_languages == ["dwarvish", "elvish", "entish"]
|
||||
|
||||
|
||||
async def test_expose_flag_automatically_set(
|
||||
hass: HomeAssistant,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test DefaultAgent sets the expose flag on all entities automatically."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||
test = entity_registry.async_get_or_create("test", "demo", "1234")
|
||||
|
||||
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {}
|
||||
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# After setting up conversation, the expose flag should now be set on all entities
|
||||
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
|
||||
light.entity_id: {"should_expose": True},
|
||||
test.entity_id: {"should_expose": False},
|
||||
}
|
||||
|
||||
# New entities will automatically have the expose flag set
|
||||
new_light = entity_registry.async_get_or_create("light", "demo", "2345")
|
||||
await hass.async_block_till_done()
|
||||
assert async_get_assistant_settings(hass, conversation.DOMAIN) == {
|
||||
light.entity_id: {"should_expose": True},
|
||||
new_light.entity_id: {"should_expose": True},
|
||||
test.entity_id: {"should_expose": False},
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ from homeassistant.const import (
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .const import CALL_SERVICE, FIRE_EVENT, REGISTER_CLEARTEXT, RENDER_TEMPLATE, UPDATE
|
||||
|
||||
@ -28,6 +29,12 @@ from tests.components.conversation.conftest import mock_agent
|
||||
mock_agent = mock_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def homeassistant(hass):
|
||||
"""Load the homeassistant integration."""
|
||||
await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
|
||||
def encrypt_payload(secret_key, payload, encode_json=True):
|
||||
"""Return a encrypted payload given a key and dictionary of data."""
|
||||
try:
|
||||
@ -1014,7 +1021,7 @@ async def test_reregister_sensor(
|
||||
|
||||
|
||||
async def test_webhook_handle_conversation_process(
|
||||
hass: HomeAssistant, create_registrations, webhook_client, mock_agent
|
||||
hass: HomeAssistant, homeassistant, create_registrations, webhook_client, mock_agent
|
||||
) -> None:
|
||||
"""Test that we can converse."""
|
||||
webhook_client.server.app.router._frozen = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user