Make device automation type an enum (#62354)

This commit is contained in:
Ville Skyttä 2021-12-20 20:16:30 +02:00 committed by GitHub
parent 2ddd45afd5
commit 334c6c5c02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 120 additions and 43 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from enum import Enum
from functools import wraps from functools import wraps
import logging import logging
from types import ModuleType from types import ModuleType
@ -19,6 +20,7 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
) )
from homeassistant.helpers.frame import report
from homeassistant.loader import IntegrationNotFound, bind_hass from homeassistant.loader import IntegrationNotFound, bind_hass
from homeassistant.requirements import async_get_integration_with_requirements from homeassistant.requirements import async_get_integration_with_requirements
@ -45,32 +47,49 @@ class DeviceAutomationDetails(NamedTuple):
get_capabilities_func: str get_capabilities_func: str
TYPES = { class DeviceAutomationType(Enum):
"trigger": DeviceAutomationDetails( """Device automation type."""
TRIGGER = DeviceAutomationDetails(
"device_trigger", "device_trigger",
"async_get_triggers", "async_get_triggers",
"async_get_trigger_capabilities", "async_get_trigger_capabilities",
), )
"condition": DeviceAutomationDetails( CONDITION = DeviceAutomationDetails(
"device_condition", "device_condition",
"async_get_conditions", "async_get_conditions",
"async_get_condition_capabilities", "async_get_condition_capabilities",
), )
"action": DeviceAutomationDetails( ACTION = DeviceAutomationDetails(
"device_action", "device_action",
"async_get_actions", "async_get_actions",
"async_get_action_capabilities", "async_get_action_capabilities",
), )
# TYPES is deprecated as of Home Assistant 2022.2, use DeviceAutomationType instead
TYPES = {
"trigger": DeviceAutomationType.TRIGGER.value,
"condition": DeviceAutomationType.CONDITION.value,
"action": DeviceAutomationType.ACTION.value,
} }
@bind_hass @bind_hass
async def async_get_device_automations( async def async_get_device_automations(
hass: HomeAssistant, hass: HomeAssistant,
automation_type: str, automation_type: DeviceAutomationType | str,
device_ids: Iterable[str] | None = None, device_ids: Iterable[str] | None = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Return all the device automations for a type optionally limited to specific device ids.""" """Return all the device automations for a type optionally limited to specific device ids."""
if isinstance(automation_type, str):
report(
"uses str for async_get_device_automations automation_type. This is "
"deprecated and will stop working in Home Assistant 2022.4, it should be "
"updated to use DeviceAutomationType instead",
error_if_core=False,
)
automation_type = DeviceAutomationType[automation_type.upper()]
return await _async_get_device_automations(hass, automation_type, device_ids) return await _async_get_device_automations(hass, automation_type, device_ids)
@ -98,13 +117,21 @@ async def async_setup(hass, config):
async def async_get_device_automation_platform( async def async_get_device_automation_platform(
hass: HomeAssistant, domain: str, automation_type: str hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> ModuleType: ) -> ModuleType:
"""Load device automation platform for integration. """Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation. Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
""" """
platform_name = TYPES[automation_type].section if isinstance(automation_type, str):
report(
"uses str for async_get_device_automation_platform automation_type. This "
"is deprecated and will stop working in Home Assistant 2022.4, it should "
"be updated to use DeviceAutomationType instead",
error_if_core=False,
)
automation_type = DeviceAutomationType[automation_type.upper()]
platform_name = automation_type.value.section
try: try:
integration = await async_get_integration_with_requirements(hass, domain) integration = await async_get_integration_with_requirements(hass, domain)
platform = integration.get_platform(platform_name) platform = integration.get_platform(platform_name)
@ -114,7 +141,8 @@ async def async_get_device_automation_platform(
) from err ) from err
except ImportError as err: except ImportError as err:
raise InvalidDeviceAutomationConfig( raise InvalidDeviceAutomationConfig(
f"Integration '{domain}' does not support device automation {automation_type}s" f"Integration '{domain}' does not support device automation "
f"{automation_type.name.lower()}s"
) from err ) from err
return platform return platform
@ -131,7 +159,7 @@ async def _async_get_device_automations_from_domain(
except InvalidDeviceAutomationConfig: except InvalidDeviceAutomationConfig:
return {} return {}
function_name = TYPES[automation_type].get_automations_func function_name = automation_type.value.get_automations_func
return await asyncio.gather( return await asyncio.gather(
*( *(
@ -143,7 +171,9 @@ async def _async_get_device_automations_from_domain(
async def _async_get_device_automations( async def _async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None hass: HomeAssistant,
automation_type: DeviceAutomationType,
device_ids: Iterable[str] | None,
) -> Mapping[str, list[dict[str, Any]]]: ) -> Mapping[str, list[dict[str, Any]]]:
"""List device automations.""" """List device automations."""
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
@ -188,7 +218,7 @@ async def _async_get_device_automations(
if isinstance(device_results, Exception): if isinstance(device_results, Exception):
logging.getLogger(__name__).error( logging.getLogger(__name__).error(
"Unexpected error fetching device %ss", "Unexpected error fetching device %ss",
automation_type, automation_type.name.lower(),
exc_info=device_results, exc_info=device_results,
) )
continue continue
@ -207,7 +237,9 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom
except InvalidDeviceAutomationConfig: except InvalidDeviceAutomationConfig:
return {} return {}
function_name = TYPES[automation_type].get_capabilities_func if isinstance(automation_type, str): # until tests pass DeviceAutomationType
automation_type = DeviceAutomationType[automation_type.upper()]
function_name = automation_type.value.get_capabilities_func
if not hasattr(platform, function_name): if not hasattr(platform, function_name):
# The device automation has no capabilities # The device automation has no capabilities
@ -256,9 +288,11 @@ def handle_device_errors(func):
async def websocket_device_automation_list_actions(hass, connection, msg): async def websocket_device_automation_list_actions(hass, connection, msg):
"""Handle request for device actions.""" """Handle request for device actions."""
device_id = msg["device_id"] device_id = msg["device_id"]
actions = (await _async_get_device_automations(hass, "action", [device_id])).get( actions = (
device_id await _async_get_device_automations(
hass, DeviceAutomationType.ACTION, [device_id]
) )
).get(device_id)
connection.send_result(msg["id"], actions) connection.send_result(msg["id"], actions)
@ -274,7 +308,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
"""Handle request for device conditions.""" """Handle request for device conditions."""
device_id = msg["device_id"] device_id = msg["device_id"]
conditions = ( conditions = (
await _async_get_device_automations(hass, "condition", [device_id]) await _async_get_device_automations(
hass, DeviceAutomationType.CONDITION, [device_id]
)
).get(device_id) ).get(device_id)
connection.send_result(msg["id"], conditions) connection.send_result(msg["id"], conditions)
@ -290,9 +326,11 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
async def websocket_device_automation_list_triggers(hass, connection, msg): async def websocket_device_automation_list_triggers(hass, connection, msg):
"""Handle request for device triggers.""" """Handle request for device triggers."""
device_id = msg["device_id"] device_id = msg["device_id"]
triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get( triggers = (
device_id await _async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, [device_id]
) )
).get(device_id)
connection.send_result(msg["id"], triggers) connection.send_result(msg["id"], triggers)
@ -308,7 +346,7 @@ async def websocket_device_automation_get_action_capabilities(hass, connection,
"""Handle request for device action capabilities.""" """Handle request for device action capabilities."""
action = msg["action"] action = msg["action"]
capabilities = await _async_get_device_automation_capabilities( capabilities = await _async_get_device_automation_capabilities(
hass, "action", action hass, DeviceAutomationType.ACTION, action
) )
connection.send_result(msg["id"], capabilities) connection.send_result(msg["id"], capabilities)
@ -327,7 +365,7 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio
"""Handle request for device condition capabilities.""" """Handle request for device condition capabilities."""
condition = msg["condition"] condition = msg["condition"]
capabilities = await _async_get_device_automation_capabilities( capabilities = await _async_get_device_automation_capabilities(
hass, "condition", condition hass, DeviceAutomationType.CONDITION, condition
) )
connection.send_result(msg["id"], capabilities) connection.send_result(msg["id"], capabilities)
@ -346,6 +384,6 @@ async def websocket_device_automation_get_trigger_capabilities(hass, connection,
"""Handle request for device trigger capabilities.""" """Handle request for device trigger capabilities."""
trigger = msg["trigger"] trigger = msg["trigger"]
capabilities = await _async_get_device_automation_capabilities( capabilities = await _async_get_device_automation_capabilities(
hass, "trigger", trigger hass, DeviceAutomationType.TRIGGER, trigger
) )
connection.send_result(msg["id"], capabilities) connection.send_result(msg["id"], capabilities)

View File

@ -3,7 +3,11 @@ import voluptuous as vol
from homeassistant.const import CONF_DOMAIN from homeassistant.const import CONF_DOMAIN
from . import DEVICE_TRIGGER_BASE_SCHEMA, async_get_device_automation_platform from . import (
DEVICE_TRIGGER_BASE_SCHEMA,
DeviceAutomationType,
async_get_device_automation_platform,
)
from .exceptions import InvalidDeviceAutomationConfig from .exceptions import InvalidDeviceAutomationConfig
# mypy: allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
@ -14,7 +18,7 @@ TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
async def async_validate_trigger_config(hass, config): async def async_validate_trigger_config(hass, config):
"""Validate config.""" """Validate config."""
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "trigger" hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
) )
if not hasattr(platform, "async_validate_trigger_config"): if not hasattr(platform, "async_validate_trigger_config"):
return platform.TRIGGER_SCHEMA(config) return platform.TRIGGER_SCHEMA(config)
@ -28,6 +32,6 @@ async def async_validate_trigger_config(hass, config):
async def async_attach_trigger(hass, config, action, automation_info): async def async_attach_trigger(hass, config, action, automation_info):
"""Listen for trigger.""" """Listen for trigger."""
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "trigger" hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
) )
return await platform.async_attach_trigger(hass, config, action, automation_info) return await platform.async_attach_trigger(hass, config, action, automation_info)

View File

@ -819,7 +819,9 @@ class HomeKit:
valid_device_ids.append(device_id) valid_device_ids.append(device_id)
for device_id, device_triggers in ( for device_id, device_triggers in (
await device_automation.async_get_device_automations( await device_automation.async_get_device_automations(
self.hass, "trigger", valid_device_ids self.hass,
device_automation.DeviceAutomationType.TRIGGER,
valid_device_ids,
) )
).items(): ).items():
self.add_bridge_triggers_accessory( self.add_bridge_triggers_accessory(

View File

@ -512,7 +512,9 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
async def _async_get_supported_devices(hass): async def _async_get_supported_devices(hass):
"""Return all supported devices.""" """Return all supported devices."""
results = await device_automation.async_get_device_automations(hass, "trigger") results = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER
)
dev_reg = device_registry.async_get(hass) dev_reg = device_registry.async_get(hass)
unsorted = { unsorted = {
device_id: dev_reg.async_get(device_id).name or device_id device_id: dev_reg.async_get(device_id).name or device_id

View File

@ -14,6 +14,7 @@ from typing import Any, Callable, cast
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import ( from homeassistant.components.device_automation import (
DeviceAutomationType,
async_get_device_automation_platform, async_get_device_automation_platform,
) )
from homeassistant.components.sensor import DEVICE_CLASS_TIMESTAMP from homeassistant.components.sensor import DEVICE_CLASS_TIMESTAMP
@ -881,7 +882,7 @@ async def async_device_from_config(
) -> ConditionCheckerType: ) -> ConditionCheckerType:
"""Test a device condition.""" """Test a device condition."""
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition" hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
) )
return trace_condition_function( return trace_condition_function(
cast( cast(
@ -952,7 +953,7 @@ async def async_validate_condition_config(
config = cv.DEVICE_CONDITION_SCHEMA(config) config = cv.DEVICE_CONDITION_SCHEMA(config)
assert not isinstance(config, Template) assert not isinstance(config, Template)
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition" hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
) )
if hasattr(platform, "async_validate_condition_config"): if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config) # type: ignore return await platform.async_validate_condition_config(hass, config) # type: ignore

View File

@ -254,7 +254,7 @@ async def async_validate_action_config(
elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
platform = await device_automation.async_get_device_automation_platform( platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "action" hass, config[CONF_DOMAIN], device_automation.DeviceAutomationType.ACTION
) )
if hasattr(platform, "async_validate_action_config"): if hasattr(platform, "async_validate_action_config"):
config = await platform.async_validate_action_config(hass, config) # type: ignore config = await platform.async_validate_action_config(hass, config) # type: ignore
@ -590,7 +590,9 @@ class _ScriptRun:
"""Perform the device automation specified in the action.""" """Perform the device automation specified in the action."""
self._step_log("device automation") self._step_log("device automation")
platform = await device_automation.async_get_device_automation_platform( platform = await device_automation.async_get_device_automation_platform(
self._hass, self._action[CONF_DOMAIN], "action" self._hass,
self._action[CONF_DOMAIN],
device_automation.DeviceAutomationType.ACTION,
) )
await platform.async_call_action_from_config( await platform.async_call_action_from_config(
self._hass, self._action, self._variables, self._context self._hass, self._action, self._variables, self._context

View File

@ -3,6 +3,7 @@ import pytest
from homeassistant.components import automation from homeassistant.components import automation
from homeassistant.components.NEW_DOMAIN import DOMAIN from homeassistant.components.NEW_DOMAIN import DOMAIN
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers import device_registry, entity_registry
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -56,7 +57,9 @@ async def test_get_actions(
"entity_id": "NEW_DOMAIN.test_5678", "entity_id": "NEW_DOMAIN.test_5678",
}, },
] ]
actions = await async_get_device_automations(hass, "action", device_entry.id) actions = await async_get_device_automations(
hass, DeviceAutomationType.ACTION, device_entry.id
)
assert_lists_same(actions, expected_actions) assert_lists_same(actions, expected_actions)

View File

@ -5,6 +5,7 @@ import pytest
from homeassistant.components import automation from homeassistant.components import automation
from homeassistant.components.NEW_DOMAIN import DOMAIN from homeassistant.components.NEW_DOMAIN import DOMAIN
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers import device_registry, entity_registry
@ -67,7 +68,9 @@ async def test_get_conditions(
"entity_id": f"{DOMAIN}.test_5678", "entity_id": f"{DOMAIN}.test_5678",
}, },
] ]
conditions = await async_get_device_automations(hass, "condition", device_entry.id) conditions = await async_get_device_automations(
hass, DeviceAutomationType.CONDITION, device_entry.id
)
assert_lists_same(conditions, expected_conditions) assert_lists_same(conditions, expected_conditions)

View File

@ -3,6 +3,7 @@ import pytest
from homeassistant.components import automation from homeassistant.components import automation
from homeassistant.components.NEW_DOMAIN import DOMAIN from homeassistant.components.NEW_DOMAIN import DOMAIN
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.helpers import device_registry from homeassistant.helpers import device_registry
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -60,7 +61,9 @@ async def test_get_triggers(hass, device_reg, entity_reg):
"entity_id": f"{DOMAIN}.test_5678", "entity_id": f"{DOMAIN}.test_5678",
}, },
] ]
triggers = await async_get_device_automations(hass, "trigger", device_entry.id) triggers = await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, device_entry.id
)
assert_lists_same(triggers, expected_triggers) assert_lists_same(triggers, expected_triggers)

View File

@ -69,7 +69,9 @@ CLIENT_REDIRECT_URI = "https://example.com/app/callback"
async def async_get_device_automations( async def async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_id: str hass: HomeAssistant,
automation_type: device_automation.DeviceAutomationType | str,
device_id: str,
) -> Any: ) -> Any:
"""Get a device automation for a single device id.""" """Get a device automation for a single device id."""
automations = await device_automation.async_get_device_automations( automations = await device_automation.async_get_device_automations(

View File

@ -391,6 +391,13 @@ async def test_async_get_device_automations_single_device_trigger(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
) )
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER, [device_entry.id]
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
# Test deprecated str automation_type works, to be removed in 2022.4
result = await device_automation.async_get_device_automations( result = await device_automation.async_get_device_automations(
hass, "trigger", [device_entry.id] hass, "trigger", [device_entry.id]
) )
@ -410,7 +417,9 @@ async def test_async_get_device_automations_all_devices_trigger(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
) )
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "trigger") result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER
)
assert device_entry.id in result assert device_entry.id in result
assert len(result[device_entry.id]) == 2 assert len(result[device_entry.id]) == 2
@ -427,7 +436,9 @@ async def test_async_get_device_automations_all_devices_condition(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
) )
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "condition") result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.CONDITION
)
assert device_entry.id in result assert device_entry.id in result
assert len(result[device_entry.id]) == 2 assert len(result[device_entry.id]) == 2
@ -444,7 +455,9 @@ async def test_async_get_device_automations_all_devices_action(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
) )
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "action") result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.ACTION
)
assert device_entry.id in result assert device_entry.id in result
assert len(result[device_entry.id]) == 3 assert len(result[device_entry.id]) == 3
@ -465,7 +478,9 @@ async def test_async_get_device_automations_all_devices_action_exception_throw(
"homeassistant.components.light.device_trigger.async_get_triggers", "homeassistant.components.light.device_trigger.async_get_triggers",
side_effect=KeyError, side_effect=KeyError,
): ):
result = await device_automation.async_get_device_automations(hass, "trigger") result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER
)
assert device_entry.id in result assert device_entry.id in result
assert len(result[device_entry.id]) == 0 assert len(result[device_entry.id]) == 0
assert "KeyError" in caplog.text assert "KeyError" in caplog.text

View File

@ -16,7 +16,9 @@ async def test_get_actions(hass, push_registration):
] ]
capabilitites = await device_automation._async_get_device_automation_capabilities( capabilitites = await device_automation._async_get_device_automation_capabilities(
hass, "action", {"domain": DOMAIN, "device_id": device_id, "type": "notify"} hass,
device_automation.DeviceAutomationType.ACTION,
{"domain": DOMAIN, "device_id": device_id, "type": "notify"},
) )
assert "extra_fields" in capabilitites assert "extra_fields" in capabilitites