Improve type hints in automation (#78368)

* Improve type hints in automation

* Apply suggestion

* Apply suggestion

* Apply suggestion

* Add Protocol for IfAction

* Use ConfigType for IfAction

* Rename variable
This commit is contained in:
epenet 2022-09-14 13:04:09 +02:00 committed by GitHub
parent b7e9fcb9fe
commit 5e338d2166
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 80 deletions

View File

@ -1,9 +1,9 @@
"""Allow to set up simple automation rules via the config file.""" """Allow to set up simple automation rules via the config file."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Mapping
import logging import logging
from typing import Any, cast from typing import Any, Protocol, cast
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -31,9 +31,12 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
) )
from homeassistant.core import ( from homeassistant.core import (
CALLBACK_TYPE,
Context, Context,
CoreState, CoreState,
Event,
HomeAssistant, HomeAssistant,
ServiceCall,
callback, callback,
split_entity_id, split_entity_id,
valid_entity_id, valid_entity_id,
@ -99,9 +102,6 @@ from .const import (
from .helpers import async_get_blueprints from .helpers import async_get_blueprints
from .trace import trace_automation from .trace import trace_automation
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
ENTITY_ID_FORMAT = DOMAIN + ".{}" ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -120,6 +120,15 @@ SERVICE_TRIGGER = "trigger"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class IfAction(Protocol):
"""Define the format of if_action."""
config: list[ConfigType]
def __call__(self, variables: Mapping[str, Any] | None = None) -> bool:
"""AND all conditions."""
# AutomationActionType, AutomationTriggerData, # AutomationActionType, AutomationTriggerData,
# and AutomationTriggerInfo are deprecated as of 2022.9. # and AutomationTriggerInfo are deprecated as of 2022.9.
AutomationActionType = TriggerActionType AutomationActionType = TriggerActionType
@ -128,7 +137,7 @@ AutomationTriggerInfo = TriggerInfo
@bind_hass @bind_hass
def is_on(hass, entity_id): def is_on(hass: HomeAssistant, entity_id: str) -> bool:
""" """
Return true if specified automation entity_id is on. Return true if specified automation entity_id is on.
@ -143,12 +152,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component = hass.data[DOMAIN] component: EntityComponent = hass.data[DOMAIN]
return [ return [
automation_entity.entity_id automation_entity.entity_id
for automation_entity in component.entities for automation_entity in component.entities
if entity_id in automation_entity.referenced_entities if entity_id in cast(AutomationEntity, automation_entity).referenced_entities
] ]
@ -158,12 +167,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component = hass.data[DOMAIN] component: EntityComponent = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None: if (automation_entity := component.get_entity(entity_id)) is None:
return [] return []
return list(automation_entity.referenced_entities) return list(cast(AutomationEntity, automation_entity).referenced_entities)
@callback @callback
@ -172,12 +181,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component = hass.data[DOMAIN] component: EntityComponent = hass.data[DOMAIN]
return [ return [
automation_entity.entity_id automation_entity.entity_id
for automation_entity in component.entities for automation_entity in component.entities
if device_id in automation_entity.referenced_devices if device_id in cast(AutomationEntity, automation_entity).referenced_devices
] ]
@ -187,12 +196,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component = hass.data[DOMAIN] component: EntityComponent = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None: if (automation_entity := component.get_entity(entity_id)) is None:
return [] return []
return list(automation_entity.referenced_devices) return list(cast(AutomationEntity, automation_entity).referenced_devices)
@callback @callback
@ -201,12 +210,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component = hass.data[DOMAIN] component: EntityComponent = hass.data[DOMAIN]
return [ return [
automation_entity.entity_id automation_entity.entity_id
for automation_entity in component.entities for automation_entity in component.entities
if area_id in automation_entity.referenced_areas if area_id in cast(AutomationEntity, automation_entity).referenced_areas
] ]
@ -216,12 +225,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component = hass.data[DOMAIN] component: EntityComponent = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None: if (automation_entity := component.get_entity(entity_id)) is None:
return [] return []
return list(automation_entity.referenced_areas) return list(cast(AutomationEntity, automation_entity).referenced_areas)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -238,7 +247,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if not await _async_process_config(hass, config, component): if not await _async_process_config(hass, config, component):
await async_get_blueprints(hass).async_populate() await async_get_blueprints(hass).async_populate()
async def trigger_service_handler(entity, service_call): async def trigger_service_handler(
entity: AutomationEntity, service_call: ServiceCall
) -> None:
"""Handle forced automation trigger, e.g. from frontend.""" """Handle forced automation trigger, e.g. from frontend."""
await entity.async_trigger( await entity.async_trigger(
{**service_call.data[ATTR_VARIABLES], "trigger": {"platform": None}}, {**service_call.data[ATTR_VARIABLES], "trigger": {"platform": None}},
@ -262,7 +273,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_turn_off", "async_turn_off",
) )
async def reload_service_handler(service_call): async def reload_service_handler(service_call: ServiceCall) -> None:
"""Remove all automations and load new ones from config.""" """Remove all automations and load new ones from config."""
if (conf := await component.async_prepare_reload()) is None: if (conf := await component.async_prepare_reload()) is None:
return return
@ -290,22 +301,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
def __init__( def __init__(
self, self,
automation_id, automation_id: str | None,
name, name: str,
trigger_config, trigger_config: list[ConfigType],
cond_func, cond_func: IfAction | None,
action_script, action_script: Script,
initial_state, initial_state: bool | None,
variables, variables: ScriptVariables | None,
trigger_variables, trigger_variables: ScriptVariables | None,
raw_config, raw_config: ConfigType | None,
blueprint_inputs, blueprint_inputs: ConfigType | None,
trace_config, trace_config: ConfigType,
): ) -> None:
"""Initialize an automation entity.""" """Initialize an automation entity."""
self._attr_name = name self._attr_name = name
self._trigger_config = trigger_config self._trigger_config = trigger_config
self._async_detach_triggers = None self._async_detach_triggers: CALLBACK_TYPE | None = None
self._cond_func = cond_func self._cond_func = cond_func
self.action_script = action_script self.action_script = action_script
self.action_script.change_listener = self.async_write_ha_state self.action_script.change_listener = self.async_write_ha_state
@ -314,15 +325,15 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._referenced_entities: set[str] | None = None self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None self._referenced_devices: set[str] | None = None
self._logger = LOGGER self._logger = LOGGER
self._variables: ScriptVariables = variables self._variables = variables
self._trigger_variables: ScriptVariables = trigger_variables self._trigger_variables = trigger_variables
self._raw_config = raw_config self._raw_config = raw_config
self._blueprint_inputs = blueprint_inputs self._blueprint_inputs = blueprint_inputs
self._trace_config = trace_config self._trace_config = trace_config
self._attr_unique_id = automation_id self._attr_unique_id = automation_id
@property @property
def extra_state_attributes(self): def extra_state_attributes(self) -> dict[str, Any]:
"""Return the entity state attributes.""" """Return the entity state attributes."""
attrs = { attrs = {
ATTR_LAST_TRIGGERED: self.action_script.last_triggered, ATTR_LAST_TRIGGERED: self.action_script.last_triggered,
@ -341,12 +352,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
return self._async_detach_triggers is not None or self._is_enabled return self._async_detach_triggers is not None or self._is_enabled
@property @property
def referenced_areas(self): def referenced_areas(self) -> set[str]:
"""Return a set of referenced areas.""" """Return a set of referenced areas."""
return self.action_script.referenced_areas return self.action_script.referenced_areas
@property @property
def referenced_devices(self): def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices.""" """Return a set of referenced devices."""
if self._referenced_devices is not None: if self._referenced_devices is not None:
return self._referenced_devices return self._referenced_devices
@ -364,7 +375,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
return referenced return referenced
@property @property
def referenced_entities(self): def referenced_entities(self) -> set[str]:
"""Return a set of referenced entities.""" """Return a set of referenced entities."""
if self._referenced_entities is not None: if self._referenced_entities is not None:
return self._referenced_entities return self._referenced_entities
@ -513,7 +524,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
event_data[ATTR_SOURCE] = variables["trigger"]["description"] event_data[ATTR_SOURCE] = variables["trigger"]["description"]
@callback @callback
def started_action(): def started_action() -> None:
self.hass.bus.async_fire( self.hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context
) )
@ -555,12 +566,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._logger.exception("While executing automation %s", self.entity_id) self._logger.exception("While executing automation %s", self.entity_id)
automation_trace.set_error(err) automation_trace.set_error(err)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self) -> None:
"""Remove listeners when removing automation from Home Assistant.""" """Remove listeners when removing automation from Home Assistant."""
await super().async_will_remove_from_hass() await super().async_will_remove_from_hass()
await self.async_disable() await self.async_disable()
async def async_enable(self): async def async_enable(self) -> None:
"""Enable this automation entity. """Enable this automation entity.
This method is a coroutine. This method is a coroutine.
@ -576,7 +587,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
return return
async def async_enable_automation(event): async def async_enable_automation(event: Event) -> None:
"""Start automation on startup.""" """Start automation on startup."""
# Don't do anything if no longer enabled or already attached # Don't do anything if no longer enabled or already attached
if not self._is_enabled or self._async_detach_triggers is not None: if not self._is_enabled or self._async_detach_triggers is not None:
@ -589,7 +600,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
) )
self.async_write_ha_state() self.async_write_ha_state()
async def async_disable(self, stop_actions=DEFAULT_STOP_ACTIONS): async def async_disable(self, stop_actions: bool = DEFAULT_STOP_ACTIONS) -> None:
"""Disable the automation entity.""" """Disable the automation entity."""
if not self._is_enabled and not self.action_script.runs: if not self._is_enabled and not self.action_script.runs:
return return
@ -610,7 +621,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
) -> Callable[[], None] | None: ) -> Callable[[], None] | None:
"""Set up the triggers.""" """Set up the triggers."""
def log_cb(level, msg, **kwargs): def log_cb(level: int, msg: str, **kwargs: Any) -> None:
self._logger.log(level, "%s %s", msg, self.name, **kwargs) self._logger.log(level, "%s %s", msg, self.name, **kwargs)
this = None this = None
@ -650,7 +661,7 @@ async def _async_process_config(
Returns if blueprints were used. Returns if blueprints were used.
""" """
entities = [] entities: list[AutomationEntity] = []
blueprints_used = False blueprints_used = False
for config_key in extract_domain_configs(config, DOMAIN): for config_key in extract_domain_configs(config, DOMAIN):
@ -681,10 +692,10 @@ async def _async_process_config(
else: else:
raw_config = cast(AutomationConfig, config_block).raw_config raw_config = cast(AutomationConfig, config_block).raw_config
automation_id = config_block.get(CONF_ID) automation_id: str | None = config_block.get(CONF_ID)
name = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}" name: str = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}"
initial_state = config_block.get(CONF_INITIAL_STATE) initial_state: bool | None = config_block.get(CONF_INITIAL_STATE)
action_script = Script( action_script = Script(
hass, hass,
@ -743,11 +754,13 @@ async def _async_process_config(
return blueprints_used return blueprints_used
async def _async_process_if(hass, name, config, p_config): async def _async_process_if(
hass: HomeAssistant, name: str, config: dict[str, Any], p_config: dict[str, Any]
) -> IfAction | None:
"""Process if checks.""" """Process if checks."""
if_configs = p_config[CONF_CONDITION] if_configs = p_config[CONF_CONDITION]
checks = [] checks: list[condition.ConditionCheckerType] = []
for if_config in if_configs: for if_config in if_configs:
try: try:
checks.append(await condition.async_from_config(hass, if_config)) checks.append(await condition.async_from_config(hass, if_config))
@ -755,9 +768,9 @@ async def _async_process_if(hass, name, config, p_config):
LOGGER.warning("Invalid condition: %s", ex) LOGGER.warning("Invalid condition: %s", ex)
return None return None
def if_action(variables=None): def if_action(variables: Mapping[str, Any] | None = None) -> bool:
"""AND all conditions.""" """AND all conditions."""
errors = [] errors: list[ConditionErrorIndex] = []
for index, check in enumerate(checks): for index, check in enumerate(checks):
try: try:
with trace_path(["condition", str(index)]): with trace_path(["condition", str(index)]):
@ -780,9 +793,10 @@ async def _async_process_if(hass, name, config, p_config):
return True return True
if_action.config = if_configs result: IfAction = if_action # type: ignore[assignment]
result.config = if_configs
return if_action return result
@callback @callback
@ -800,7 +814,7 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]:
return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]] return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]]
if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf: if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf:
return trigger_conf[CONF_DEVICE_ID] return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return]
return [] return []
@ -809,13 +823,13 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]:
def _trigger_extract_entities(trigger_conf: dict) -> list[str]: def _trigger_extract_entities(trigger_conf: dict) -> list[str]:
"""Extract entities from a trigger config.""" """Extract entities from a trigger config."""
if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"): if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"):
return trigger_conf[CONF_ENTITY_ID] return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return]
if trigger_conf[CONF_PLATFORM] == "calendar": if trigger_conf[CONF_PLATFORM] == "calendar":
return [trigger_conf[CONF_ENTITY_ID]] return [trigger_conf[CONF_ENTITY_ID]]
if trigger_conf[CONF_PLATFORM] == "zone": if trigger_conf[CONF_PLATFORM] == "zone":
return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] # type: ignore[no-any-return]
if trigger_conf[CONF_PLATFORM] == "geo_location": if trigger_conf[CONF_PLATFORM] == "geo_location":
return [trigger_conf[CONF_ZONE]] return [trigger_conf[CONF_ZONE]]

