Add api to device_automation to return all matching devices (#53361)

This commit is contained in:
J. Nick Koston 2021-08-10 14:21:34 -05:00 committed by GitHub
parent ac29571db3
commit 4bde4504ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 48 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import MutableMapping from collections.abc import Iterable, Mapping
from functools import wraps from functools import wraps
from types import ModuleType from types import ModuleType
from typing import Any from typing import Any
@ -13,9 +13,12 @@ import voluptuous_serialize
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import (
from homeassistant.helpers.entity_registry import async_entries_for_device config_validation as cv,
from homeassistant.loader import IntegrationNotFound device_registry as dr,
entity_registry as er,
)
from homeassistant.loader import IntegrationNotFound, bind_hass
from homeassistant.requirements import async_get_integration_with_requirements from homeassistant.requirements import async_get_integration_with_requirements
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
@ -49,6 +52,16 @@ TYPES = {
} }
@bind_hass
async def async_get_device_automations(
hass: HomeAssistant,
automation_type: str,
device_ids: Iterable[str] | None = None,
) -> Mapping[str, Any]:
"""Return all the device automations for a type optionally limited to specific device ids."""
return await _async_get_device_automations(hass, automation_type, device_ids)
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up device automation.""" """Set up device automation."""
hass.components.websocket_api.async_register_command( hass.components.websocket_api.async_register_command(
@ -96,7 +109,7 @@ async def async_get_device_automation_platform(
async def _async_get_device_automations_from_domain( async def _async_get_device_automations_from_domain(
hass, domain, automation_type, device_id hass, domain, automation_type, device_ids, return_exceptions
): ):
"""List device automations.""" """List device automations."""
try: try:
@ -104,48 +117,67 @@ async def _async_get_device_automations_from_domain(
hass, domain, automation_type hass, domain, automation_type
) )
except InvalidDeviceAutomationConfig: except InvalidDeviceAutomationConfig:
return None return {}
function_name = TYPES[automation_type][1] function_name = TYPES[automation_type][1]
return await getattr(platform, function_name)(hass, device_id) return await asyncio.gather(
*(
getattr(platform, function_name)(hass, device_id)
async def _async_get_device_automations(hass, automation_type, device_id): for device_id in device_ids
"""List device automations.""" ),
device_registry, entity_registry = await asyncio.gather( return_exceptions=return_exceptions,
hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(),
) )
domains = set()
automations: list[MutableMapping[str, Any]] = []
device = device_registry.async_get(device_id)
if device is None: async def _async_get_device_automations(
raise DeviceNotFound hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None
) -> Mapping[str, list[dict[str, Any]]]:
"""List device automations."""
device_registry = dr.async_get(hass)
entity_registry = er.async_get(hass)
domain_devices: dict[str, set[str]] = {}
device_entities_domains: dict[str, set[str]] = {}
match_device_ids = set(device_ids or device_registry.devices)
combined_results: dict[str, list[dict[str, Any]]] = {}
for entry_id in device.config_entries: for entry in entity_registry.entities.values():
config_entry = hass.config_entries.async_get_entry(entry_id) if not entry.disabled_by and entry.device_id in match_device_ids:
domains.add(config_entry.domain) device_entities_domains.setdefault(entry.device_id, set()).add(entry.domain)
entity_entries = async_entries_for_device(entity_registry, device_id) for device_id in match_device_ids:
for entity_entry in entity_entries: combined_results[device_id] = []
domains.add(entity_entry.domain) device = device_registry.async_get(device_id)
if device is None:
raise DeviceNotFound
for entry_id in device.config_entries:
if config_entry := hass.config_entries.async_get_entry(entry_id):
domain_devices.setdefault(config_entry.domain, set()).add(device_id)
for domain in device_entities_domains.get(device_id, []):
domain_devices.setdefault(domain, set()).add(device_id)
device_automations = await asyncio.gather( # If specific device ids were requested, we allow
# InvalidDeviceAutomationConfig to be thrown, otherwise we skip
# devices that do not have valid triggers
return_exceptions = not bool(device_ids)
for domain_results in await asyncio.gather(
*( *(
_async_get_device_automations_from_domain( _async_get_device_automations_from_domain(
hass, domain, automation_type, device_id hass, domain, automation_type, domain_device_ids, return_exceptions
) )
for domain in domains for domain, domain_device_ids in domain_devices.items()
) )
) ):
for device_automation in device_automations: for device_results in domain_results:
if device_automation is not None: if device_results is None or isinstance(
automations.extend(device_automation) device_results, InvalidDeviceAutomationConfig
):
continue
for automation in device_results:
combined_results[automation["device_id"]].append(automation)
return automations return combined_results
async def _async_get_device_automation_capabilities(hass, automation_type, automation): async def _async_get_device_automation_capabilities(hass, automation_type, automation):
@ -207,7 +239,9 @@ def handle_device_errors(func):
async def websocket_device_automation_list_actions(hass, connection, msg): async def websocket_device_automation_list_actions(hass, connection, msg):
"""Handle request for device actions.""" """Handle request for device actions."""
device_id = msg["device_id"] device_id = msg["device_id"]
actions = await _async_get_device_automations(hass, "action", device_id) actions = (await _async_get_device_automations(hass, "action", [device_id])).get(
device_id
)
connection.send_result(msg["id"], actions) connection.send_result(msg["id"], actions)
@ -222,7 +256,9 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
async def websocket_device_automation_list_conditions(hass, connection, msg): async def websocket_device_automation_list_conditions(hass, connection, msg):
"""Handle request for device conditions.""" """Handle request for device conditions."""
device_id = msg["device_id"] device_id = msg["device_id"]
conditions = await _async_get_device_automations(hass, "condition", device_id) conditions = (
await _async_get_device_automations(hass, "condition", [device_id])
).get(device_id)
connection.send_result(msg["id"], conditions) connection.send_result(msg["id"], conditions)
@ -237,7 +273,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
async def websocket_device_automation_list_triggers(hass, connection, msg): async def websocket_device_automation_list_triggers(hass, connection, msg):
"""Handle request for device triggers.""" """Handle request for device triggers."""
device_id = msg["device_id"] device_id = msg["device_id"]
triggers = await _async_get_device_automations(hass, "trigger", device_id) triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get(
device_id
)
connection.send_result(msg["id"], triggers) connection.send_result(msg["id"], triggers)

View File

@ -29,10 +29,9 @@ from homeassistant.auth import (
providers as auth_providers, providers as auth_providers,
) )
from homeassistant.auth.permissions import system_policies from homeassistant.auth.permissions import system_policies
from homeassistant.components import recorder from homeassistant.components import device_automation, recorder
from homeassistant.components.device_automation import ( # noqa: F401 from homeassistant.components.device_automation import ( # noqa: F401
_async_get_device_automation_capabilities as async_get_device_automation_capabilities, _async_get_device_automation_capabilities as async_get_device_automation_capabilities,
_async_get_device_automations as async_get_device_automations,
) )
from homeassistant.components.mqtt.models import ReceiveMessage from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
@ -69,6 +68,16 @@ CLIENT_ID = "https://example.com/app"
CLIENT_REDIRECT_URI = "https://example.com/app/callback" CLIENT_REDIRECT_URI = "https://example.com/app/callback"
async def async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_id: str
) -> Any:
"""Get a device automation for a single device id."""
automations = await device_automation.async_get_device_automations(
hass, automation_type, [device_id]
)
return automations.get(device_id)
def threadsafe_callback_factory(func): def threadsafe_callback_factory(func):
"""Create threadsafe functions out of callbacks. """Create threadsafe functions out of callbacks.

View File

@ -1,6 +1,7 @@
"""The test for light device automation.""" """The test for light device automation."""
import pytest import pytest
from homeassistant.components import device_automation
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
from homeassistant.components.websocket_api.const import TYPE_RESULT from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
@ -372,6 +373,76 @@ async def test_websocket_get_no_condition_capabilities(
assert capabilities == expected_capabilities assert capabilities == expected_capabilities
async def test_async_get_device_automations_single_device_trigger(
hass, device_reg, entity_reg
):
"""Test we get can fetch the triggers for a device id."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(
hass, "trigger", [device_entry.id]
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
async def test_async_get_device_automations_all_devices_trigger(
hass, device_reg, entity_reg
):
"""Test we get can fetch all the triggers when no device id is passed."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "trigger")
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
async def test_async_get_device_automations_all_devices_condition(
hass, device_reg, entity_reg
):
"""Test we get can fetch all the conditions when no device id is passed."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "condition")
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
async def test_async_get_device_automations_all_devices_action(
hass, device_reg, entity_reg
):
"""Test we get can fetch all the actions when no device id is passed."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "action")
assert device_entry.id in result
assert len(result[device_entry.id]) == 3
async def test_websocket_get_trigger_capabilities( async def test_websocket_get_trigger_capabilities(
hass, hass_ws_client, device_reg, entity_reg hass, hass_ws_client, device_reg, entity_reg
): ):

View File

@ -2,9 +2,6 @@
import pytest import pytest
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
from homeassistant.components.device_automation import (
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.remote import DOMAIN from homeassistant.components.remote import DOMAIN
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
from homeassistant.helpers import device_registry from homeassistant.helpers import device_registry
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
async_get_device_automations,
async_mock_service, async_mock_service,
mock_device_registry, mock_device_registry,
mock_registry, mock_registry,

View File

@ -2,9 +2,6 @@
import pytest import pytest
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
from homeassistant.components.device_automation import (
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.switch import DOMAIN from homeassistant.components.switch import DOMAIN
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
from homeassistant.helpers import device_registry from homeassistant.helpers import device_registry
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
async_get_device_automations,
async_mock_service, async_mock_service,
mock_device_registry, mock_device_registry,
mock_registry, mock_registry,

View File

@ -8,14 +8,11 @@ import zigpy.zcl.clusters.security as security
import zigpy.zcl.foundation as zcl_f import zigpy.zcl.foundation as zcl_f
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
from homeassistant.components.device_automation import (
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.zha import DOMAIN from homeassistant.components.zha import DOMAIN
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import async_mock_service, mock_coro from tests.common import async_get_device_automations, async_mock_service, mock_coro
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
SHORT_PRESS = "remote_button_short_press" SHORT_PRESS = "remote_button_short_press"