Allow core integrations to describe their triggers (#147075)

Co-authored-by: Abílio Costa <abmantis@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2025-06-25 18:35:15 +02:00 committed by GitHub
parent d8258924f7
commit 1fb587bf03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 908 additions and 13 deletions

View File

@ -89,6 +89,7 @@ from .helpers import (
restore_state,
template,
translation,
trigger,
)
from .helpers.dispatcher import async_dispatcher_send_internal
from .helpers.storage import get_internal_store_manager
@ -452,6 +453,7 @@ async def async_load_base_functionality(hass: core.HomeAssistant) -> None:
create_eager_task(restore_state.async_load(hass)),
create_eager_task(hass.config_entries.async_initialize()),
create_eager_task(async_get_system_info(hass)),
create_eager_task(trigger.async_setup(hass)),
)

View File

@ -9,5 +9,10 @@
"reload": {
"service": "mdi:reload"
}
},
"triggers": {
"mqtt": {
"trigger": "mdi:swap-horizontal"
}
}
}

View File

@ -992,6 +992,23 @@
"description": "Reloads MQTT entities from the YAML-configuration."
}
},
"triggers": {
"mqtt": {
"name": "MQTT",
"description": "When a specific message is received on a given MQTT topic.",
"description_configured": "When an MQTT message has been received",
"fields": {
"payload": {
"name": "Payload",
"description": "The payload to trigger on."
},
"topic": {
"name": "Topic",
"description": "MQTT topic to listen to."
}
}
}
},
"exceptions": {
"addon_start_failed": {
"message": "Failed to correctly start {addon} add-on."

View File

@ -0,0 +1,14 @@
# Describes the format for MQTT triggers
mqtt:
fields:
payload:
example: "on"
required: false
selector:
text:
topic:
example: "living_room/switch/ac"
required: true
selector:
text:

View File

@ -52,7 +52,13 @@ from homeassistant.helpers.json import (
json_bytes,
json_fragment,
)
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.service import (
async_get_all_descriptions as async_get_all_service_descriptions,
)
from homeassistant.helpers.trigger import (
async_get_all_descriptions as async_get_all_trigger_descriptions,
async_subscribe_platform_events as async_subscribe_trigger_platform_events,
)
from homeassistant.loader import (
IntegrationNotFound,
async_get_integration,
@ -68,9 +74,10 @@ from homeassistant.util.json import format_unserializable_data
from . import const, decorators, messages
from .connection import ActiveConnection
from .messages import construct_result_message
from .messages import construct_event_message, construct_result_message
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json"
ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_trigger_descriptions_json"
_LOGGER = logging.getLogger(__name__)
@ -96,6 +103,7 @@ def async_register_commands(
async_reg(hass, handle_subscribe_bootstrap_integrations)
async_reg(hass, handle_subscribe_events)
async_reg(hass, handle_subscribe_trigger)
async_reg(hass, handle_subscribe_trigger_platforms)
async_reg(hass, handle_test_condition)
async_reg(hass, handle_unsubscribe_events)
async_reg(hass, handle_validate_config)
@ -493,9 +501,9 @@ def _send_handle_entities_init_response(
)
async def _async_get_all_descriptions_json(hass: HomeAssistant) -> bytes:
async def _async_get_all_service_descriptions_json(hass: HomeAssistant) -> bytes:
"""Return JSON of descriptions (i.e. user documentation) for all service calls."""
descriptions = await async_get_all_descriptions(hass)
descriptions = await async_get_all_service_descriptions(hass)
if ALL_SERVICE_DESCRIPTIONS_JSON_CACHE in hass.data:
cached_descriptions, cached_json_payload = hass.data[
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE
@ -514,10 +522,57 @@ async def handle_get_services(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get services command."""
payload = await _async_get_all_descriptions_json(hass)
payload = await _async_get_all_service_descriptions_json(hass)
connection.send_message(construct_result_message(msg["id"], payload))
async def _async_get_all_trigger_descriptions_json(hass: HomeAssistant) -> bytes:
"""Return JSON of descriptions (i.e. user documentation) for all triggers."""
descriptions = await async_get_all_trigger_descriptions(hass)
if ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE in hass.data:
cached_descriptions, cached_json_payload = hass.data[
ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE
]
# If the descriptions are the same, return the cached JSON payload
if cached_descriptions is descriptions:
return cast(bytes, cached_json_payload)
json_payload = json_bytes(
{
trigger: description
for trigger, description in descriptions.items()
if description is not None
}
)
hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] = (descriptions, json_payload)
return json_payload
@decorators.websocket_command({vol.Required("type"): "trigger_platforms/subscribe"})
@decorators.async_response
async def handle_subscribe_trigger_platforms(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe triggers command."""
async def on_new_triggers(new_triggers: set[str]) -> None:
"""Forward new triggers to websocket."""
descriptions = await async_get_all_trigger_descriptions(hass)
new_trigger_descriptions = {}
for trigger in new_triggers:
if (description := descriptions[trigger]) is not None:
new_trigger_descriptions[trigger] = description
if not new_trigger_descriptions:
return
connection.send_event(msg["id"], new_trigger_descriptions)
connection.subscriptions[msg["id"]] = async_subscribe_trigger_platform_events(
hass, on_new_triggers
)
connection.send_result(msg["id"])
triggers_json = await _async_get_all_trigger_descriptions_json(hass)
connection.send_message(construct_event_message(msg["id"], triggers_json))
@callback
@decorators.websocket_command({vol.Required("type"): "get_config"})
def handle_get_config(

View File

@ -109,6 +109,19 @@ def event_message(iden: int, event: Any) -> dict[str, Any]:
return {"id": iden, "type": "event", "event": event}
def construct_event_message(iden: int, event: bytes) -> bytes:
"""Construct an event message JSON."""
return b"".join(
(
b'{"id":',
str(iden).encode(),
b',"type":"event","event":',
event,
b"}",
)
)
def cached_event_message(message_id_as_bytes: bytes, event: Event) -> bytes:
"""Return an event message.

View File

@ -5,11 +5,11 @@ from __future__ import annotations
import abc
import asyncio
from collections import defaultdict
from collections.abc import Callable, Coroutine
from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass, field
import functools
import logging
from typing import Any, Protocol, TypedDict, cast
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
import voluptuous as vol
@ -29,13 +29,24 @@ from homeassistant.core import (
is_callback,
)
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.loader import (
Integration,
IntegrationNotFound,
async_get_integration,
async_get_integrations,
)
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE
from . import config_validation as cv
from .integration_platform import async_process_integration_platforms
from .template import Template
from .typing import ConfigType, TemplateVarsType
_LOGGER = logging.getLogger(__name__)
_PLATFORM_ALIASES = {
"device": "device_automation",
"event": "homeassistant",
@ -49,6 +60,99 @@ DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = Has
"pluggable_actions"
)
TRIGGER_DESCRIPTION_CACHE: HassKey[dict[str, dict[str, Any] | None]] = HassKey(
"trigger_description_cache"
)
TRIGGER_PLATFORM_SUBSCRIPTIONS: HassKey[
list[Callable[[set[str]], Coroutine[Any, Any, None]]]
] = HassKey("trigger_platform_subscriptions")
TRIGGERS: HassKey[dict[str, str]] = HassKey("triggers")
# Basic schemas to sanity check the trigger descriptions,
# full validation is done by hassfest.triggers
_FIELD_SCHEMA = vol.Schema(
{},
extra=vol.ALLOW_EXTRA,
)
_TRIGGER_SCHEMA = vol.Schema(
{
vol.Optional("fields"): vol.Schema({str: _FIELD_SCHEMA}),
},
extra=vol.ALLOW_EXTRA,
)
def starts_with_dot(key: str) -> str:
"""Check if key starts with dot."""
if not key.startswith("."):
raise vol.Invalid("Key does not start with .")
return key
_TRIGGERS_SCHEMA = vol.Schema(
{
vol.Remove(vol.All(str, starts_with_dot)): object,
cv.slug: vol.Any(None, _TRIGGER_SCHEMA),
}
)
async def async_setup(hass: HomeAssistant) -> None:
"""Set up the trigger helper."""
hass.data[TRIGGER_DESCRIPTION_CACHE] = {}
hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] = []
hass.data[TRIGGERS] = {}
await async_process_integration_platforms(
hass, "trigger", _register_trigger_platform, wait_for_platforms=True
)
@callback
def async_subscribe_platform_events(
hass: HomeAssistant,
on_event: Callable[[set[str]], Coroutine[Any, Any, None]],
) -> Callable[[], None]:
"""Subscribe to trigger platform events."""
trigger_platform_event_subscriptions = hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS]
def remove_subscription() -> None:
trigger_platform_event_subscriptions.remove(on_event)
trigger_platform_event_subscriptions.append(on_event)
return remove_subscription
async def _register_trigger_platform(
hass: HomeAssistant, integration_domain: str, platform: TriggerProtocol
) -> None:
"""Register a trigger platform."""
new_triggers: set[str] = set()
if hasattr(platform, "async_get_triggers"):
for trigger_key in await platform.async_get_triggers(hass):
hass.data[TRIGGERS][trigger_key] = integration_domain
new_triggers.add(trigger_key)
elif hasattr(platform, "async_validate_trigger_config") or hasattr(
platform, "TRIGGER_SCHEMA"
):
hass.data[TRIGGERS][integration_domain] = integration_domain
new_triggers.add(integration_domain)
else:
_LOGGER.debug(
"Integration %s does not provide trigger support, skipping",
integration_domain,
)
return
tasks: list[asyncio.Task[None]] = [
create_eager_task(listener(new_triggers))
for listener in hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS]
]
await asyncio.gather(*tasks)
class Trigger(abc.ABC):
"""Trigger class."""
@ -409,3 +513,107 @@ async def async_initialize_triggers(
remove()
return remove_triggers
def _load_triggers_file(hass: HomeAssistant, integration: Integration) -> JSON_TYPE:
"""Load triggers file for an integration."""
try:
return cast(
JSON_TYPE,
_TRIGGERS_SCHEMA(
load_yaml_dict(str(integration.file_path / "triggers.yaml"))
),
)
except FileNotFoundError:
_LOGGER.warning(
"Unable to find triggers.yaml for the %s integration", integration.domain
)
return {}
except (HomeAssistantError, vol.Invalid) as ex:
_LOGGER.warning(
"Unable to parse triggers.yaml for the %s integration: %s",
integration.domain,
ex,
)
return {}
def _load_triggers_files(
hass: HomeAssistant, integrations: Iterable[Integration]
) -> dict[str, JSON_TYPE]:
"""Load trigger files for multiple integrations."""
return {
integration.domain: _load_triggers_file(hass, integration)
for integration in integrations
}
async def async_get_all_descriptions(
hass: HomeAssistant,
) -> dict[str, dict[str, Any] | None]:
"""Return descriptions (i.e. user documentation) for all triggers."""
descriptions_cache = hass.data[TRIGGER_DESCRIPTION_CACHE]
triggers = hass.data[TRIGGERS]
# See if there are new triggers not seen before.
# Any trigger that we saw before already has an entry in description_cache.
all_triggers = set(triggers)
previous_all_triggers = set(descriptions_cache)
# If the triggers are the same, we can return the cache
if previous_all_triggers == all_triggers:
return descriptions_cache
# Files we loaded for missing descriptions
new_triggers_descriptions: dict[str, JSON_TYPE] = {}
# We try to avoid making a copy in the event the cache is good,
# but now we must make a copy in case new triggers get added
# while we are loading the missing ones so we do not
# add the new ones to the cache without their descriptions
triggers = triggers.copy()
if missing_triggers := all_triggers.difference(descriptions_cache):
domains_with_missing_triggers = {
triggers[missing_trigger] for missing_trigger in missing_triggers
}
ints_or_excs = await async_get_integrations(hass, domains_with_missing_triggers)
integrations: list[Integration] = []
for domain, int_or_exc in ints_or_excs.items():
if type(int_or_exc) is Integration and int_or_exc.has_triggers:
integrations.append(int_or_exc)
continue
if TYPE_CHECKING:
assert isinstance(int_or_exc, Exception)
_LOGGER.debug(
"Failed to load triggers.yaml for integration: %s",
domain,
exc_info=int_or_exc,
)
if integrations:
new_triggers_descriptions = await hass.async_add_executor_job(
_load_triggers_files, hass, integrations
)
# Make a copy of the old cache and add missing descriptions to it
new_descriptions_cache = descriptions_cache.copy()
for missing_trigger in missing_triggers:
domain = triggers[missing_trigger]
if (
yaml_description := new_triggers_descriptions.get(domain, {}).get( # type: ignore[union-attr]
missing_trigger
)
) is None:
_LOGGER.debug(
"No trigger descriptions found for trigger %s, skipping",
missing_trigger,
)
new_descriptions_cache[missing_trigger] = None
continue
description = {"fields": yaml_description.get("fields", {})}
new_descriptions_cache[missing_trigger] = description
hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache
return new_descriptions_cache

View File

@ -857,15 +857,20 @@ class Integration:
# True.
return self.manifest.get("import_executor", True)
@cached_property
def has_services(self) -> bool:
"""Return if the integration has services."""
return "services.yaml" in self._top_level_files
@cached_property
def has_translations(self) -> bool:
"""Return if the integration has translations."""
return "translations" in self._top_level_files
@cached_property
def has_services(self) -> bool:
"""Return if the integration has services."""
return "services.yaml" in self._top_level_files
def has_triggers(self) -> bool:
"""Return if the integration has triggers."""
return "triggers.yaml" in self._top_level_files
@property
def mqtt(self) -> list[str] | None:

View File

@ -28,6 +28,7 @@ from . import (
services,
ssdp,
translations,
triggers,
usb,
zeroconf,
)
@ -49,6 +50,7 @@ INTEGRATION_PLUGINS = [
services,
ssdp,
translations,
triggers,
usb,
zeroconf,
config_flow, # This needs to run last, after translations are processed

View File

@ -120,6 +120,16 @@ CUSTOM_INTEGRATION_SERVICE_ICONS_SCHEMA = cv.schema_with_slug_keys(
)
TRIGGER_ICONS_SCHEMA = cv.schema_with_slug_keys(
vol.Schema(
{
vol.Optional("trigger"): icon_value_validator,
}
),
slug_validator=translation_key_validator,
)
def icon_schema(
core_integration: bool, integration_type: str, no_entity_platform: bool
) -> vol.Schema:
@ -164,6 +174,7 @@ def icon_schema(
vol.Optional("services"): CORE_SERVICE_ICONS_SCHEMA
if core_integration
else CUSTOM_INTEGRATION_SERVICE_ICONS_SCHEMA,
vol.Optional("triggers"): TRIGGER_ICONS_SCHEMA,
}
)

View File

@ -416,6 +416,22 @@ def gen_strings_schema(config: Config, integration: Integration) -> vol.Schema:
},
slug_validator=translation_key_validator,
),
vol.Optional("triggers"): cv.schema_with_slug_keys(
{
vol.Required("name"): translation_value_validator,
vol.Required("description"): translation_value_validator,
vol.Required("description_configured"): translation_value_validator,
vol.Optional("fields"): cv.schema_with_slug_keys(
{
vol.Required("name"): str,
vol.Required("description"): translation_value_validator,
vol.Optional("example"): translation_value_validator,
},
slug_validator=translation_key_validator,
),
},
slug_validator=translation_key_validator,
),
vol.Optional("conversation"): {
vol.Required("agent"): {
vol.Required("done"): translation_value_validator,

238
script/hassfest/triggers.py Normal file
View File

@ -0,0 +1,238 @@
"""Validate triggers."""
from __future__ import annotations
import contextlib
import json
import pathlib
import re
from typing import Any
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import CONF_SELECTOR
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, selector, trigger
from homeassistant.util.yaml import load_yaml_dict
from .model import Config, Integration
def exists(value: Any) -> Any:
"""Check if value exists."""
if value is None:
raise vol.Invalid("Value cannot be None")
return value
FIELD_SCHEMA = vol.Schema(
{
vol.Optional("example"): exists,
vol.Optional("default"): exists,
vol.Optional("required"): bool,
vol.Optional(CONF_SELECTOR): selector.validate_selector,
}
)
TRIGGER_SCHEMA = vol.Any(
vol.Schema(
{
vol.Optional("fields"): vol.Schema({str: FIELD_SCHEMA}),
}
),
None,
)
TRIGGERS_SCHEMA = vol.Schema(
{
vol.Remove(vol.All(str, trigger.starts_with_dot)): object,
cv.slug: TRIGGER_SCHEMA,
}
)
NON_MIGRATED_INTEGRATIONS = {
"calendar",
"conversation",
"device_automation",
"geo_location",
"homeassistant",
"knx",
"lg_netcast",
"litejet",
"persistent_notification",
"samsungtv",
"sun",
"tag",
"template",
"webhook",
"webostv",
"zone",
"zwave_js",
}
def grep_dir(path: pathlib.Path, glob_pattern: str, search_pattern: str) -> bool:
"""Recursively go through a dir and it's children and find the regex."""
pattern = re.compile(search_pattern)
for fil in path.glob(glob_pattern):
if not fil.is_file():
continue
if pattern.search(fil.read_text()):
return True
return False
def validate_triggers(config: Config, integration: Integration) -> None: # noqa: C901
"""Validate triggers."""
try:
data = load_yaml_dict(str(integration.path / "triggers.yaml"))
except FileNotFoundError:
# Find if integration uses triggers
has_triggers = grep_dir(
integration.path,
"**/trigger.py",
r"async_attach_trigger|async_get_triggers",
)
if has_triggers and integration.domain not in NON_MIGRATED_INTEGRATIONS:
integration.add_error(
"triggers", "Registers triggers but has no triggers.yaml"
)
return
except HomeAssistantError:
integration.add_error("triggers", "Invalid triggers.yaml")
return
try:
triggers = TRIGGERS_SCHEMA(data)
except vol.Invalid as err:
integration.add_error(
"triggers", f"Invalid triggers.yaml: {humanize_error(data, err)}"
)
return
icons_file = integration.path / "icons.json"
icons = {}
if icons_file.is_file():
with contextlib.suppress(ValueError):
icons = json.loads(icons_file.read_text())
trigger_icons = icons.get("triggers", {})
# Try loading translation strings
if integration.core:
strings_file = integration.path / "strings.json"
else:
# For custom integrations, use the en.json file
strings_file = integration.path / "translations/en.json"
strings = {}
if strings_file.is_file():
with contextlib.suppress(ValueError):
strings = json.loads(strings_file.read_text())
error_msg_suffix = "in the translations file"
if not integration.core:
error_msg_suffix = f"and is not {error_msg_suffix}"
# For each trigger in the integration:
# 1. Check if the trigger description is set, if not,
# check if it's in the strings file else add an error.
# 2. Check if the trigger has an icon set in icons.json.
# raise an error if not.,
for trigger_name, trigger_schema in triggers.items():
if integration.core and trigger_name not in trigger_icons:
# This is enforced for Core integrations only
integration.add_error(
"triggers",
f"Trigger {trigger_name} has no icon in icons.json.",
)
if trigger_schema is None:
continue
if "name" not in trigger_schema and integration.core:
try:
strings["triggers"][trigger_name]["name"]
except KeyError:
integration.add_error(
"triggers",
f"Trigger {trigger_name} has no name {error_msg_suffix}",
)
if "description" not in trigger_schema and integration.core:
try:
strings["triggers"][trigger_name]["description"]
except KeyError:
integration.add_error(
"triggers",
f"Trigger {trigger_name} has no description {error_msg_suffix}",
)
# The same check is done for the description in each of the fields of the
# trigger schema.
for field_name, field_schema in trigger_schema.get("fields", {}).items():
if "fields" in field_schema:
# This is a section
continue
if "name" not in field_schema and integration.core:
try:
strings["triggers"][trigger_name]["fields"][field_name]["name"]
except KeyError:
integration.add_error(
"triggers",
(
f"Trigger {trigger_name} has a field {field_name} with no "
f"name {error_msg_suffix}"
),
)
if "description" not in field_schema and integration.core:
try:
strings["triggers"][trigger_name]["fields"][field_name][
"description"
]
except KeyError:
integration.add_error(
"triggers",
(
f"Trigger {trigger_name} has a field {field_name} with no "
f"description {error_msg_suffix}"
),
)
if "selector" in field_schema:
with contextlib.suppress(KeyError):
translation_key = field_schema["selector"]["select"][
"translation_key"
]
try:
strings["selector"][translation_key]
except KeyError:
integration.add_error(
"triggers",
f"Trigger {trigger_name} has a field {field_name} with a selector with a translation key {translation_key} that is not in the translations file",
)
# The same check is done for the description in each of the sections of the
# trigger schema.
for section_name, section_schema in trigger_schema.get("fields", {}).items():
if "fields" not in section_schema:
# This is not a section
continue
if "name" not in section_schema and integration.core:
try:
strings["triggers"][trigger_name]["sections"][section_name]["name"]
except KeyError:
integration.add_error(
"triggers",
f"Trigger {trigger_name} has a section {section_name} with no name {error_msg_suffix}",
)
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Handle dependencies for integrations."""
# check triggers.yaml is valid
for integration in integrations.values():
validate_triggers(config, integration)

View File

@ -87,6 +87,7 @@ from homeassistant.helpers import (
restore_state as rs,
storage,
translation,
trigger,
)
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
@ -295,6 +296,7 @@ async def async_test_home_assistant(
# Load the registries
entity.async_setup(hass)
loader.async_setup(hass)
await trigger.async_setup(hass)
# setup translation cache instead of calling translation.async_setup(hass)
hass.data[translation.TRANSLATION_FLATTEN_CACHE] = translation._TranslationCache(

View File

@ -2,6 +2,7 @@
import asyncio
from copy import deepcopy
import io
import logging
from typing import Any
from unittest.mock import ANY, AsyncMock, Mock, patch
@ -19,6 +20,7 @@ from homeassistant.components.websocket_api.auth import (
)
from homeassistant.components.websocket_api.commands import (
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE,
ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE,
)
from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL
from homeassistant.config_entries import ConfigEntryState
@ -28,9 +30,10 @@ from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.loader import async_get_integration
from homeassistant.loader import Integration, async_get_integration
from homeassistant.setup import async_set_domains_to_be_loaded, async_setup_component
from homeassistant.util.json import json_loads
from homeassistant.util.yaml.loader import parse_yaml
from tests.common import (
MockConfigEntry,
@ -707,6 +710,91 @@ async def test_get_services(
assert hass.data[ALL_SERVICE_DESCRIPTIONS_JSON_CACHE] is old_cache
@patch("annotatedyaml.loader.load_yaml")
@patch.object(Integration, "has_triggers", return_value=True)
async def test_subscribe_triggers(
mock_has_triggers: Mock,
mock_load_yaml: Mock,
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
) -> None:
"""Test trigger_platforms/subscribe command."""
sun_trigger_descriptions = """
sun: {}
"""
tag_trigger_descriptions = """
tag: {}
"""
def _load_yaml(fname, secrets=None):
if fname.endswith("sun/triggers.yaml"):
trigger_descriptions = sun_trigger_descriptions
elif fname.endswith("tag/triggers.yaml"):
trigger_descriptions = tag_trigger_descriptions
else:
raise FileNotFoundError
with io.StringIO(trigger_descriptions) as file:
return parse_yaml(file)
mock_load_yaml.side_effect = _load_yaml
assert await async_setup_component(hass, "sun", {})
assert await async_setup_component(hass, "system_health", {})
await hass.async_block_till_done()
assert ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE not in hass.data
await websocket_client.send_json_auto_id({"type": "trigger_platforms/subscribe"})
# Test start subscription with initial event
msg = await websocket_client.receive_json()
assert msg == {"id": 1, "result": None, "success": True, "type": "result"}
msg = await websocket_client.receive_json()
assert msg == {"event": {"sun": {"fields": {}}}, "id": 1, "type": "event"}
old_cache = hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE]
# Test we receive an event when a new platform is loaded, if it has descriptions
assert await async_setup_component(hass, "calendar", {})
assert await async_setup_component(hass, "tag", {})
await hass.async_block_till_done()
msg = await websocket_client.receive_json()
assert msg == {
"event": {"tag": {"fields": {}}},
"id": 1,
"type": "event",
}
# Initiate a second subscription to check the cache is updated because of the new
# trigger
await websocket_client.send_json_auto_id({"type": "trigger_platforms/subscribe"})
msg = await websocket_client.receive_json()
assert msg == {"id": 2, "result": None, "success": True, "type": "result"}
msg = await websocket_client.receive_json()
assert msg == {
"event": {"sun": {"fields": {}}, "tag": {"fields": {}}},
"id": 2,
"type": "event",
}
assert hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] is not old_cache
# Initiate a third subscription to check the cache is not updated because no new
# trigger was added
old_cache = hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE]
await websocket_client.send_json_auto_id({"type": "trigger_platforms/subscribe"})
msg = await websocket_client.receive_json()
assert msg == {"id": 3, "result": None, "success": True, "type": "result"}
msg = await websocket_client.receive_json()
assert msg == {
"event": {"sun": {"fields": {}}, "tag": {"fields": {}}},
"id": 3,
"type": "event",
}
assert hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] is old_cache
async def test_get_config(
hass: HomeAssistant, websocket_client: MockHAClientWebSocket
) -> None:

View File

@ -1,10 +1,15 @@
"""The tests for the trigger helper."""
import io
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch
import pytest
from pytest_unordered import unordered
import voluptuous as vol
from homeassistant.components.sun import DOMAIN as DOMAIN_SUN
from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH
from homeassistant.components.tag import DOMAIN as DOMAIN_TAG
from homeassistant.core import (
CALLBACK_TYPE,
Context,
@ -12,6 +17,8 @@ from homeassistant.core import (
ServiceCall,
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import trigger
from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS,
PluggableAction,
@ -23,9 +30,11 @@ from homeassistant.helpers.trigger import (
async_validate_trigger_config,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import Integration, async_get_integration
from homeassistant.setup import async_setup_component
from homeassistant.util.yaml.loader import parse_yaml
from tests.common import MockModule, mock_integration, mock_platform
from tests.common import MockModule, MockPlatform, mock_integration, mock_platform
async def test_bad_trigger_platform(hass: HomeAssistant) -> None:
@ -519,3 +528,213 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
with pytest.raises(KeyError):
await async_initialize_triggers(hass, config_3, cb_action, "test", "", log_cb)
@pytest.mark.parametrize(
"sun_trigger_descriptions",
[
"""
sun:
fields:
event:
example: sunrise
selector:
select:
options:
- sunrise
- sunset
offset:
selector:
time: null
""",
"""
.anchor: &anchor
- sunrise
- sunset
sun:
fields:
event:
example: sunrise
selector:
select:
options: *anchor
offset:
selector:
time: null
""",
],
)
async def test_async_get_all_descriptions(
hass: HomeAssistant, sun_trigger_descriptions: str
) -> None:
"""Test async_get_all_descriptions."""
tag_trigger_descriptions = """
tag: {}
"""
assert await async_setup_component(hass, DOMAIN_SUN, {})
assert await async_setup_component(hass, DOMAIN_SYSTEM_HEALTH, {})
await hass.async_block_till_done()
def _load_yaml(fname, secrets=None):
if fname.endswith("sun/triggers.yaml"):
trigger_descriptions = sun_trigger_descriptions
elif fname.endswith("tag/triggers.yaml"):
trigger_descriptions = tag_trigger_descriptions
with io.StringIO(trigger_descriptions) as file:
return parse_yaml(file)
with (
patch(
"homeassistant.helpers.trigger._load_triggers_files",
side_effect=trigger._load_triggers_files,
) as proxy_load_triggers_files,
patch(
"annotatedyaml.loader.load_yaml",
side_effect=_load_yaml,
),
patch.object(Integration, "has_triggers", return_value=True),
):
descriptions = await trigger.async_get_all_descriptions(hass)
# Test we only load triggers.yaml for integrations with triggers,
# system_health has no triggers
assert proxy_load_triggers_files.mock_calls[0][1][1] == unordered(
[
await async_get_integration(hass, DOMAIN_SUN),
]
)
# system_health does not have triggers and should not be in descriptions
assert descriptions == {
DOMAIN_SUN: {
"fields": {
"event": {
"example": "sunrise",
"selector": {"select": {"options": ["sunrise", "sunset"]}},
},
"offset": {"selector": {"time": None}},
}
}
}
# Verify the cache returns the same object
assert await trigger.async_get_all_descriptions(hass) is descriptions
# Load the tag integration and check a new cache object is created
assert await async_setup_component(hass, DOMAIN_TAG, {})
await hass.async_block_till_done()
with (
patch(
"annotatedyaml.loader.load_yaml",
side_effect=_load_yaml,
),
patch.object(Integration, "has_triggers", return_value=True),
):
new_descriptions = await trigger.async_get_all_descriptions(hass)
assert new_descriptions is not descriptions
assert new_descriptions == {
DOMAIN_SUN: {
"fields": {
"event": {
"example": "sunrise",
"selector": {"select": {"options": ["sunrise", "sunset"]}},
},
"offset": {"selector": {"time": None}},
}
},
DOMAIN_TAG: {
"fields": {},
},
}
# Verify the cache returns the same object
assert await trigger.async_get_all_descriptions(hass) is new_descriptions
@pytest.mark.parametrize(
("yaml_error", "expected_message"),
[
(
FileNotFoundError("Blah"),
"Unable to find triggers.yaml for the sun integration",
),
(
HomeAssistantError("Test error"),
"Unable to parse triggers.yaml for the sun integration: Test error",
),
],
)
async def test_async_get_all_descriptions_with_yaml_error(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
yaml_error: Exception,
expected_message: str,
) -> None:
"""Test async_get_all_descriptions."""
assert await async_setup_component(hass, DOMAIN_SUN, {})
await hass.async_block_till_done()
def _load_yaml_dict(fname, secrets=None):
raise yaml_error
with (
patch(
"homeassistant.helpers.trigger.load_yaml_dict",
side_effect=_load_yaml_dict,
),
patch.object(Integration, "has_triggers", return_value=True),
):
descriptions = await trigger.async_get_all_descriptions(hass)
assert descriptions == {DOMAIN_SUN: None}
assert expected_message in caplog.text
async def test_async_get_all_descriptions_with_bad_description(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test async_get_all_descriptions."""
sun_service_descriptions = """
sun:
fields: not_a_dict
"""
assert await async_setup_component(hass, DOMAIN_SUN, {})
await hass.async_block_till_done()
def _load_yaml(fname, secrets=None):
with io.StringIO(sun_service_descriptions) as file:
return parse_yaml(file)
with (
patch(
"annotatedyaml.loader.load_yaml",
side_effect=_load_yaml,
),
patch.object(Integration, "has_triggers", return_value=True),
):
descriptions = await trigger.async_get_all_descriptions(hass)
assert descriptions == {DOMAIN_SUN: None}
assert (
"Unable to parse triggers.yaml for the sun integration: "
"expected a dictionary for dictionary value @ data['sun']['fields']"
) in caplog.text
async def test_invalid_trigger_platform(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test invalid trigger platform."""
mock_integration(hass, MockModule("test", async_setup=AsyncMock(return_value=True)))
mock_platform(hass, "test.trigger", MockPlatform())
await async_setup_component(hass, "test", {})
assert "Integration test does not provide trigger support, skipping" in caplog.text