View File

@ -1,6 +1,9 @@
"""Config validation helper for the automation integration.""" """Config validation helper for the automation integration."""
from __future__ import annotations
import asyncio import asyncio
from contextlib import suppress from contextlib import suppress
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -17,10 +20,12 @@ from homeassistant.const import (
CONF_ID, CONF_ID,
CONF_VARIABLES, CONF_VARIABLES,
) )
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, config_validation as cv, script from homeassistant.helpers import config_per_platform, config_validation as cv, script
from homeassistant.helpers.condition import async_validate_conditions_config from homeassistant.helpers.condition import async_validate_conditions_config
from homeassistant.helpers.trigger import async_validate_trigger_config from homeassistant.helpers.trigger import async_validate_trigger_config
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import IntegrationNotFound from homeassistant.loader import IntegrationNotFound
from .const import ( from .const import (
@ -34,9 +39,6 @@ from .const import (
) )
from .helpers import async_get_blueprints from .helpers import async_get_blueprints
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
PACKAGE_MERGE_HINT = "list" PACKAGE_MERGE_HINT = "list"
_CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA]) _CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
@ -63,7 +65,11 @@ PLATFORM_SCHEMA = vol.All(
) )
async def async_validate_config_item(hass, config, full_config=None): async def async_validate_config_item(
hass: HomeAssistant,
config: ConfigType,
full_config: ConfigType | None = None,
) -> blueprint.BlueprintInputs | dict[str, Any]:
"""Validate config item.""" """Validate config item."""
if blueprint.is_blueprint_instance_config(config): if blueprint.is_blueprint_instance_config(config):
blueprints = async_get_blueprints(hass) blueprints = async_get_blueprints(hass)
@ -90,17 +96,21 @@ async def async_validate_config_item(hass, config, full_config=None):
class AutomationConfig(dict): class AutomationConfig(dict):
"""Dummy class to allow adding attributes.""" """Dummy class to allow adding attributes."""
raw_config = None raw_config: dict[str, Any] | None = None
async def _try_async_validate_config_item(hass, config, full_config=None): async def _try_async_validate_config_item(
hass: HomeAssistant,
config: dict[str, Any],
full_config: dict[str, Any] | None = None,
) -> AutomationConfig | blueprint.BlueprintInputs | None:
"""Validate config item.""" """Validate config item."""
raw_config = None raw_config = None
with suppress(ValueError): with suppress(ValueError):
raw_config = dict(config) raw_config = dict(config)
try: try:
config = await async_validate_config_item(hass, config, full_config) validated_config = await async_validate_config_item(hass, config, full_config)
except ( except (
vol.Invalid, vol.Invalid,
HomeAssistantError, HomeAssistantError,
@ -110,15 +120,15 @@ async def _try_async_validate_config_item(hass, config, full_config=None):
async_log_exception(ex, DOMAIN, full_config or config, hass) async_log_exception(ex, DOMAIN, full_config or config, hass)
return None return None
if isinstance(config, blueprint.BlueprintInputs): if isinstance(validated_config, blueprint.BlueprintInputs):
return config return validated_config
config = AutomationConfig(config) automation_config = AutomationConfig(validated_config)
config.raw_config = raw_config automation_config.raw_config = raw_config
return config return automation_config
async def async_validate_config(hass, config): async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> ConfigType:
"""Validate config.""" """Validate config."""
automations = list( automations = list(
filter( filter(

View File

@ -1,6 +1,7 @@
"""Trace support for automation.""" """Trace support for automation."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any
@ -9,13 +10,11 @@ from homeassistant.components.trace import (
ActionTrace, ActionTrace,
async_store_trace, async_store_trace,
) )
from homeassistant.core import Context from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN from .const import DOMAIN
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
class AutomationTrace(ActionTrace): class AutomationTrace(ActionTrace):
"""Container for automation trace.""" """Container for automation trace."""
@ -24,9 +23,9 @@ class AutomationTrace(ActionTrace):
def __init__( def __init__(
self, self,
item_id: str, item_id: str | None,
config: dict[str, Any], config: ConfigType | None,
blueprint_inputs: dict[str, Any], blueprint_inputs: ConfigType | None,
context: Context, context: Context,
) -> None: ) -> None:
"""Container for automation trace.""" """Container for automation trace."""
@ -49,8 +48,13 @@ class AutomationTrace(ActionTrace):
@contextmanager @contextmanager
def trace_automation( def trace_automation(
hass, automation_id, config, blueprint_inputs, context, trace_config hass: HomeAssistant,
): automation_id: str | None,
config: ConfigType | None,
blueprint_inputs: ConfigType | None,
context: Context,
trace_config: ConfigType,
) -> Generator[AutomationTrace, None, None]:
"""Trace action execution of automation with automation_id.""" """Trace action execution of automation with automation_id."""
trace = AutomationTrace(automation_id, config, blueprint_inputs, context) trace = AutomationTrace(automation_id, config, blueprint_inputs, context)
async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES]) async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])