diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 74582f0f77b..9d80ce169a9 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from collections.abc import Iterable, Mapping +from enum import Enum from functools import wraps import logging from types import ModuleType @@ -19,6 +20,7 @@ from homeassistant.helpers import ( device_registry as dr, entity_registry as er, ) +from homeassistant.helpers.frame import report from homeassistant.loader import IntegrationNotFound, bind_hass from homeassistant.requirements import async_get_integration_with_requirements @@ -45,32 +47,49 @@ class DeviceAutomationDetails(NamedTuple): get_capabilities_func: str -TYPES = { - "trigger": DeviceAutomationDetails( +class DeviceAutomationType(Enum): + """Device automation type.""" + + TRIGGER = DeviceAutomationDetails( "device_trigger", "async_get_triggers", "async_get_trigger_capabilities", - ), - "condition": DeviceAutomationDetails( + ) + CONDITION = DeviceAutomationDetails( "device_condition", "async_get_conditions", "async_get_condition_capabilities", - ), - "action": DeviceAutomationDetails( + ) + ACTION = DeviceAutomationDetails( "device_action", "async_get_actions", "async_get_action_capabilities", - ), + ) + + +# TYPES is deprecated as of Home Assistant 2022.2, use DeviceAutomationType instead +TYPES = { + "trigger": DeviceAutomationType.TRIGGER.value, + "condition": DeviceAutomationType.CONDITION.value, + "action": DeviceAutomationType.ACTION.value, } @bind_hass async def async_get_device_automations( hass: HomeAssistant, - automation_type: str, + automation_type: DeviceAutomationType | str, device_ids: Iterable[str] | None = None, ) -> Mapping[str, Any]: """Return all the device automations for a type optionally limited to specific device ids.""" + if isinstance(automation_type, str): + report( + "uses str for async_get_device_automations automation_type. This is " + "deprecated and will stop working in Home Assistant 2022.4, it should be " + "updated to use DeviceAutomationType instead", + error_if_core=False, + ) + automation_type = DeviceAutomationType[automation_type.upper()] return await _async_get_device_automations(hass, automation_type, device_ids) @@ -98,13 +117,21 @@ async def async_setup(hass, config): async def async_get_device_automation_platform( - hass: HomeAssistant, domain: str, automation_type: str + hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str ) -> ModuleType: """Load device automation platform for integration. Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation. """ - platform_name = TYPES[automation_type].section + if isinstance(automation_type, str): + report( + "uses str for async_get_device_automation_platform automation_type. This " + "is deprecated and will stop working in Home Assistant 2022.4, it should " + "be updated to use DeviceAutomationType instead", + error_if_core=False, + ) + automation_type = DeviceAutomationType[automation_type.upper()] + platform_name = automation_type.value.section try: integration = await async_get_integration_with_requirements(hass, domain) platform = integration.get_platform(platform_name) @@ -114,7 +141,8 @@ async def async_get_device_automation_platform( ) from err except ImportError as err: raise InvalidDeviceAutomationConfig( - f"Integration '{domain}' does not support device automation {automation_type}s" + f"Integration '{domain}' does not support device automation " + f"{automation_type.name.lower()}s" ) from err return platform @@ -131,7 +159,7 @@ async def _async_get_device_automations_from_domain( except InvalidDeviceAutomationConfig: return {} - function_name = TYPES[automation_type].get_automations_func + function_name = automation_type.value.get_automations_func return await asyncio.gather( *( @@ -143,7 +171,9 @@ async def _async_get_device_automations_from_domain( async def _async_get_device_automations( - hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None + hass: HomeAssistant, + automation_type: DeviceAutomationType, + device_ids: Iterable[str] | None, ) -> Mapping[str, list[dict[str, Any]]]: """List device automations.""" device_registry = dr.async_get(hass) @@ -188,7 +218,7 @@ async def _async_get_device_automations( if isinstance(device_results, Exception): logging.getLogger(__name__).error( "Unexpected error fetching device %ss", - automation_type, + automation_type.name.lower(), exc_info=device_results, ) continue @@ -207,7 +237,9 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom except InvalidDeviceAutomationConfig: return {} - function_name = TYPES[automation_type].get_capabilities_func + if isinstance(automation_type, str): # until tests pass DeviceAutomationType + automation_type = DeviceAutomationType[automation_type.upper()] + function_name = automation_type.value.get_capabilities_func if not hasattr(platform, function_name): # The device automation has no capabilities @@ -256,9 +288,11 @@ def handle_device_errors(func): async def websocket_device_automation_list_actions(hass, connection, msg): """Handle request for device actions.""" device_id = msg["device_id"] - actions = (await _async_get_device_automations(hass, "action", [device_id])).get( - device_id - ) + actions = ( + await _async_get_device_automations( + hass, DeviceAutomationType.ACTION, [device_id] + ) + ).get(device_id) connection.send_result(msg["id"], actions) @@ -274,7 +308,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg): """Handle request for device conditions.""" device_id = msg["device_id"] conditions = ( - await _async_get_device_automations(hass, "condition", [device_id]) + await _async_get_device_automations( + hass, DeviceAutomationType.CONDITION, [device_id] + ) ).get(device_id) connection.send_result(msg["id"], conditions) @@ -290,9 +326,11 @@ async def websocket_device_automation_list_conditions(hass, connection, msg): async def websocket_device_automation_list_triggers(hass, connection, msg): """Handle request for device triggers.""" device_id = msg["device_id"] - triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get( - device_id - ) + triggers = ( + await _async_get_device_automations( + hass, DeviceAutomationType.TRIGGER, [device_id] + ) + ).get(device_id) connection.send_result(msg["id"], triggers) @@ -308,7 +346,7 @@ async def websocket_device_automation_get_action_capabilities(hass, connection, """Handle request for device action capabilities.""" action = msg["action"] capabilities = await _async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) connection.send_result(msg["id"], capabilities) @@ -327,7 +365,7 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio """Handle request for device condition capabilities.""" condition = msg["condition"] capabilities = await _async_get_device_automation_capabilities( - hass, "condition", condition + hass, DeviceAutomationType.CONDITION, condition ) connection.send_result(msg["id"], capabilities) @@ -346,6 +384,6 @@ async def websocket_device_automation_get_trigger_capabilities(hass, connection, """Handle request for device trigger capabilities.""" trigger = msg["trigger"] capabilities = await _async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) connection.send_result(msg["id"], capabilities) diff --git a/homeassistant/components/device_automation/trigger.py b/homeassistant/components/device_automation/trigger.py index 62bd8d1c808..008a7603dba 100644 --- a/homeassistant/components/device_automation/trigger.py +++ b/homeassistant/components/device_automation/trigger.py @@ -3,7 +3,11 @@ import voluptuous as vol from homeassistant.const import CONF_DOMAIN -from . import DEVICE_TRIGGER_BASE_SCHEMA, async_get_device_automation_platform +from . import ( + DEVICE_TRIGGER_BASE_SCHEMA, + DeviceAutomationType, + async_get_device_automation_platform, +) from .exceptions import InvalidDeviceAutomationConfig # mypy: allow-untyped-defs, no-check-untyped-defs @@ -14,7 +18,7 @@ TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) async def async_validate_trigger_config(hass, config): """Validate config.""" platform = await async_get_device_automation_platform( - hass, config[CONF_DOMAIN], "trigger" + hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER ) if not hasattr(platform, "async_validate_trigger_config"): return platform.TRIGGER_SCHEMA(config) @@ -28,6 +32,6 @@ async def async_validate_trigger_config(hass, config): async def async_attach_trigger(hass, config, action, automation_info): """Listen for trigger.""" platform = await async_get_device_automation_platform( - hass, config[CONF_DOMAIN], "trigger" + hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER ) return await platform.async_attach_trigger(hass, config, action, automation_info) diff --git a/homeassistant/components/homekit/__init__.py b/homeassistant/components/homekit/__init__.py index 2bc7da6d24a..503a76418a9 100644 --- a/homeassistant/components/homekit/__init__.py +++ b/homeassistant/components/homekit/__init__.py @@ -819,7 +819,9 @@ class HomeKit: valid_device_ids.append(device_id) for device_id, device_triggers in ( await device_automation.async_get_device_automations( - self.hass, "trigger", valid_device_ids + self.hass, + device_automation.DeviceAutomationType.TRIGGER, + valid_device_ids, ) ).items(): self.add_bridge_triggers_accessory( diff --git a/homeassistant/components/homekit/config_flow.py b/homeassistant/components/homekit/config_flow.py index 0d8bf967c5b..f47ecdf5dbb 100644 --- a/homeassistant/components/homekit/config_flow.py +++ b/homeassistant/components/homekit/config_flow.py @@ -512,7 +512,9 @@ class OptionsFlowHandler(config_entries.OptionsFlow): async def _async_get_supported_devices(hass): """Return all supported devices.""" - results = await device_automation.async_get_device_automations(hass, "trigger") + results = await device_automation.async_get_device_automations( + hass, device_automation.DeviceAutomationType.TRIGGER + ) dev_reg = device_registry.async_get(hass) unsorted = { device_id: dev_reg.async_get(device_id).name or device_id diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 030e5dacfd5..d8d98f05ccd 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -14,6 +14,7 @@ from typing import Any, Callable, cast from homeassistant.components import zone as zone_cmp from homeassistant.components.device_automation import ( + DeviceAutomationType, async_get_device_automation_platform, ) from homeassistant.components.sensor import DEVICE_CLASS_TIMESTAMP @@ -881,7 +882,7 @@ async def async_device_from_config( ) -> ConditionCheckerType: """Test a device condition.""" platform = await async_get_device_automation_platform( - hass, config[CONF_DOMAIN], "condition" + hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION ) return trace_condition_function( cast( @@ -952,7 +953,7 @@ async def async_validate_condition_config( config = cv.DEVICE_CONDITION_SCHEMA(config) assert not isinstance(config, Template) platform = await async_get_device_automation_platform( - hass, config[CONF_DOMAIN], "condition" + hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION ) if hasattr(platform, "async_validate_condition_config"): return await platform.async_validate_condition_config(hass, config) # type: ignore diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 20a1dbb8aec..3e4432a40eb 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -254,7 +254,7 @@ async def async_validate_action_config( elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: platform = await device_automation.async_get_device_automation_platform( - hass, config[CONF_DOMAIN], "action" + hass, config[CONF_DOMAIN], device_automation.DeviceAutomationType.ACTION ) if hasattr(platform, "async_validate_action_config"): config = await platform.async_validate_action_config(hass, config) # type: ignore @@ -590,7 +590,9 @@ class _ScriptRun: """Perform the device automation specified in the action.""" self._step_log("device automation") platform = await device_automation.async_get_device_automation_platform( - self._hass, self._action[CONF_DOMAIN], "action" + self._hass, + self._action[CONF_DOMAIN], + device_automation.DeviceAutomationType.ACTION, ) await platform.async_call_action_from_config( self._hass, self._action, self._variables, self._context diff --git a/script/scaffold/templates/device_action/tests/test_device_action.py b/script/scaffold/templates/device_action/tests/test_device_action.py index 424fa0a9afd..f300ae55cf7 100644 --- a/script/scaffold/templates/device_action/tests/test_device_action.py +++ b/script/scaffold/templates/device_action/tests/test_device_action.py @@ -3,6 +3,7 @@ import pytest from homeassistant.components import automation from homeassistant.components.NEW_DOMAIN import DOMAIN +from homeassistant.components.device_automation import DeviceAutomationType from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry, entity_registry from homeassistant.setup import async_setup_component @@ -56,7 +57,9 @@ async def test_get_actions( "entity_id": "NEW_DOMAIN.test_5678", }, ] - actions = await async_get_device_automations(hass, "action", device_entry.id) + actions = await async_get_device_automations( + hass, DeviceAutomationType.ACTION, device_entry.id + ) assert_lists_same(actions, expected_actions) diff --git a/script/scaffold/templates/device_condition/tests/test_device_condition.py b/script/scaffold/templates/device_condition/tests/test_device_condition.py index 9a283fa1f5b..539a60ded97 100644 --- a/script/scaffold/templates/device_condition/tests/test_device_condition.py +++ b/script/scaffold/templates/device_condition/tests/test_device_condition.py @@ -5,6 +5,7 @@ import pytest from homeassistant.components import automation from homeassistant.components.NEW_DOMAIN import DOMAIN +from homeassistant.components.device_automation import DeviceAutomationType from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.helpers import device_registry, entity_registry @@ -67,7 +68,9 @@ async def test_get_conditions( "entity_id": f"{DOMAIN}.test_5678", }, ] - conditions = await async_get_device_automations(hass, "condition", device_entry.id) + conditions = await async_get_device_automations( + hass, DeviceAutomationType.CONDITION, device_entry.id + ) assert_lists_same(conditions, expected_conditions) diff --git a/script/scaffold/templates/device_trigger/tests/test_device_trigger.py b/script/scaffold/templates/device_trigger/tests/test_device_trigger.py index 55343abadb1..59ba9654566 100644 --- a/script/scaffold/templates/device_trigger/tests/test_device_trigger.py +++ b/script/scaffold/templates/device_trigger/tests/test_device_trigger.py @@ -3,6 +3,7 @@ import pytest from homeassistant.components import automation from homeassistant.components.NEW_DOMAIN import DOMAIN +from homeassistant.components.device_automation import DeviceAutomationType from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.helpers import device_registry from homeassistant.setup import async_setup_component @@ -60,7 +61,9 @@ async def test_get_triggers(hass, device_reg, entity_reg): "entity_id": f"{DOMAIN}.test_5678", }, ] - triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + triggers = await async_get_device_automations( + hass, DeviceAutomationType.TRIGGER, device_entry.id + ) assert_lists_same(triggers, expected_triggers) diff --git a/tests/common.py b/tests/common.py index 9d4a9cfe366..327427eda6e 100644 --- a/tests/common.py +++ b/tests/common.py @@ -69,7 +69,9 @@ CLIENT_REDIRECT_URI = "https://example.com/app/callback" async def async_get_device_automations( - hass: HomeAssistant, automation_type: str, device_id: str + hass: HomeAssistant, + automation_type: device_automation.DeviceAutomationType | str, + device_id: str, ) -> Any: """Get a device automation for a single device id.""" automations = await device_automation.async_get_device_automations( diff --git a/tests/components/device_automation/test_init.py b/tests/components/device_automation/test_init.py index 563611b99ad..fe656663ca8 100644 --- a/tests/components/device_automation/test_init.py +++ b/tests/components/device_automation/test_init.py @@ -391,6 +391,13 @@ async def test_async_get_device_automations_single_device_trigger( connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) + result = await device_automation.async_get_device_automations( + hass, device_automation.DeviceAutomationType.TRIGGER, [device_entry.id] + ) + assert device_entry.id in result + assert len(result[device_entry.id]) == 2 + + # Test deprecated str automation_type works, to be removed in 2022.4 result = await device_automation.async_get_device_automations( hass, "trigger", [device_entry.id] ) @@ -410,7 +417,9 @@ async def test_async_get_device_automations_all_devices_trigger( connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) - result = await device_automation.async_get_device_automations(hass, "trigger") + result = await device_automation.async_get_device_automations( + hass, device_automation.DeviceAutomationType.TRIGGER + ) assert device_entry.id in result assert len(result[device_entry.id]) == 2 @@ -427,7 +436,9 @@ async def test_async_get_device_automations_all_devices_condition( connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) - result = await device_automation.async_get_device_automations(hass, "condition") + result = await device_automation.async_get_device_automations( + hass, device_automation.DeviceAutomationType.CONDITION + ) assert device_entry.id in result assert len(result[device_entry.id]) == 2 @@ -444,7 +455,9 @@ async def test_async_get_device_automations_all_devices_action( connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) - result = await device_automation.async_get_device_automations(hass, "action") + result = await device_automation.async_get_device_automations( + hass, device_automation.DeviceAutomationType.ACTION + ) assert device_entry.id in result assert len(result[device_entry.id]) == 3 @@ -465,7 +478,9 @@ async def test_async_get_device_automations_all_devices_action_exception_throw( "homeassistant.components.light.device_trigger.async_get_triggers", side_effect=KeyError, ): - result = await device_automation.async_get_device_automations(hass, "trigger") + result = await device_automation.async_get_device_automations( + hass, device_automation.DeviceAutomationType.TRIGGER + ) assert device_entry.id in result assert len(result[device_entry.id]) == 0 assert "KeyError" in caplog.text diff --git a/tests/components/mobile_app/test_device_action.py b/tests/components/mobile_app/test_device_action.py index e5b15412e4d..f7846ecc377 100644 --- a/tests/components/mobile_app/test_device_action.py +++ b/tests/components/mobile_app/test_device_action.py @@ -16,7 +16,9 @@ async def test_get_actions(hass, push_registration): ] capabilitites = await device_automation._async_get_device_automation_capabilities( - hass, "action", {"domain": DOMAIN, "device_id": device_id, "type": "notify"} + hass, + device_automation.DeviceAutomationType.ACTION, + {"domain": DOMAIN, "device_id": device_id, "type": "notify"}, ) assert "extra_fields" in capabilitites