mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 02:49:40 +00:00
Allow core integrations to describe their triggers (#147075)
Co-authored-by: Abílio Costa <abmantis@users.noreply.github.com>
This commit is contained in:
@@ -5,11 +5,11 @@ from __future__ import annotations
|
||||
import abc
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -29,13 +29,24 @@ from homeassistant.core import (
|
||||
is_callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
from homeassistant.loader import (
|
||||
Integration,
|
||||
IntegrationNotFound,
|
||||
async_get_integration,
|
||||
async_get_integrations,
|
||||
)
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.yaml import load_yaml_dict
|
||||
from homeassistant.util.yaml.loader import JSON_TYPE
|
||||
|
||||
from . import config_validation as cv
|
||||
from .integration_platform import async_process_integration_platforms
|
||||
from .template import Template
|
||||
from .typing import ConfigType, TemplateVarsType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PLATFORM_ALIASES = {
|
||||
"device": "device_automation",
|
||||
"event": "homeassistant",
|
||||
@@ -49,6 +60,99 @@ DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = Has
|
||||
"pluggable_actions"
|
||||
)
|
||||
|
||||
TRIGGER_DESCRIPTION_CACHE: HassKey[dict[str, dict[str, Any] | None]] = HassKey(
|
||||
"trigger_description_cache"
|
||||
)
|
||||
TRIGGER_PLATFORM_SUBSCRIPTIONS: HassKey[
|
||||
list[Callable[[set[str]], Coroutine[Any, Any, None]]]
|
||||
] = HassKey("trigger_platform_subscriptions")
|
||||
TRIGGERS: HassKey[dict[str, str]] = HassKey("triggers")
|
||||
|
||||
|
||||
# Basic schemas to sanity check the trigger descriptions,
|
||||
# full validation is done by hassfest.triggers
|
||||
_FIELD_SCHEMA = vol.Schema(
|
||||
{},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
_TRIGGER_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional("fields"): vol.Schema({str: _FIELD_SCHEMA}),
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
def starts_with_dot(key: str) -> str:
|
||||
"""Check if key starts with dot."""
|
||||
if not key.startswith("."):
|
||||
raise vol.Invalid("Key does not start with .")
|
||||
return key
|
||||
|
||||
|
||||
_TRIGGERS_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Remove(vol.All(str, starts_with_dot)): object,
|
||||
cv.slug: vol.Any(None, _TRIGGER_SCHEMA),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the trigger helper."""
|
||||
hass.data[TRIGGER_DESCRIPTION_CACHE] = {}
|
||||
hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] = []
|
||||
hass.data[TRIGGERS] = {}
|
||||
await async_process_integration_platforms(
|
||||
hass, "trigger", _register_trigger_platform, wait_for_platforms=True
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_subscribe_platform_events(
|
||||
hass: HomeAssistant,
|
||||
on_event: Callable[[set[str]], Coroutine[Any, Any, None]],
|
||||
) -> Callable[[], None]:
|
||||
"""Subscribe to trigger platform events."""
|
||||
trigger_platform_event_subscriptions = hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS]
|
||||
|
||||
def remove_subscription() -> None:
|
||||
trigger_platform_event_subscriptions.remove(on_event)
|
||||
|
||||
trigger_platform_event_subscriptions.append(on_event)
|
||||
return remove_subscription
|
||||
|
||||
|
||||
async def _register_trigger_platform(
|
||||
hass: HomeAssistant, integration_domain: str, platform: TriggerProtocol
|
||||
) -> None:
|
||||
"""Register a trigger platform."""
|
||||
|
||||
new_triggers: set[str] = set()
|
||||
|
||||
if hasattr(platform, "async_get_triggers"):
|
||||
for trigger_key in await platform.async_get_triggers(hass):
|
||||
hass.data[TRIGGERS][trigger_key] = integration_domain
|
||||
new_triggers.add(trigger_key)
|
||||
elif hasattr(platform, "async_validate_trigger_config") or hasattr(
|
||||
platform, "TRIGGER_SCHEMA"
|
||||
):
|
||||
hass.data[TRIGGERS][integration_domain] = integration_domain
|
||||
new_triggers.add(integration_domain)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Integration %s does not provide trigger support, skipping",
|
||||
integration_domain,
|
||||
)
|
||||
return
|
||||
|
||||
tasks: list[asyncio.Task[None]] = [
|
||||
create_eager_task(listener(new_triggers))
|
||||
for listener in hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS]
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
class Trigger(abc.ABC):
|
||||
"""Trigger class."""
|
||||
@@ -409,3 +513,107 @@ async def async_initialize_triggers(
|
||||
remove()
|
||||
|
||||
return remove_triggers
|
||||
|
||||
|
||||
def _load_triggers_file(hass: HomeAssistant, integration: Integration) -> JSON_TYPE:
|
||||
"""Load triggers file for an integration."""
|
||||
try:
|
||||
return cast(
|
||||
JSON_TYPE,
|
||||
_TRIGGERS_SCHEMA(
|
||||
load_yaml_dict(str(integration.file_path / "triggers.yaml"))
|
||||
),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
_LOGGER.warning(
|
||||
"Unable to find triggers.yaml for the %s integration", integration.domain
|
||||
)
|
||||
return {}
|
||||
except (HomeAssistantError, vol.Invalid) as ex:
|
||||
_LOGGER.warning(
|
||||
"Unable to parse triggers.yaml for the %s integration: %s",
|
||||
integration.domain,
|
||||
ex,
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def _load_triggers_files(
|
||||
hass: HomeAssistant, integrations: Iterable[Integration]
|
||||
) -> dict[str, JSON_TYPE]:
|
||||
"""Load trigger files for multiple integrations."""
|
||||
return {
|
||||
integration.domain: _load_triggers_file(hass, integration)
|
||||
for integration in integrations
|
||||
}
|
||||
|
||||
|
||||
async def async_get_all_descriptions(
|
||||
hass: HomeAssistant,
|
||||
) -> dict[str, dict[str, Any] | None]:
|
||||
"""Return descriptions (i.e. user documentation) for all triggers."""
|
||||
descriptions_cache = hass.data[TRIGGER_DESCRIPTION_CACHE]
|
||||
|
||||
triggers = hass.data[TRIGGERS]
|
||||
# See if there are new triggers not seen before.
|
||||
# Any trigger that we saw before already has an entry in description_cache.
|
||||
all_triggers = set(triggers)
|
||||
previous_all_triggers = set(descriptions_cache)
|
||||
# If the triggers are the same, we can return the cache
|
||||
if previous_all_triggers == all_triggers:
|
||||
return descriptions_cache
|
||||
|
||||
# Files we loaded for missing descriptions
|
||||
new_triggers_descriptions: dict[str, JSON_TYPE] = {}
|
||||
# We try to avoid making a copy in the event the cache is good,
|
||||
# but now we must make a copy in case new triggers get added
|
||||
# while we are loading the missing ones so we do not
|
||||
# add the new ones to the cache without their descriptions
|
||||
triggers = triggers.copy()
|
||||
|
||||
if missing_triggers := all_triggers.difference(descriptions_cache):
|
||||
domains_with_missing_triggers = {
|
||||
triggers[missing_trigger] for missing_trigger in missing_triggers
|
||||
}
|
||||
ints_or_excs = await async_get_integrations(hass, domains_with_missing_triggers)
|
||||
integrations: list[Integration] = []
|
||||
for domain, int_or_exc in ints_or_excs.items():
|
||||
if type(int_or_exc) is Integration and int_or_exc.has_triggers:
|
||||
integrations.append(int_or_exc)
|
||||
continue
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(int_or_exc, Exception)
|
||||
_LOGGER.debug(
|
||||
"Failed to load triggers.yaml for integration: %s",
|
||||
domain,
|
||||
exc_info=int_or_exc,
|
||||
)
|
||||
|
||||
if integrations:
|
||||
new_triggers_descriptions = await hass.async_add_executor_job(
|
||||
_load_triggers_files, hass, integrations
|
||||
)
|
||||
|
||||
# Make a copy of the old cache and add missing descriptions to it
|
||||
new_descriptions_cache = descriptions_cache.copy()
|
||||
for missing_trigger in missing_triggers:
|
||||
domain = triggers[missing_trigger]
|
||||
|
||||
if (
|
||||
yaml_description := new_triggers_descriptions.get(domain, {}).get( # type: ignore[union-attr]
|
||||
missing_trigger
|
||||
)
|
||||
) is None:
|
||||
_LOGGER.debug(
|
||||
"No trigger descriptions found for trigger %s, skipping",
|
||||
missing_trigger,
|
||||
)
|
||||
new_descriptions_cache[missing_trigger] = None
|
||||
continue
|
||||
|
||||
description = {"fields": yaml_description.get("fields", {})}
|
||||
|
||||
new_descriptions_cache[missing_trigger] = description
|
||||
|
||||
hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache
|
||||
return new_descriptions_cache
|
||||
|
||||
Reference in New Issue
Block a user