mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Add api to device_automation to return all matching devices (#53361)
This commit is contained in:
parent
ac29571db3
commit
4bde4504ec
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import wraps
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
@ -13,9 +13,12 @@ import voluptuous_serialize
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_registry import async_entries_for_device
|
||||
from homeassistant.loader import IntegrationNotFound
|
||||
from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
)
|
||||
from homeassistant.loader import IntegrationNotFound, bind_hass
|
||||
from homeassistant.requirements import async_get_integration_with_requirements
|
||||
|
||||
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
|
||||
@ -49,6 +52,16 @@ TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_device_automations(
|
||||
hass: HomeAssistant,
|
||||
automation_type: str,
|
||||
device_ids: Iterable[str] | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Return all the device automations for a type optionally limited to specific device ids."""
|
||||
return await _async_get_device_automations(hass, automation_type, device_ids)
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up device automation."""
|
||||
hass.components.websocket_api.async_register_command(
|
||||
@ -96,7 +109,7 @@ async def async_get_device_automation_platform(
|
||||
|
||||
|
||||
async def _async_get_device_automations_from_domain(
|
||||
hass, domain, automation_type, device_id
|
||||
hass, domain, automation_type, device_ids, return_exceptions
|
||||
):
|
||||
"""List device automations."""
|
||||
try:
|
||||
@ -104,48 +117,67 @@ async def _async_get_device_automations_from_domain(
|
||||
hass, domain, automation_type
|
||||
)
|
||||
except InvalidDeviceAutomationConfig:
|
||||
return None
|
||||
return {}
|
||||
|
||||
function_name = TYPES[automation_type][1]
|
||||
|
||||
return await getattr(platform, function_name)(hass, device_id)
|
||||
|
||||
|
||||
async def _async_get_device_automations(hass, automation_type, device_id):
|
||||
"""List device automations."""
|
||||
device_registry, entity_registry = await asyncio.gather(
|
||||
hass.helpers.device_registry.async_get_registry(),
|
||||
hass.helpers.entity_registry.async_get_registry(),
|
||||
return await asyncio.gather(
|
||||
*(
|
||||
getattr(platform, function_name)(hass, device_id)
|
||||
for device_id in device_ids
|
||||
),
|
||||
return_exceptions=return_exceptions,
|
||||
)
|
||||
|
||||
domains = set()
|
||||
automations: list[MutableMapping[str, Any]] = []
|
||||
device = device_registry.async_get(device_id)
|
||||
|
||||
if device is None:
|
||||
raise DeviceNotFound
|
||||
async def _async_get_device_automations(
|
||||
hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None
|
||||
) -> Mapping[str, list[dict[str, Any]]]:
|
||||
"""List device automations."""
|
||||
device_registry = dr.async_get(hass)
|
||||
entity_registry = er.async_get(hass)
|
||||
domain_devices: dict[str, set[str]] = {}
|
||||
device_entities_domains: dict[str, set[str]] = {}
|
||||
match_device_ids = set(device_ids or device_registry.devices)
|
||||
combined_results: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
for entry_id in device.config_entries:
|
||||
config_entry = hass.config_entries.async_get_entry(entry_id)
|
||||
domains.add(config_entry.domain)
|
||||
for entry in entity_registry.entities.values():
|
||||
if not entry.disabled_by and entry.device_id in match_device_ids:
|
||||
device_entities_domains.setdefault(entry.device_id, set()).add(entry.domain)
|
||||
|
||||
entity_entries = async_entries_for_device(entity_registry, device_id)
|
||||
for entity_entry in entity_entries:
|
||||
domains.add(entity_entry.domain)
|
||||
for device_id in match_device_ids:
|
||||
combined_results[device_id] = []
|
||||
device = device_registry.async_get(device_id)
|
||||
if device is None:
|
||||
raise DeviceNotFound
|
||||
for entry_id in device.config_entries:
|
||||
if config_entry := hass.config_entries.async_get_entry(entry_id):
|
||||
domain_devices.setdefault(config_entry.domain, set()).add(device_id)
|
||||
for domain in device_entities_domains.get(device_id, []):
|
||||
domain_devices.setdefault(domain, set()).add(device_id)
|
||||
|
||||
device_automations = await asyncio.gather(
|
||||
# If specific device ids were requested, we allow
|
||||
# InvalidDeviceAutomationConfig to be thrown, otherwise we skip
|
||||
# devices that do not have valid triggers
|
||||
return_exceptions = not bool(device_ids)
|
||||
|
||||
for domain_results in await asyncio.gather(
|
||||
*(
|
||||
_async_get_device_automations_from_domain(
|
||||
hass, domain, automation_type, device_id
|
||||
hass, domain, automation_type, domain_device_ids, return_exceptions
|
||||
)
|
||||
for domain in domains
|
||||
for domain, domain_device_ids in domain_devices.items()
|
||||
)
|
||||
)
|
||||
for device_automation in device_automations:
|
||||
if device_automation is not None:
|
||||
automations.extend(device_automation)
|
||||
):
|
||||
for device_results in domain_results:
|
||||
if device_results is None or isinstance(
|
||||
device_results, InvalidDeviceAutomationConfig
|
||||
):
|
||||
continue
|
||||
for automation in device_results:
|
||||
combined_results[automation["device_id"]].append(automation)
|
||||
|
||||
return automations
|
||||
return combined_results
|
||||
|
||||
|
||||
async def _async_get_device_automation_capabilities(hass, automation_type, automation):
|
||||
@ -207,7 +239,9 @@ def handle_device_errors(func):
|
||||
async def websocket_device_automation_list_actions(hass, connection, msg):
|
||||
"""Handle request for device actions."""
|
||||
device_id = msg["device_id"]
|
||||
actions = await _async_get_device_automations(hass, "action", device_id)
|
||||
actions = (await _async_get_device_automations(hass, "action", [device_id])).get(
|
||||
device_id
|
||||
)
|
||||
connection.send_result(msg["id"], actions)
|
||||
|
||||
|
||||
@ -222,7 +256,9 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
|
||||
async def websocket_device_automation_list_conditions(hass, connection, msg):
|
||||
"""Handle request for device conditions."""
|
||||
device_id = msg["device_id"]
|
||||
conditions = await _async_get_device_automations(hass, "condition", device_id)
|
||||
conditions = (
|
||||
await _async_get_device_automations(hass, "condition", [device_id])
|
||||
).get(device_id)
|
||||
connection.send_result(msg["id"], conditions)
|
||||
|
||||
|
||||
@ -237,7 +273,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
|
||||
async def websocket_device_automation_list_triggers(hass, connection, msg):
|
||||
"""Handle request for device triggers."""
|
||||
device_id = msg["device_id"]
|
||||
triggers = await _async_get_device_automations(hass, "trigger", device_id)
|
||||
triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get(
|
||||
device_id
|
||||
)
|
||||
connection.send_result(msg["id"], triggers)
|
||||
|
||||
|
||||
|
@ -29,10 +29,9 @@ from homeassistant.auth import (
|
||||
providers as auth_providers,
|
||||
)
|
||||
from homeassistant.auth.permissions import system_policies
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components import device_automation, recorder
|
||||
from homeassistant.components.device_automation import ( # noqa: F401
|
||||
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
|
||||
_async_get_device_automations as async_get_device_automations,
|
||||
)
|
||||
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||
from homeassistant.config import async_process_component_config
|
||||
@ -69,6 +68,16 @@ CLIENT_ID = "https://example.com/app"
|
||||
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
|
||||
|
||||
|
||||
async def async_get_device_automations(
|
||||
hass: HomeAssistant, automation_type: str, device_id: str
|
||||
) -> Any:
|
||||
"""Get a device automation for a single device id."""
|
||||
automations = await device_automation.async_get_device_automations(
|
||||
hass, automation_type, [device_id]
|
||||
)
|
||||
return automations.get(device_id)
|
||||
|
||||
|
||||
def threadsafe_callback_factory(func):
|
||||
"""Create threadsafe functions out of callbacks.
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""The test for light device automation."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import device_automation
|
||||
import homeassistant.components.automation as automation
|
||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
||||
@ -372,6 +373,76 @@ async def test_websocket_get_no_condition_capabilities(
|
||||
assert capabilities == expected_capabilities
|
||||
|
||||
|
||||
async def test_async_get_device_automations_single_device_trigger(
|
||||
hass, device_reg, entity_reg
|
||||
):
|
||||
"""Test we get can fetch the triggers for a device id."""
|
||||
await async_setup_component(hass, "device_automation", {})
|
||||
config_entry = MockConfigEntry(domain="test", data={})
|
||||
config_entry.add_to_hass(hass)
|
||||
device_entry = device_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(
|
||||
hass, "trigger", [device_entry.id]
|
||||
)
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 2
|
||||
|
||||
|
||||
async def test_async_get_device_automations_all_devices_trigger(
|
||||
hass, device_reg, entity_reg
|
||||
):
|
||||
"""Test we get can fetch all the triggers when no device id is passed."""
|
||||
await async_setup_component(hass, "device_automation", {})
|
||||
config_entry = MockConfigEntry(domain="test", data={})
|
||||
config_entry.add_to_hass(hass)
|
||||
device_entry = device_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(hass, "trigger")
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 2
|
||||
|
||||
|
||||
async def test_async_get_device_automations_all_devices_condition(
|
||||
hass, device_reg, entity_reg
|
||||
):
|
||||
"""Test we get can fetch all the conditions when no device id is passed."""
|
||||
await async_setup_component(hass, "device_automation", {})
|
||||
config_entry = MockConfigEntry(domain="test", data={})
|
||||
config_entry.add_to_hass(hass)
|
||||
device_entry = device_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(hass, "condition")
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 2
|
||||
|
||||
|
||||
async def test_async_get_device_automations_all_devices_action(
|
||||
hass, device_reg, entity_reg
|
||||
):
|
||||
"""Test we get can fetch all the actions when no device id is passed."""
|
||||
await async_setup_component(hass, "device_automation", {})
|
||||
config_entry = MockConfigEntry(domain="test", data={})
|
||||
config_entry.add_to_hass(hass)
|
||||
device_entry = device_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(hass, "action")
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 3
|
||||
|
||||
|
||||
async def test_websocket_get_trigger_capabilities(
|
||||
hass, hass_ws_client, device_reg, entity_reg
|
||||
):
|
||||
|
@ -2,9 +2,6 @@
|
||||
import pytest
|
||||
|
||||
import homeassistant.components.automation as automation
|
||||
from homeassistant.components.device_automation import (
|
||||
_async_get_device_automations as async_get_device_automations,
|
||||
)
|
||||
from homeassistant.components.remote import DOMAIN
|
||||
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
||||
from homeassistant.helpers import device_registry
|
||||
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
async_get_device_automations,
|
||||
async_mock_service,
|
||||
mock_device_registry,
|
||||
mock_registry,
|
||||
|
@ -2,9 +2,6 @@
|
||||
import pytest
|
||||
|
||||
import homeassistant.components.automation as automation
|
||||
from homeassistant.components.device_automation import (
|
||||
_async_get_device_automations as async_get_device_automations,
|
||||
)
|
||||
from homeassistant.components.switch import DOMAIN
|
||||
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
||||
from homeassistant.helpers import device_registry
|
||||
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
async_get_device_automations,
|
||||
async_mock_service,
|
||||
mock_device_registry,
|
||||
mock_registry,
|
||||
|
@ -8,14 +8,11 @@ import zigpy.zcl.clusters.security as security
|
||||
import zigpy.zcl.foundation as zcl_f
|
||||
|
||||
import homeassistant.components.automation as automation
|
||||
from homeassistant.components.device_automation import (
|
||||
_async_get_device_automations as async_get_device_automations,
|
||||
)
|
||||
from homeassistant.components.zha import DOMAIN
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service, mock_coro
|
||||
from tests.common import async_get_device_automations, async_mock_service, mock_coro
|
||||
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
|
||||
|
||||
SHORT_PRESS = "remote_button_short_press"
|
||||
|
Loading…
x
Reference in New Issue
Block a user