diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 154c443e799..a5ea30f59d2 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -1,9 +1,9 @@ """Allow to set up simple automation rules via the config file.""" from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Mapping import logging -from typing import Any, cast +from typing import Any, Protocol, cast import voluptuous as vol from voluptuous.humanize import humanize_error @@ -31,9 +31,12 @@ from homeassistant.const import ( STATE_ON, ) from homeassistant.core import ( + CALLBACK_TYPE, Context, CoreState, + Event, HomeAssistant, + ServiceCall, callback, split_entity_id, valid_entity_id, @@ -99,9 +102,6 @@ from .const import ( from .helpers import async_get_blueprints from .trace import trace_automation -# mypy: allow-untyped-calls, allow-untyped-defs -# mypy: no-check-untyped-defs, no-warn-return-any - ENTITY_ID_FORMAT = DOMAIN + ".{}" @@ -120,6 +120,15 @@ SERVICE_TRIGGER = "trigger" _LOGGER = logging.getLogger(__name__) +class IfAction(Protocol): + """Define the format of if_action.""" + + config: list[ConfigType] + + def __call__(self, variables: Mapping[str, Any] | None = None) -> bool: + """AND all conditions.""" + + # AutomationActionType, AutomationTriggerData, # and AutomationTriggerInfo are deprecated as of 2022.9. AutomationActionType = TriggerActionType @@ -128,7 +137,7 @@ AutomationTriggerInfo = TriggerInfo @bind_hass -def is_on(hass, entity_id): +def is_on(hass: HomeAssistant, entity_id: str) -> bool: """ Return true if specified automation entity_id is on. @@ -143,12 +152,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent = hass.data[DOMAIN] return [ automation_entity.entity_id for automation_entity in component.entities - if entity_id in automation_entity.referenced_entities + if entity_id in cast(AutomationEntity, automation_entity).referenced_entities ] @@ -158,12 +167,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent = hass.data[DOMAIN] if (automation_entity := component.get_entity(entity_id)) is None: return [] - return list(automation_entity.referenced_entities) + return list(cast(AutomationEntity, automation_entity).referenced_entities) @callback @@ -172,12 +181,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent = hass.data[DOMAIN] return [ automation_entity.entity_id for automation_entity in component.entities - if device_id in automation_entity.referenced_devices + if device_id in cast(AutomationEntity, automation_entity).referenced_devices ] @@ -187,12 +196,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent = hass.data[DOMAIN] if (automation_entity := component.get_entity(entity_id)) is None: return [] - return list(automation_entity.referenced_devices) + return list(cast(AutomationEntity, automation_entity).referenced_devices) @callback @@ -201,12 +210,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent = hass.data[DOMAIN] return [ automation_entity.entity_id for automation_entity in component.entities - if area_id in automation_entity.referenced_areas + if area_id in cast(AutomationEntity, automation_entity).referenced_areas ] @@ -216,12 +225,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent = hass.data[DOMAIN] if (automation_entity := component.get_entity(entity_id)) is None: return [] - return list(automation_entity.referenced_areas) + return list(cast(AutomationEntity, automation_entity).referenced_areas) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -238,7 +247,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: if not await _async_process_config(hass, config, component): await async_get_blueprints(hass).async_populate() - async def trigger_service_handler(entity, service_call): + async def trigger_service_handler( + entity: AutomationEntity, service_call: ServiceCall + ) -> None: """Handle forced automation trigger, e.g. from frontend.""" await entity.async_trigger( {**service_call.data[ATTR_VARIABLES], "trigger": {"platform": None}}, @@ -262,7 +273,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: "async_turn_off", ) - async def reload_service_handler(service_call): + async def reload_service_handler(service_call: ServiceCall) -> None: """Remove all automations and load new ones from config.""" if (conf := await component.async_prepare_reload()) is None: return @@ -290,22 +301,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity): def __init__( self, - automation_id, - name, - trigger_config, - cond_func, - action_script, - initial_state, - variables, - trigger_variables, - raw_config, - blueprint_inputs, - trace_config, - ): + automation_id: str | None, + name: str, + trigger_config: list[ConfigType], + cond_func: IfAction | None, + action_script: Script, + initial_state: bool | None, + variables: ScriptVariables | None, + trigger_variables: ScriptVariables | None, + raw_config: ConfigType | None, + blueprint_inputs: ConfigType | None, + trace_config: ConfigType, + ) -> None: """Initialize an automation entity.""" self._attr_name = name self._trigger_config = trigger_config - self._async_detach_triggers = None + self._async_detach_triggers: CALLBACK_TYPE | None = None self._cond_func = cond_func self.action_script = action_script self.action_script.change_listener = self.async_write_ha_state @@ -314,15 +325,15 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._referenced_entities: set[str] | None = None self._referenced_devices: set[str] | None = None self._logger = LOGGER - self._variables: ScriptVariables = variables - self._trigger_variables: ScriptVariables = trigger_variables + self._variables = variables + self._trigger_variables = trigger_variables self._raw_config = raw_config self._blueprint_inputs = blueprint_inputs self._trace_config = trace_config self._attr_unique_id = automation_id @property - def extra_state_attributes(self): + def extra_state_attributes(self) -> dict[str, Any]: """Return the entity state attributes.""" attrs = { ATTR_LAST_TRIGGERED: self.action_script.last_triggered, @@ -341,12 +352,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity): return self._async_detach_triggers is not None or self._is_enabled @property - def referenced_areas(self): + def referenced_areas(self) -> set[str]: """Return a set of referenced areas.""" return self.action_script.referenced_areas @property - def referenced_devices(self): + def referenced_devices(self) -> set[str]: """Return a set of referenced devices.""" if self._referenced_devices is not None: return self._referenced_devices @@ -364,7 +375,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): return referenced @property - def referenced_entities(self): + def referenced_entities(self) -> set[str]: """Return a set of referenced entities.""" if self._referenced_entities is not None: return self._referenced_entities @@ -513,7 +524,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): event_data[ATTR_SOURCE] = variables["trigger"]["description"] @callback - def started_action(): + def started_action() -> None: self.hass.bus.async_fire( EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context ) @@ -555,12 +566,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self._logger.exception("While executing automation %s", self.entity_id) automation_trace.set_error(err) - async def async_will_remove_from_hass(self): + async def async_will_remove_from_hass(self) -> None: """Remove listeners when removing automation from Home Assistant.""" await super().async_will_remove_from_hass() await self.async_disable() - async def async_enable(self): + async def async_enable(self) -> None: """Enable this automation entity. This method is a coroutine. @@ -576,7 +587,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): self.async_write_ha_state() return - async def async_enable_automation(event): + async def async_enable_automation(event: Event) -> None: """Start automation on startup.""" # Don't do anything if no longer enabled or already attached if not self._is_enabled or self._async_detach_triggers is not None: @@ -589,7 +600,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): ) self.async_write_ha_state() - async def async_disable(self, stop_actions=DEFAULT_STOP_ACTIONS): + async def async_disable(self, stop_actions: bool = DEFAULT_STOP_ACTIONS) -> None: """Disable the automation entity.""" if not self._is_enabled and not self.action_script.runs: return @@ -610,7 +621,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): ) -> Callable[[], None] | None: """Set up the triggers.""" - def log_cb(level, msg, **kwargs): + def log_cb(level: int, msg: str, **kwargs: Any) -> None: self._logger.log(level, "%s %s", msg, self.name, **kwargs) this = None @@ -650,7 +661,7 @@ async def _async_process_config( Returns if blueprints were used. """ - entities = [] + entities: list[AutomationEntity] = [] blueprints_used = False for config_key in extract_domain_configs(config, DOMAIN): @@ -681,10 +692,10 @@ async def _async_process_config( else: raw_config = cast(AutomationConfig, config_block).raw_config - automation_id = config_block.get(CONF_ID) - name = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}" + automation_id: str | None = config_block.get(CONF_ID) + name: str = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}" - initial_state = config_block.get(CONF_INITIAL_STATE) + initial_state: bool | None = config_block.get(CONF_INITIAL_STATE) action_script = Script( hass, @@ -743,11 +754,13 @@ async def _async_process_config( return blueprints_used -async def _async_process_if(hass, name, config, p_config): +async def _async_process_if( + hass: HomeAssistant, name: str, config: dict[str, Any], p_config: dict[str, Any] +) -> IfAction | None: """Process if checks.""" if_configs = p_config[CONF_CONDITION] - checks = [] + checks: list[condition.ConditionCheckerType] = [] for if_config in if_configs: try: checks.append(await condition.async_from_config(hass, if_config)) @@ -755,9 +768,9 @@ async def _async_process_if(hass, name, config, p_config): LOGGER.warning("Invalid condition: %s", ex) return None - def if_action(variables=None): + def if_action(variables: Mapping[str, Any] | None = None) -> bool: """AND all conditions.""" - errors = [] + errors: list[ConditionErrorIndex] = [] for index, check in enumerate(checks): try: with trace_path(["condition", str(index)]): @@ -780,9 +793,10 @@ async def _async_process_if(hass, name, config, p_config): return True - if_action.config = if_configs + result: IfAction = if_action # type: ignore[assignment] + result.config = if_configs - return if_action + return result @callback @@ -800,7 +814,7 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]: return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]] if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf: - return trigger_conf[CONF_DEVICE_ID] + return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return] return [] @@ -809,13 +823,13 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]: def _trigger_extract_entities(trigger_conf: dict) -> list[str]: """Extract entities from a trigger config.""" if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"): - return trigger_conf[CONF_ENTITY_ID] + return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return] if trigger_conf[CONF_PLATFORM] == "calendar": return [trigger_conf[CONF_ENTITY_ID]] if trigger_conf[CONF_PLATFORM] == "zone": - return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] + return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] # type: ignore[no-any-return] if trigger_conf[CONF_PLATFORM] == "geo_location": return [trigger_conf[CONF_ZONE]] diff --git a/homeassistant/components/automation/config.py b/homeassistant/components/automation/config.py index 228e78ac446..ec35e617b07 100644 --- a/homeassistant/components/automation/config.py +++ b/homeassistant/components/automation/config.py @@ -1,6 +1,9 @@ """Config validation helper for the automation integration.""" +from __future__ import annotations + import asyncio from contextlib import suppress +from typing import Any import voluptuous as vol @@ -17,10 +20,12 @@ from homeassistant.const import ( CONF_ID, CONF_VARIABLES, ) +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_per_platform, config_validation as cv, script from homeassistant.helpers.condition import async_validate_conditions_config from homeassistant.helpers.trigger import async_validate_trigger_config +from homeassistant.helpers.typing import ConfigType from homeassistant.loader import IntegrationNotFound from .const import ( @@ -34,9 +39,6 @@ from .const import ( ) from .helpers import async_get_blueprints -# mypy: allow-untyped-calls, allow-untyped-defs -# mypy: no-check-untyped-defs, no-warn-return-any - PACKAGE_MERGE_HINT = "list" _CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA]) @@ -63,7 +65,11 @@ PLATFORM_SCHEMA = vol.All( ) -async def async_validate_config_item(hass, config, full_config=None): +async def async_validate_config_item( + hass: HomeAssistant, + config: ConfigType, + full_config: ConfigType | None = None, +) -> blueprint.BlueprintInputs | dict[str, Any]: """Validate config item.""" if blueprint.is_blueprint_instance_config(config): blueprints = async_get_blueprints(hass) @@ -90,17 +96,21 @@ async def async_validate_config_item(hass, config, full_config=None): class AutomationConfig(dict): """Dummy class to allow adding attributes.""" - raw_config = None + raw_config: dict[str, Any] | None = None -async def _try_async_validate_config_item(hass, config, full_config=None): +async def _try_async_validate_config_item( + hass: HomeAssistant, + config: dict[str, Any], + full_config: dict[str, Any] | None = None, +) -> AutomationConfig | blueprint.BlueprintInputs | None: """Validate config item.""" raw_config = None with suppress(ValueError): raw_config = dict(config) try: - config = await async_validate_config_item(hass, config, full_config) + validated_config = await async_validate_config_item(hass, config, full_config) except ( vol.Invalid, HomeAssistantError, @@ -110,15 +120,15 @@ async def _try_async_validate_config_item(hass, config, full_config=None): async_log_exception(ex, DOMAIN, full_config or config, hass) return None - if isinstance(config, blueprint.BlueprintInputs): - return config + if isinstance(validated_config, blueprint.BlueprintInputs): + return validated_config - config = AutomationConfig(config) - config.raw_config = raw_config - return config + automation_config = AutomationConfig(validated_config) + automation_config.raw_config = raw_config + return automation_config -async def async_validate_config(hass, config): +async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> ConfigType: """Validate config.""" automations = list( filter( diff --git a/homeassistant/components/automation/trace.py b/homeassistant/components/automation/trace.py index b302f99d036..ae0d0339bfa 100644 --- a/homeassistant/components/automation/trace.py +++ b/homeassistant/components/automation/trace.py @@ -1,6 +1,7 @@ """Trace support for automation.""" from __future__ import annotations +from collections.abc import Generator from contextlib import contextmanager from typing import Any @@ -9,13 +10,11 @@ from homeassistant.components.trace import ( ActionTrace, async_store_trace, ) -from homeassistant.core import Context +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers.typing import ConfigType from .const import DOMAIN -# mypy: allow-untyped-calls, allow-untyped-defs -# mypy: no-check-untyped-defs, no-warn-return-any - class AutomationTrace(ActionTrace): """Container for automation trace.""" @@ -24,9 +23,9 @@ class AutomationTrace(ActionTrace): def __init__( self, - item_id: str, - config: dict[str, Any], - blueprint_inputs: dict[str, Any], + item_id: str | None, + config: ConfigType | None, + blueprint_inputs: ConfigType | None, context: Context, ) -> None: """Container for automation trace.""" @@ -49,8 +48,13 @@ class AutomationTrace(ActionTrace): @contextmanager def trace_automation( - hass, automation_id, config, blueprint_inputs, context, trace_config -): + hass: HomeAssistant, + automation_id: str | None, + config: ConfigType | None, + blueprint_inputs: ConfigType | None, + context: Context, + trace_config: ConfigType, +) -> Generator[AutomationTrace, None, None]: """Trace action execution of automation with automation_id.""" trace = AutomationTrace(automation_id, config, blueprint_inputs, context) async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])