Improve trigger platform typing (#88511)

* Improve trigger platform typing

* Tweak docstring

* Revert "Tweak docstring"

This reverts commit c31f790fc3c1e66a5d802f759f07dfe4049cf529.

* Tweak docstring
This commit is contained in:
Erik Montnemery 2023-02-22 11:59:53 +01:00 committed by GitHub
parent 33b16d20b1
commit 6d9411b8a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 26 deletions

View File

@ -7,7 +7,11 @@ import voluptuous as vol
from homeassistant.const import CONF_DOMAIN from homeassistant.const import CONF_DOMAIN
from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import (
TriggerActionType,
TriggerInfo,
TriggerProtocol,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import ( from . import (
@ -20,28 +24,13 @@ from .helpers import async_validate_device_automation_config
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
class DeviceAutomationTriggerProtocol(Protocol): class DeviceAutomationTriggerProtocol(TriggerProtocol, Protocol):
"""Define the format of device_trigger modules. """Define the format of device_trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config. Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config
from TriggerProtocol.
""" """
TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
async def async_get_trigger_capabilities( async def async_get_trigger_capabilities(
self, hass: HomeAssistant, config: ConfigType self, hass: HomeAssistant, config: ConfigType
) -> dict[str, vol.Schema]: ) -> dict[str, vol.Schema]:

View File

@ -7,7 +7,7 @@ from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field from dataclasses import dataclass, field
import functools import functools
import logging import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast from typing import Any, Protocol, TypedDict, cast
import voluptuous as vol import voluptuous as vol
@ -31,11 +31,6 @@ from homeassistant.loader import IntegrationNotFound, async_get_integration
from .typing import ConfigType, TemplateVarsType from .typing import ConfigType, TemplateVarsType
if TYPE_CHECKING:
from homeassistant.components.device_automation.trigger import (
DeviceAutomationTriggerProtocol,
)
_PLATFORM_ALIASES = { _PLATFORM_ALIASES = {
"device": "device_automation", "device": "device_automation",
"event": "homeassistant", "event": "homeassistant",
@ -48,6 +43,29 @@ _PLATFORM_ALIASES = {
DATA_PLUGGABLE_ACTIONS = "pluggable_actions" DATA_PLUGGABLE_ACTIONS = "pluggable_actions"
class TriggerProtocol(Protocol):
"""Define the format of trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""
TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
class TriggerActionType(Protocol): class TriggerActionType(Protocol):
"""Protocol type for trigger action callback.""" """Protocol type for trigger action callback."""
@ -195,7 +213,7 @@ class PluggableAction:
async def _async_get_trigger_platform( async def _async_get_trigger_platform(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> DeviceAutomationTriggerProtocol: ) -> TriggerProtocol:
platform_and_sub_type = config[CONF_PLATFORM].split(".") platform_and_sub_type = config[CONF_PLATFORM].split(".")
platform = platform_and_sub_type[0] platform = platform_and_sub_type[0]
platform = _PLATFORM_ALIASES.get(platform, platform) platform = _PLATFORM_ALIASES.get(platform, platform)