diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 5cbc8a1e678..75613b3d118 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -7,14 +7,14 @@ from enum import Enum from functools import wraps import logging from types import ModuleType -from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Union, overload import voluptuous as vol import voluptuous_serialize from homeassistant.components import websocket_api from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM -from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant +from homeassistant.core import HomeAssistant from homeassistant.helpers import ( config_validation as cv, device_registry as dr, @@ -28,11 +28,16 @@ from homeassistant.requirements import async_get_integration_with_requirements from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig if TYPE_CHECKING: - from homeassistant.components.automation import ( - AutomationActionType, - AutomationTriggerInfo, - ) - from homeassistant.helpers import condition + from .action import DeviceAutomationActionProtocol + from .condition import DeviceAutomationConditionProtocol + from .trigger import DeviceAutomationTriggerProtocol + + DeviceAutomationPlatformType = Union[ + ModuleType, + DeviceAutomationTriggerProtocol, + DeviceAutomationConditionProtocol, + DeviceAutomationActionProtocol, + ] # mypy: allow-untyped-calls, allow-untyped-defs @@ -83,77 +88,6 @@ TYPES = { } -class DeviceAutomationTriggerProtocol(Protocol): - """Define the format of device_trigger modules. - - Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config. - """ - - TRIGGER_SCHEMA: vol.Schema - - async def async_validate_trigger_config( - self, hass: HomeAssistant, config: ConfigType - ) -> ConfigType: - """Validate config.""" - raise NotImplementedError - - async def async_attach_trigger( - self, - hass: HomeAssistant, - config: ConfigType, - action: AutomationActionType, - automation_info: AutomationTriggerInfo, - ) -> CALLBACK_TYPE: - """Attach a trigger.""" - raise NotImplementedError - - -class DeviceAutomationConditionProtocol(Protocol): - """Define the format of device_condition modules. - - Each module must define either CONDITION_SCHEMA or async_validate_condition_config. - """ - - CONDITION_SCHEMA: vol.Schema - - async def async_validate_condition_config( - self, hass: HomeAssistant, config: ConfigType - ) -> ConfigType: - """Validate config.""" - raise NotImplementedError - - def async_condition_from_config( - self, hass: HomeAssistant, config: ConfigType - ) -> condition.ConditionCheckerType: - """Evaluate state based on configuration.""" - raise NotImplementedError - - -class DeviceAutomationActionProtocol(Protocol): - """Define the format of device_action modules. - - Each module must define either ACTION_SCHEMA or async_validate_action_config. - """ - - ACTION_SCHEMA: vol.Schema - - async def async_validate_action_config( - self, hass: HomeAssistant, config: ConfigType - ) -> ConfigType: - """Validate config.""" - raise NotImplementedError - - async def async_call_action_from_config( - self, - hass: HomeAssistant, - config: ConfigType, - variables: dict[str, Any], - context: Context | None, - ) -> None: - """Execute a device action.""" - raise NotImplementedError - - @bind_hass async def async_get_device_automations( hass: HomeAssistant, @@ -193,14 +127,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -DeviceAutomationPlatformType = Union[ - ModuleType, - DeviceAutomationTriggerProtocol, - DeviceAutomationConditionProtocol, - DeviceAutomationActionProtocol, -] - - @overload async def async_get_device_automation_platform( # noqa: D103 hass: HomeAssistant, @@ -231,13 +157,13 @@ async def async_get_device_automation_platform( # noqa: D103 @overload async def async_get_device_automation_platform( # noqa: D103 hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str -) -> DeviceAutomationPlatformType: +) -> "DeviceAutomationPlatformType": ... async def async_get_device_automation_platform( hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str -) -> DeviceAutomationPlatformType: +) -> "DeviceAutomationPlatformType": """Load device automation platform for integration. Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation. diff --git a/homeassistant/components/device_automation/action.py b/homeassistant/components/device_automation/action.py new file mode 100644 index 00000000000..5261757c645 --- /dev/null +++ b/homeassistant/components/device_automation/action.py @@ -0,0 +1,68 @@ +"""Device action validator.""" +from __future__ import annotations + +from typing import Any, Protocol, cast + +import voluptuous as vol + +from homeassistant.const import CONF_DOMAIN +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers.typing import ConfigType + +from . import DeviceAutomationType, async_get_device_automation_platform +from .exceptions import InvalidDeviceAutomationConfig + + +class DeviceAutomationActionProtocol(Protocol): + """Define the format of device_action modules. + + Each module must define either ACTION_SCHEMA or async_validate_action_config. + """ + + ACTION_SCHEMA: vol.Schema + + async def async_validate_action_config( + self, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + raise NotImplementedError + + async def async_call_action_from_config( + self, + hass: HomeAssistant, + config: ConfigType, + variables: dict[str, Any], + context: Context | None, + ) -> None: + """Execute a device action.""" + raise NotImplementedError + + +async def async_validate_action_config( + hass: HomeAssistant, config: ConfigType +) -> ConfigType: + """Validate config.""" + try: + platform = await async_get_device_automation_platform( + hass, config[CONF_DOMAIN], DeviceAutomationType.ACTION + ) + if hasattr(platform, "async_validate_action_config"): + return await platform.async_validate_action_config(hass, config) + return cast(ConfigType, platform.ACTION_SCHEMA(config)) + except InvalidDeviceAutomationConfig as err: + raise vol.Invalid(str(err) or "Invalid action configuration") from err + + +async def async_call_action_from_config( + hass: HomeAssistant, + config: ConfigType, + variables: dict[str, Any], + context: Context | None, +) -> None: + """Execute a device action.""" + platform = await async_get_device_automation_platform( + hass, + config[CONF_DOMAIN], + DeviceAutomationType.ACTION, + ) + await platform.async_call_action_from_config(hass, config, variables, context) diff --git a/homeassistant/components/device_automation/condition.py b/homeassistant/components/device_automation/condition.py new file mode 100644 index 00000000000..1c226ee8c29 --- /dev/null +++ b/homeassistant/components/device_automation/condition.py @@ -0,0 +1,64 @@ +"""Validate device conditions.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, cast + +import voluptuous as vol + +from homeassistant.const import CONF_DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.typing import ConfigType + +from . import DeviceAutomationType, async_get_device_automation_platform +from .exceptions import InvalidDeviceAutomationConfig + +if TYPE_CHECKING: + from homeassistant.helpers import condition + + +class DeviceAutomationConditionProtocol(Protocol): + """Define the format of device_condition modules. + + Each module must define either CONDITION_SCHEMA or async_validate_condition_config. + """ + + CONDITION_SCHEMA: vol.Schema + + async def async_validate_condition_config( + self, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + raise NotImplementedError + + def async_condition_from_config( + self, hass: HomeAssistant, config: ConfigType + ) -> condition.ConditionCheckerType: + """Evaluate state based on configuration.""" + raise NotImplementedError + + +async def async_validate_condition_config( + hass: HomeAssistant, config: ConfigType +) -> ConfigType: + """Validate device condition config.""" + try: + config = cv.DEVICE_CONDITION_SCHEMA(config) + platform = await async_get_device_automation_platform( + hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION + ) + if hasattr(platform, "async_validate_condition_config"): + return await platform.async_validate_condition_config(hass, config) + return cast(ConfigType, platform.CONDITION_SCHEMA(config)) + except InvalidDeviceAutomationConfig as err: + raise vol.Invalid(str(err) or "Invalid condition configuration") from err + + +async def async_condition_from_config( + hass: HomeAssistant, config: ConfigType +) -> condition.ConditionCheckerType: + """Test a device condition.""" + platform = await async_get_device_automation_platform( + hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION + ) + return platform.async_condition_from_config(hass, config) diff --git a/homeassistant/components/device_automation/trigger.py b/homeassistant/components/device_automation/trigger.py index f2962d6544e..933c5c4c60a 100644 --- a/homeassistant/components/device_automation/trigger.py +++ b/homeassistant/components/device_automation/trigger.py @@ -1,5 +1,5 @@ """Offer device oriented automation.""" -from typing import cast +from typing import Protocol, cast import voluptuous as vol @@ -21,17 +21,41 @@ from .exceptions import InvalidDeviceAutomationConfig TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) +class DeviceAutomationTriggerProtocol(Protocol): + """Define the format of device_trigger modules. + + Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config. + """ + + TRIGGER_SCHEMA: vol.Schema + + async def async_validate_trigger_config( + self, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + raise NotImplementedError + + async def async_attach_trigger( + self, + hass: HomeAssistant, + config: ConfigType, + action: AutomationActionType, + automation_info: AutomationTriggerInfo, + ) -> CALLBACK_TYPE: + """Attach a trigger.""" + raise NotImplementedError + + async def async_validate_trigger_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" - platform = await async_get_device_automation_platform( - hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER - ) - if not hasattr(platform, "async_validate_trigger_config"): - return cast(ConfigType, platform.TRIGGER_SCHEMA(config)) - try: + platform = await async_get_device_automation_platform( + hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER + ) + if not hasattr(platform, "async_validate_trigger_config"): + return cast(ConfigType, platform.TRIGGER_SCHEMA(config)) return await platform.async_validate_trigger_config(hass, config) except InvalidDeviceAutomationConfig as err: raise vol.Invalid(str(err) or "Invalid trigger configuration") from err diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 06853dd9450..3355424b710 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -13,10 +13,7 @@ import sys from typing import Any, cast from homeassistant.components import zone as zone_cmp -from homeassistant.components.device_automation import ( - DeviceAutomationType, - async_get_device_automation_platform, -) +from homeassistant.components.device_automation import condition as device_condition from homeassistant.components.sensor import SensorDeviceClass from homeassistant.const import ( ATTR_DEVICE_CLASS, @@ -30,7 +27,6 @@ from homeassistant.const import ( CONF_BELOW, CONF_CONDITION, CONF_DEVICE_ID, - CONF_DOMAIN, CONF_ENTITY_ID, CONF_ID, CONF_STATE, @@ -872,10 +868,8 @@ async def async_device_from_config( hass: HomeAssistant, config: ConfigType ) -> ConditionCheckerType: """Test a device condition.""" - platform = await async_get_device_automation_platform( - hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION - ) - return trace_condition_function(platform.async_condition_from_config(hass, config)) + checker = await device_condition.async_condition_from_config(hass, config) + return trace_condition_function(checker) async def async_trigger_from_config( @@ -931,15 +925,10 @@ async def async_validate_condition_config( sub_cond = await async_validate_condition_config(hass, sub_cond) conditions.append(sub_cond) config["conditions"] = conditions + return config if condition == "device": - config = cv.DEVICE_CONDITION_SCHEMA(config) - platform = await async_get_device_automation_platform( - hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION - ) - if hasattr(platform, "async_validate_condition_config"): - return await platform.async_validate_condition_config(hass, config) - return cast(ConfigType, platform.CONDITION_SCHEMA(config)) + return await device_condition.async_validate_condition_config(hass, config) if condition in ("numeric_state", "state"): validator = cast( diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 5a80691fa46..1eabc33b89d 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -15,7 +15,8 @@ import async_timeout import voluptuous as vol from homeassistant import exceptions -from homeassistant.components import device_automation, scene +from homeassistant.components import scene +from homeassistant.components.device_automation import action as device_action from homeassistant.components.logger import LOGSEVERITY from homeassistant.const import ( ATTR_AREA_ID, @@ -244,13 +245,7 @@ async def async_validate_action_config( pass elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: - platform = await device_automation.async_get_device_automation_platform( - hass, config[CONF_DOMAIN], device_automation.DeviceAutomationType.ACTION - ) - if hasattr(platform, "async_validate_action_config"): - config = await platform.async_validate_action_config(hass, config) - else: - config = platform.ACTION_SCHEMA(config) + config = await device_action.async_validate_action_config(hass, config) elif action_type == cv.SCRIPT_ACTION_CHECK_CONDITION: config = await condition.async_validate_condition_config(hass, config) @@ -580,12 +575,7 @@ class _ScriptRun: async def _async_device_step(self): """Perform the device automation specified in the action.""" self._step_log("device automation") - platform = await device_automation.async_get_device_automation_platform( - self._hass, - self._action[CONF_DOMAIN], - device_automation.DeviceAutomationType.ACTION, - ) - await platform.async_call_action_from_config( + await device_action.async_call_action_from_config( self._hass, self._action, self._variables, self._context ) diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index b365c114b03..ffbc2130f3b 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -2977,7 +2977,7 @@ async def test_platform_async_validate_condition_config(hass): config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test", CONF_CONDITION: "device"} platform = AsyncMock() with patch( - "homeassistant.helpers.condition.async_get_device_automation_platform", + "homeassistant.components.device_automation.condition.async_get_device_automation_platform", return_value=platform, ): platform.async_validate_condition_config.return_value = config diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 5bb4833a796..11ba9810b9d 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -3721,7 +3721,7 @@ async def test_platform_async_validate_action_config(hass): config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test"} platform = AsyncMock() with patch( - "homeassistant.helpers.script.device_automation.async_get_device_automation_platform", + "homeassistant.components.device_automation.action.async_get_device_automation_platform", return_value=platform, ): platform.async_validate_action_config.return_value = config