diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 810c1f1e8d2..afe8ea6f356 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -89,6 +89,7 @@ from .helpers import ( restore_state, template, translation, + trigger, ) from .helpers.dispatcher import async_dispatcher_send_internal from .helpers.storage import get_internal_store_manager @@ -452,6 +453,7 @@ async def async_load_base_functionality(hass: core.HomeAssistant) -> None: create_eager_task(restore_state.async_load(hass)), create_eager_task(hass.config_entries.async_initialize()), create_eager_task(async_get_system_info(hass)), + create_eager_task(trigger.async_setup(hass)), ) diff --git a/homeassistant/components/mqtt/icons.json b/homeassistant/components/mqtt/icons.json index 73cbf22b629..46a588a5667 100644 --- a/homeassistant/components/mqtt/icons.json +++ b/homeassistant/components/mqtt/icons.json @@ -9,5 +9,10 @@ "reload": { "service": "mdi:reload" } + }, + "triggers": { + "mqtt": { + "trigger": "mdi:swap-horizontal" + } } } diff --git a/homeassistant/components/mqtt/strings.json b/homeassistant/components/mqtt/strings.json index ed7da6fc112..9c7a2fcea96 100644 --- a/homeassistant/components/mqtt/strings.json +++ b/homeassistant/components/mqtt/strings.json @@ -992,6 +992,23 @@ "description": "Reloads MQTT entities from the YAML-configuration." } }, + "triggers": { + "mqtt": { + "name": "MQTT", + "description": "When a specific message is received on a given MQTT topic.", + "description_configured": "When an MQTT message has been received", + "fields": { + "payload": { + "name": "Payload", + "description": "The payload to trigger on." + }, + "topic": { + "name": "Topic", + "description": "MQTT topic to listen to." + } + } + } + }, "exceptions": { "addon_start_failed": { "message": "Failed to correctly start {addon} add-on." diff --git a/homeassistant/components/mqtt/triggers.yaml b/homeassistant/components/mqtt/triggers.yaml new file mode 100644 index 00000000000..d3998674d58 --- /dev/null +++ b/homeassistant/components/mqtt/triggers.yaml @@ -0,0 +1,14 @@ +# Describes the format for MQTT triggers + +mqtt: + fields: + payload: + example: "on" + required: false + selector: + text: + topic: + example: "living_room/switch/ac" + required: true + selector: + text: diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 498a986e806..701a9a659b1 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -52,7 +52,13 @@ from homeassistant.helpers.json import ( json_bytes, json_fragment, ) -from homeassistant.helpers.service import async_get_all_descriptions +from homeassistant.helpers.service import ( + async_get_all_descriptions as async_get_all_service_descriptions, +) +from homeassistant.helpers.trigger import ( + async_get_all_descriptions as async_get_all_trigger_descriptions, + async_subscribe_platform_events as async_subscribe_trigger_platform_events, +) from homeassistant.loader import ( IntegrationNotFound, async_get_integration, @@ -68,9 +74,10 @@ from homeassistant.util.json import format_unserializable_data from . import const, decorators, messages from .connection import ActiveConnection -from .messages import construct_result_message +from .messages import construct_event_message, construct_result_message ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json" +ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_trigger_descriptions_json" _LOGGER = logging.getLogger(__name__) @@ -96,6 +103,7 @@ def async_register_commands( async_reg(hass, handle_subscribe_bootstrap_integrations) async_reg(hass, handle_subscribe_events) async_reg(hass, handle_subscribe_trigger) + async_reg(hass, handle_subscribe_trigger_platforms) async_reg(hass, handle_test_condition) async_reg(hass, handle_unsubscribe_events) async_reg(hass, handle_validate_config) @@ -493,9 +501,9 @@ def _send_handle_entities_init_response( ) -async def _async_get_all_descriptions_json(hass: HomeAssistant) -> bytes: +async def _async_get_all_service_descriptions_json(hass: HomeAssistant) -> bytes: """Return JSON of descriptions (i.e. user documentation) for all service calls.""" - descriptions = await async_get_all_descriptions(hass) + descriptions = await async_get_all_service_descriptions(hass) if ALL_SERVICE_DESCRIPTIONS_JSON_CACHE in hass.data: cached_descriptions, cached_json_payload = hass.data[ ALL_SERVICE_DESCRIPTIONS_JSON_CACHE @@ -514,10 +522,57 @@ async def handle_get_services( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Handle get services command.""" - payload = await _async_get_all_descriptions_json(hass) + payload = await _async_get_all_service_descriptions_json(hass) connection.send_message(construct_result_message(msg["id"], payload)) +async def _async_get_all_trigger_descriptions_json(hass: HomeAssistant) -> bytes: + """Return JSON of descriptions (i.e. user documentation) for all triggers.""" + descriptions = await async_get_all_trigger_descriptions(hass) + if ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE in hass.data: + cached_descriptions, cached_json_payload = hass.data[ + ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE + ] + # If the descriptions are the same, return the cached JSON payload + if cached_descriptions is descriptions: + return cast(bytes, cached_json_payload) + json_payload = json_bytes( + { + trigger: description + for trigger, description in descriptions.items() + if description is not None + } + ) + hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] = (descriptions, json_payload) + return json_payload + + +@decorators.websocket_command({vol.Required("type"): "trigger_platforms/subscribe"}) +@decorators.async_response +async def handle_subscribe_trigger_platforms( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle subscribe triggers command.""" + + async def on_new_triggers(new_triggers: set[str]) -> None: + """Forward new triggers to websocket.""" + descriptions = await async_get_all_trigger_descriptions(hass) + new_trigger_descriptions = {} + for trigger in new_triggers: + if (description := descriptions[trigger]) is not None: + new_trigger_descriptions[trigger] = description + if not new_trigger_descriptions: + return + connection.send_event(msg["id"], new_trigger_descriptions) + + connection.subscriptions[msg["id"]] = async_subscribe_trigger_platform_events( + hass, on_new_triggers + ) + connection.send_result(msg["id"]) + triggers_json = await _async_get_all_trigger_descriptions_json(hass) + connection.send_message(construct_event_message(msg["id"], triggers_json)) + + @callback @decorators.websocket_command({vol.Required("type"): "get_config"}) def handle_get_config( diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index 6ae7de2c4b7..88d29f243d5 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -109,6 +109,19 @@ def event_message(iden: int, event: Any) -> dict[str, Any]: return {"id": iden, "type": "event", "event": event} +def construct_event_message(iden: int, event: bytes) -> bytes: + """Construct an event message JSON.""" + return b"".join( + ( + b'{"id":', + str(iden).encode(), + b',"type":"event","event":', + event, + b"}", + ) + ) + + def cached_event_message(message_id_as_bytes: bytes, event: Event) -> bytes: """Return an event message. diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 853b5aaf812..66d1560ac70 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -5,11 +5,11 @@ from __future__ import annotations import abc import asyncio from collections import defaultdict -from collections.abc import Callable, Coroutine +from collections.abc import Callable, Coroutine, Iterable from dataclasses import dataclass, field import functools import logging -from typing import Any, Protocol, TypedDict, cast +from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast import voluptuous as vol @@ -29,13 +29,24 @@ from homeassistant.core import ( is_callback, ) from homeassistant.exceptions import HomeAssistantError, TemplateError -from homeassistant.loader import IntegrationNotFound, async_get_integration +from homeassistant.loader import ( + Integration, + IntegrationNotFound, + async_get_integration, + async_get_integrations, +) from homeassistant.util.async_ import create_eager_task from homeassistant.util.hass_dict import HassKey +from homeassistant.util.yaml import load_yaml_dict +from homeassistant.util.yaml.loader import JSON_TYPE +from . import config_validation as cv +from .integration_platform import async_process_integration_platforms from .template import Template from .typing import ConfigType, TemplateVarsType +_LOGGER = logging.getLogger(__name__) + _PLATFORM_ALIASES = { "device": "device_automation", "event": "homeassistant", @@ -49,6 +60,99 @@ DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = Has "pluggable_actions" ) +TRIGGER_DESCRIPTION_CACHE: HassKey[dict[str, dict[str, Any] | None]] = HassKey( + "trigger_description_cache" +) +TRIGGER_PLATFORM_SUBSCRIPTIONS: HassKey[ + list[Callable[[set[str]], Coroutine[Any, Any, None]]] +] = HassKey("trigger_platform_subscriptions") +TRIGGERS: HassKey[dict[str, str]] = HassKey("triggers") + + +# Basic schemas to sanity check the trigger descriptions, +# full validation is done by hassfest.triggers +_FIELD_SCHEMA = vol.Schema( + {}, + extra=vol.ALLOW_EXTRA, +) + +_TRIGGER_SCHEMA = vol.Schema( + { + vol.Optional("fields"): vol.Schema({str: _FIELD_SCHEMA}), + }, + extra=vol.ALLOW_EXTRA, +) + + +def starts_with_dot(key: str) -> str: + """Check if key starts with dot.""" + if not key.startswith("."): + raise vol.Invalid("Key does not start with .") + return key + + +_TRIGGERS_SCHEMA = vol.Schema( + { + vol.Remove(vol.All(str, starts_with_dot)): object, + cv.slug: vol.Any(None, _TRIGGER_SCHEMA), + } +) + + +async def async_setup(hass: HomeAssistant) -> None: + """Set up the trigger helper.""" + hass.data[TRIGGER_DESCRIPTION_CACHE] = {} + hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] = [] + hass.data[TRIGGERS] = {} + await async_process_integration_platforms( + hass, "trigger", _register_trigger_platform, wait_for_platforms=True + ) + + +@callback +def async_subscribe_platform_events( + hass: HomeAssistant, + on_event: Callable[[set[str]], Coroutine[Any, Any, None]], +) -> Callable[[], None]: + """Subscribe to trigger platform events.""" + trigger_platform_event_subscriptions = hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] + + def remove_subscription() -> None: + trigger_platform_event_subscriptions.remove(on_event) + + trigger_platform_event_subscriptions.append(on_event) + return remove_subscription + + +async def _register_trigger_platform( + hass: HomeAssistant, integration_domain: str, platform: TriggerProtocol +) -> None: + """Register a trigger platform.""" + + new_triggers: set[str] = set() + + if hasattr(platform, "async_get_triggers"): + for trigger_key in await platform.async_get_triggers(hass): + hass.data[TRIGGERS][trigger_key] = integration_domain + new_triggers.add(trigger_key) + elif hasattr(platform, "async_validate_trigger_config") or hasattr( + platform, "TRIGGER_SCHEMA" + ): + hass.data[TRIGGERS][integration_domain] = integration_domain + new_triggers.add(integration_domain) + else: + _LOGGER.debug( + "Integration %s does not provide trigger support, skipping", + integration_domain, + ) + return + + tasks: list[asyncio.Task[None]] = [ + create_eager_task(listener(new_triggers)) + for listener in hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] + ] + await asyncio.gather(*tasks) + class Trigger(abc.ABC): """Trigger class.""" @@ -409,3 +513,107 @@ async def async_initialize_triggers( remove() return remove_triggers + + +def _load_triggers_file(hass: HomeAssistant, integration: Integration) -> JSON_TYPE: + """Load triggers file for an integration.""" + try: + return cast( + JSON_TYPE, + _TRIGGERS_SCHEMA( + load_yaml_dict(str(integration.file_path / "triggers.yaml")) + ), + ) + except FileNotFoundError: + _LOGGER.warning( + "Unable to find triggers.yaml for the %s integration", integration.domain + ) + return {} + except (HomeAssistantError, vol.Invalid) as ex: + _LOGGER.warning( + "Unable to parse triggers.yaml for the %s integration: %s", + integration.domain, + ex, + ) + return {} + + +def _load_triggers_files( + hass: HomeAssistant, integrations: Iterable[Integration] +) -> dict[str, JSON_TYPE]: + """Load trigger files for multiple integrations.""" + return { + integration.domain: _load_triggers_file(hass, integration) + for integration in integrations + } + + +async def async_get_all_descriptions( + hass: HomeAssistant, +) -> dict[str, dict[str, Any] | None]: + """Return descriptions (i.e. user documentation) for all triggers.""" + descriptions_cache = hass.data[TRIGGER_DESCRIPTION_CACHE] + + triggers = hass.data[TRIGGERS] + # See if there are new triggers not seen before. + # Any trigger that we saw before already has an entry in description_cache. + all_triggers = set(triggers) + previous_all_triggers = set(descriptions_cache) + # If the triggers are the same, we can return the cache + if previous_all_triggers == all_triggers: + return descriptions_cache + + # Files we loaded for missing descriptions + new_triggers_descriptions: dict[str, JSON_TYPE] = {} + # We try to avoid making a copy in the event the cache is good, + # but now we must make a copy in case new triggers get added + # while we are loading the missing ones so we do not + # add the new ones to the cache without their descriptions + triggers = triggers.copy() + + if missing_triggers := all_triggers.difference(descriptions_cache): + domains_with_missing_triggers = { + triggers[missing_trigger] for missing_trigger in missing_triggers + } + ints_or_excs = await async_get_integrations(hass, domains_with_missing_triggers) + integrations: list[Integration] = [] + for domain, int_or_exc in ints_or_excs.items(): + if type(int_or_exc) is Integration and int_or_exc.has_triggers: + integrations.append(int_or_exc) + continue + if TYPE_CHECKING: + assert isinstance(int_or_exc, Exception) + _LOGGER.debug( + "Failed to load triggers.yaml for integration: %s", + domain, + exc_info=int_or_exc, + ) + + if integrations: + new_triggers_descriptions = await hass.async_add_executor_job( + _load_triggers_files, hass, integrations + ) + + # Make a copy of the old cache and add missing descriptions to it + new_descriptions_cache = descriptions_cache.copy() + for missing_trigger in missing_triggers: + domain = triggers[missing_trigger] + + if ( + yaml_description := new_triggers_descriptions.get(domain, {}).get( # type: ignore[union-attr] + missing_trigger + ) + ) is None: + _LOGGER.debug( + "No trigger descriptions found for trigger %s, skipping", + missing_trigger, + ) + new_descriptions_cache[missing_trigger] = None + continue + + description = {"fields": yaml_description.get("fields", {})} + + new_descriptions_cache[missing_trigger] = description + + hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache + return new_descriptions_cache diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 6a3061b0d2a..ae3709e383b 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -857,15 +857,20 @@ class Integration: # True. return self.manifest.get("import_executor", True) + @cached_property + def has_services(self) -> bool: + """Return if the integration has services.""" + return "services.yaml" in self._top_level_files + @cached_property def has_translations(self) -> bool: """Return if the integration has translations.""" return "translations" in self._top_level_files @cached_property - def has_services(self) -> bool: - """Return if the integration has services.""" - return "services.yaml" in self._top_level_files + def has_triggers(self) -> bool: + """Return if the integration has triggers.""" + return "triggers.yaml" in self._top_level_files @property def mqtt(self) -> list[str] | None: diff --git a/script/hassfest/__main__.py b/script/hassfest/__main__.py index 277696c669b..05c0d455af6 100644 --- a/script/hassfest/__main__.py +++ b/script/hassfest/__main__.py @@ -28,6 +28,7 @@ from . import ( services, ssdp, translations, + triggers, usb, zeroconf, ) @@ -49,6 +50,7 @@ INTEGRATION_PLUGINS = [ services, ssdp, translations, + triggers, usb, zeroconf, config_flow, # This needs to run last, after translations are processed diff --git a/script/hassfest/icons.py b/script/hassfest/icons.py index 563fe0edb93..6abe338e45b 100644 --- a/script/hassfest/icons.py +++ b/script/hassfest/icons.py @@ -120,6 +120,16 @@ CUSTOM_INTEGRATION_SERVICE_ICONS_SCHEMA = cv.schema_with_slug_keys( ) +TRIGGER_ICONS_SCHEMA = cv.schema_with_slug_keys( + vol.Schema( + { + vol.Optional("trigger"): icon_value_validator, + } + ), + slug_validator=translation_key_validator, +) + + def icon_schema( core_integration: bool, integration_type: str, no_entity_platform: bool ) -> vol.Schema: @@ -164,6 +174,7 @@ def icon_schema( vol.Optional("services"): CORE_SERVICE_ICONS_SCHEMA if core_integration else CUSTOM_INTEGRATION_SERVICE_ICONS_SCHEMA, + vol.Optional("triggers"): TRIGGER_ICONS_SCHEMA, } ) diff --git a/script/hassfest/translations.py b/script/hassfest/translations.py index 34c06abb451..913f7df2e7a 100644 --- a/script/hassfest/translations.py +++ b/script/hassfest/translations.py @@ -416,6 +416,22 @@ def gen_strings_schema(config: Config, integration: Integration) -> vol.Schema: }, slug_validator=translation_key_validator, ), + vol.Optional("triggers"): cv.schema_with_slug_keys( + { + vol.Required("name"): translation_value_validator, + vol.Required("description"): translation_value_validator, + vol.Required("description_configured"): translation_value_validator, + vol.Optional("fields"): cv.schema_with_slug_keys( + { + vol.Required("name"): str, + vol.Required("description"): translation_value_validator, + vol.Optional("example"): translation_value_validator, + }, + slug_validator=translation_key_validator, + ), + }, + slug_validator=translation_key_validator, + ), vol.Optional("conversation"): { vol.Required("agent"): { vol.Required("done"): translation_value_validator, diff --git a/script/hassfest/triggers.py b/script/hassfest/triggers.py new file mode 100644 index 00000000000..ff6654f2789 --- /dev/null +++ b/script/hassfest/triggers.py @@ -0,0 +1,238 @@ +"""Validate triggers.""" + +from __future__ import annotations + +import contextlib +import json +import pathlib +import re +from typing import Any + +import voluptuous as vol +from voluptuous.humanize import humanize_error + +from homeassistant.const import CONF_SELECTOR +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_validation as cv, selector, trigger +from homeassistant.util.yaml import load_yaml_dict + +from .model import Config, Integration + + +def exists(value: Any) -> Any: + """Check if value exists.""" + if value is None: + raise vol.Invalid("Value cannot be None") + return value + + +FIELD_SCHEMA = vol.Schema( + { + vol.Optional("example"): exists, + vol.Optional("default"): exists, + vol.Optional("required"): bool, + vol.Optional(CONF_SELECTOR): selector.validate_selector, + } +) + +TRIGGER_SCHEMA = vol.Any( + vol.Schema( + { + vol.Optional("fields"): vol.Schema({str: FIELD_SCHEMA}), + } + ), + None, +) + +TRIGGERS_SCHEMA = vol.Schema( + { + vol.Remove(vol.All(str, trigger.starts_with_dot)): object, + cv.slug: TRIGGER_SCHEMA, + } +) + +NON_MIGRATED_INTEGRATIONS = { + "calendar", + "conversation", + "device_automation", + "geo_location", + "homeassistant", + "knx", + "lg_netcast", + "litejet", + "persistent_notification", + "samsungtv", + "sun", + "tag", + "template", + "webhook", + "webostv", + "zone", + "zwave_js", +} + + +def grep_dir(path: pathlib.Path, glob_pattern: str, search_pattern: str) -> bool: + """Recursively go through a dir and it's children and find the regex.""" + pattern = re.compile(search_pattern) + + for fil in path.glob(glob_pattern): + if not fil.is_file(): + continue + + if pattern.search(fil.read_text()): + return True + + return False + + +def validate_triggers(config: Config, integration: Integration) -> None: # noqa: C901 + """Validate triggers.""" + try: + data = load_yaml_dict(str(integration.path / "triggers.yaml")) + except FileNotFoundError: + # Find if integration uses triggers + has_triggers = grep_dir( + integration.path, + "**/trigger.py", + r"async_attach_trigger|async_get_triggers", + ) + + if has_triggers and integration.domain not in NON_MIGRATED_INTEGRATIONS: + integration.add_error( + "triggers", "Registers triggers but has no triggers.yaml" + ) + return + except HomeAssistantError: + integration.add_error("triggers", "Invalid triggers.yaml") + return + + try: + triggers = TRIGGERS_SCHEMA(data) + except vol.Invalid as err: + integration.add_error( + "triggers", f"Invalid triggers.yaml: {humanize_error(data, err)}" + ) + return + + icons_file = integration.path / "icons.json" + icons = {} + if icons_file.is_file(): + with contextlib.suppress(ValueError): + icons = json.loads(icons_file.read_text()) + trigger_icons = icons.get("triggers", {}) + + # Try loading translation strings + if integration.core: + strings_file = integration.path / "strings.json" + else: + # For custom integrations, use the en.json file + strings_file = integration.path / "translations/en.json" + + strings = {} + if strings_file.is_file(): + with contextlib.suppress(ValueError): + strings = json.loads(strings_file.read_text()) + + error_msg_suffix = "in the translations file" + if not integration.core: + error_msg_suffix = f"and is not {error_msg_suffix}" + + # For each trigger in the integration: + # 1. Check if the trigger description is set, if not, + # check if it's in the strings file else add an error. + # 2. Check if the trigger has an icon set in icons.json. + # raise an error if not., + for trigger_name, trigger_schema in triggers.items(): + if integration.core and trigger_name not in trigger_icons: + # This is enforced for Core integrations only + integration.add_error( + "triggers", + f"Trigger {trigger_name} has no icon in icons.json.", + ) + if trigger_schema is None: + continue + if "name" not in trigger_schema and integration.core: + try: + strings["triggers"][trigger_name]["name"] + except KeyError: + integration.add_error( + "triggers", + f"Trigger {trigger_name} has no name {error_msg_suffix}", + ) + + if "description" not in trigger_schema and integration.core: + try: + strings["triggers"][trigger_name]["description"] + except KeyError: + integration.add_error( + "triggers", + f"Trigger {trigger_name} has no description {error_msg_suffix}", + ) + + # The same check is done for the description in each of the fields of the + # trigger schema. + for field_name, field_schema in trigger_schema.get("fields", {}).items(): + if "fields" in field_schema: + # This is a section + continue + if "name" not in field_schema and integration.core: + try: + strings["triggers"][trigger_name]["fields"][field_name]["name"] + except KeyError: + integration.add_error( + "triggers", + ( + f"Trigger {trigger_name} has a field {field_name} with no " + f"name {error_msg_suffix}" + ), + ) + + if "description" not in field_schema and integration.core: + try: + strings["triggers"][trigger_name]["fields"][field_name][ + "description" + ] + except KeyError: + integration.add_error( + "triggers", + ( + f"Trigger {trigger_name} has a field {field_name} with no " + f"description {error_msg_suffix}" + ), + ) + + if "selector" in field_schema: + with contextlib.suppress(KeyError): + translation_key = field_schema["selector"]["select"][ + "translation_key" + ] + try: + strings["selector"][translation_key] + except KeyError: + integration.add_error( + "triggers", + f"Trigger {trigger_name} has a field {field_name} with a selector with a translation key {translation_key} that is not in the translations file", + ) + + # The same check is done for the description in each of the sections of the + # trigger schema. + for section_name, section_schema in trigger_schema.get("fields", {}).items(): + if "fields" not in section_schema: + # This is not a section + continue + if "name" not in section_schema and integration.core: + try: + strings["triggers"][trigger_name]["sections"][section_name]["name"] + except KeyError: + integration.add_error( + "triggers", + f"Trigger {trigger_name} has a section {section_name} with no name {error_msg_suffix}", + ) + + +def validate(integrations: dict[str, Integration], config: Config) -> None: + """Handle dependencies for integrations.""" + # check triggers.yaml is valid + for integration in integrations.values(): + validate_triggers(config, integration) diff --git a/tests/common.py b/tests/common.py index 40d6e4d79d3..ff64dcb33a7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -87,6 +87,7 @@ from homeassistant.helpers import ( restore_state as rs, storage, translation, + trigger, ) from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, @@ -295,6 +296,7 @@ async def async_test_home_assistant( # Load the registries entity.async_setup(hass) loader.async_setup(hass) + await trigger.async_setup(hass) # setup translation cache instead of calling translation.async_setup(hass) hass.data[translation.TRANSLATION_FLATTEN_CACHE] = translation._TranslationCache( diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 6e4fa34ed26..bfb8c917f71 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -2,6 +2,7 @@ import asyncio from copy import deepcopy +import io import logging from typing import Any from unittest.mock import ANY, AsyncMock, Mock, patch @@ -19,6 +20,7 @@ from homeassistant.components.websocket_api.auth import ( ) from homeassistant.components.websocket_api.commands import ( ALL_SERVICE_DESCRIPTIONS_JSON_CACHE, + ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE, ) from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL from homeassistant.config_entries import ConfigEntryState @@ -28,9 +30,10 @@ from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.helpers import device_registry as dr from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_track_state_change_event -from homeassistant.loader import async_get_integration +from homeassistant.loader import Integration, async_get_integration from homeassistant.setup import async_set_domains_to_be_loaded, async_setup_component from homeassistant.util.json import json_loads +from homeassistant.util.yaml.loader import parse_yaml from tests.common import ( MockConfigEntry, @@ -707,6 +710,91 @@ async def test_get_services( assert hass.data[ALL_SERVICE_DESCRIPTIONS_JSON_CACHE] is old_cache +@patch("annotatedyaml.loader.load_yaml") +@patch.object(Integration, "has_triggers", return_value=True) +async def test_subscribe_triggers( + mock_has_triggers: Mock, + mock_load_yaml: Mock, + hass: HomeAssistant, + websocket_client: MockHAClientWebSocket, +) -> None: + """Test trigger_platforms/subscribe command.""" + sun_trigger_descriptions = """ + sun: {} + """ + tag_trigger_descriptions = """ + tag: {} + """ + + def _load_yaml(fname, secrets=None): + if fname.endswith("sun/triggers.yaml"): + trigger_descriptions = sun_trigger_descriptions + elif fname.endswith("tag/triggers.yaml"): + trigger_descriptions = tag_trigger_descriptions + else: + raise FileNotFoundError + with io.StringIO(trigger_descriptions) as file: + return parse_yaml(file) + + mock_load_yaml.side_effect = _load_yaml + + assert await async_setup_component(hass, "sun", {}) + assert await async_setup_component(hass, "system_health", {}) + await hass.async_block_till_done() + + assert ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE not in hass.data + + await websocket_client.send_json_auto_id({"type": "trigger_platforms/subscribe"}) + + # Test start subscription with initial event + msg = await websocket_client.receive_json() + assert msg == {"id": 1, "result": None, "success": True, "type": "result"} + msg = await websocket_client.receive_json() + assert msg == {"event": {"sun": {"fields": {}}}, "id": 1, "type": "event"} + + old_cache = hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] + + # Test we receive an event when a new platform is loaded, if it has descriptions + assert await async_setup_component(hass, "calendar", {}) + assert await async_setup_component(hass, "tag", {}) + await hass.async_block_till_done() + msg = await websocket_client.receive_json() + assert msg == { + "event": {"tag": {"fields": {}}}, + "id": 1, + "type": "event", + } + + # Initiate a second subscription to check the cache is updated because of the new + # trigger + await websocket_client.send_json_auto_id({"type": "trigger_platforms/subscribe"}) + msg = await websocket_client.receive_json() + assert msg == {"id": 2, "result": None, "success": True, "type": "result"} + msg = await websocket_client.receive_json() + assert msg == { + "event": {"sun": {"fields": {}}, "tag": {"fields": {}}}, + "id": 2, + "type": "event", + } + + assert hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] is not old_cache + + # Initiate a third subscription to check the cache is not updated because no new + # trigger was added + old_cache = hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] + await websocket_client.send_json_auto_id({"type": "trigger_platforms/subscribe"}) + msg = await websocket_client.receive_json() + assert msg == {"id": 3, "result": None, "success": True, "type": "result"} + msg = await websocket_client.receive_json() + assert msg == { + "event": {"sun": {"fields": {}}, "tag": {"fields": {}}}, + "id": 3, + "type": "event", + } + + assert hass.data[ALL_TRIGGER_DESCRIPTIONS_JSON_CACHE] is old_cache + + async def test_get_config( hass: HomeAssistant, websocket_client: MockHAClientWebSocket ) -> None: diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index f5a2b549f89..27cde92d14f 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -1,10 +1,15 @@ """The tests for the trigger helper.""" +import io from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest +from pytest_unordered import unordered import voluptuous as vol +from homeassistant.components.sun import DOMAIN as DOMAIN_SUN +from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH +from homeassistant.components.tag import DOMAIN as DOMAIN_TAG from homeassistant.core import ( CALLBACK_TYPE, Context, @@ -12,6 +17,8 @@ from homeassistant.core import ( ServiceCall, callback, ) +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import trigger from homeassistant.helpers.trigger import ( DATA_PLUGGABLE_ACTIONS, PluggableAction, @@ -23,9 +30,11 @@ from homeassistant.helpers.trigger import ( async_validate_trigger_config, ) from homeassistant.helpers.typing import ConfigType +from homeassistant.loader import Integration, async_get_integration from homeassistant.setup import async_setup_component +from homeassistant.util.yaml.loader import parse_yaml -from tests.common import MockModule, mock_integration, mock_platform +from tests.common import MockModule, MockPlatform, mock_integration, mock_platform async def test_bad_trigger_platform(hass: HomeAssistant) -> None: @@ -519,3 +528,213 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None: with pytest.raises(KeyError): await async_initialize_triggers(hass, config_3, cb_action, "test", "", log_cb) + + +@pytest.mark.parametrize( + "sun_trigger_descriptions", + [ + """ + sun: + fields: + event: + example: sunrise + selector: + select: + options: + - sunrise + - sunset + offset: + selector: + time: null + """, + """ + .anchor: &anchor + - sunrise + - sunset + sun: + fields: + event: + example: sunrise + selector: + select: + options: *anchor + offset: + selector: + time: null + """, + ], +) +async def test_async_get_all_descriptions( + hass: HomeAssistant, sun_trigger_descriptions: str +) -> None: + """Test async_get_all_descriptions.""" + tag_trigger_descriptions = """ + tag: {} + """ + + assert await async_setup_component(hass, DOMAIN_SUN, {}) + assert await async_setup_component(hass, DOMAIN_SYSTEM_HEALTH, {}) + await hass.async_block_till_done() + + def _load_yaml(fname, secrets=None): + if fname.endswith("sun/triggers.yaml"): + trigger_descriptions = sun_trigger_descriptions + elif fname.endswith("tag/triggers.yaml"): + trigger_descriptions = tag_trigger_descriptions + with io.StringIO(trigger_descriptions) as file: + return parse_yaml(file) + + with ( + patch( + "homeassistant.helpers.trigger._load_triggers_files", + side_effect=trigger._load_triggers_files, + ) as proxy_load_triggers_files, + patch( + "annotatedyaml.loader.load_yaml", + side_effect=_load_yaml, + ), + patch.object(Integration, "has_triggers", return_value=True), + ): + descriptions = await trigger.async_get_all_descriptions(hass) + + # Test we only load triggers.yaml for integrations with triggers, + # system_health has no triggers + assert proxy_load_triggers_files.mock_calls[0][1][1] == unordered( + [ + await async_get_integration(hass, DOMAIN_SUN), + ] + ) + + # system_health does not have triggers and should not be in descriptions + assert descriptions == { + DOMAIN_SUN: { + "fields": { + "event": { + "example": "sunrise", + "selector": {"select": {"options": ["sunrise", "sunset"]}}, + }, + "offset": {"selector": {"time": None}}, + } + } + } + + # Verify the cache returns the same object + assert await trigger.async_get_all_descriptions(hass) is descriptions + + # Load the tag integration and check a new cache object is created + assert await async_setup_component(hass, DOMAIN_TAG, {}) + await hass.async_block_till_done() + + with ( + patch( + "annotatedyaml.loader.load_yaml", + side_effect=_load_yaml, + ), + patch.object(Integration, "has_triggers", return_value=True), + ): + new_descriptions = await trigger.async_get_all_descriptions(hass) + assert new_descriptions is not descriptions + assert new_descriptions == { + DOMAIN_SUN: { + "fields": { + "event": { + "example": "sunrise", + "selector": {"select": {"options": ["sunrise", "sunset"]}}, + }, + "offset": {"selector": {"time": None}}, + } + }, + DOMAIN_TAG: { + "fields": {}, + }, + } + + # Verify the cache returns the same object + assert await trigger.async_get_all_descriptions(hass) is new_descriptions + + +@pytest.mark.parametrize( + ("yaml_error", "expected_message"), + [ + ( + FileNotFoundError("Blah"), + "Unable to find triggers.yaml for the sun integration", + ), + ( + HomeAssistantError("Test error"), + "Unable to parse triggers.yaml for the sun integration: Test error", + ), + ], +) +async def test_async_get_all_descriptions_with_yaml_error( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + yaml_error: Exception, + expected_message: str, +) -> None: + """Test async_get_all_descriptions.""" + assert await async_setup_component(hass, DOMAIN_SUN, {}) + await hass.async_block_till_done() + + def _load_yaml_dict(fname, secrets=None): + raise yaml_error + + with ( + patch( + "homeassistant.helpers.trigger.load_yaml_dict", + side_effect=_load_yaml_dict, + ), + patch.object(Integration, "has_triggers", return_value=True), + ): + descriptions = await trigger.async_get_all_descriptions(hass) + + assert descriptions == {DOMAIN_SUN: None} + + assert expected_message in caplog.text + + +async def test_async_get_all_descriptions_with_bad_description( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test async_get_all_descriptions.""" + sun_service_descriptions = """ + sun: + fields: not_a_dict + """ + + assert await async_setup_component(hass, DOMAIN_SUN, {}) + await hass.async_block_till_done() + + def _load_yaml(fname, secrets=None): + with io.StringIO(sun_service_descriptions) as file: + return parse_yaml(file) + + with ( + patch( + "annotatedyaml.loader.load_yaml", + side_effect=_load_yaml, + ), + patch.object(Integration, "has_triggers", return_value=True), + ): + descriptions = await trigger.async_get_all_descriptions(hass) + + assert descriptions == {DOMAIN_SUN: None} + + assert ( + "Unable to parse triggers.yaml for the sun integration: " + "expected a dictionary for dictionary value @ data['sun']['fields']" + ) in caplog.text + + +async def test_invalid_trigger_platform( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test invalid trigger platform.""" + mock_integration(hass, MockModule("test", async_setup=AsyncMock(return_value=True))) + mock_platform(hass, "test.trigger", MockPlatform()) + + await async_setup_component(hass, "test", {}) + + assert "Integration test does not provide trigger support, skipping" in caplog.text