mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Add api to device_automation to return all matching devices (#53361)
This commit is contained in:
parent
ac29571db3
commit
4bde4504ec
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import Iterable, Mapping
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -13,9 +13,12 @@ import voluptuous_serialize
|
|||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
|
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import (
|
||||||
from homeassistant.helpers.entity_registry import async_entries_for_device
|
config_validation as cv,
|
||||||
from homeassistant.loader import IntegrationNotFound
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
|
)
|
||||||
|
from homeassistant.loader import IntegrationNotFound, bind_hass
|
||||||
from homeassistant.requirements import async_get_integration_with_requirements
|
from homeassistant.requirements import async_get_integration_with_requirements
|
||||||
|
|
||||||
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
|
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
|
||||||
@ -49,6 +52,16 @@ TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@bind_hass
|
||||||
|
async def async_get_device_automations(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
automation_type: str,
|
||||||
|
device_ids: Iterable[str] | None = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Return all the device automations for a type optionally limited to specific device ids."""
|
||||||
|
return await _async_get_device_automations(hass, automation_type, device_ids)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Set up device automation."""
|
"""Set up device automation."""
|
||||||
hass.components.websocket_api.async_register_command(
|
hass.components.websocket_api.async_register_command(
|
||||||
@ -96,7 +109,7 @@ async def async_get_device_automation_platform(
|
|||||||
|
|
||||||
|
|
||||||
async def _async_get_device_automations_from_domain(
|
async def _async_get_device_automations_from_domain(
|
||||||
hass, domain, automation_type, device_id
|
hass, domain, automation_type, device_ids, return_exceptions
|
||||||
):
|
):
|
||||||
"""List device automations."""
|
"""List device automations."""
|
||||||
try:
|
try:
|
||||||
@ -104,48 +117,67 @@ async def _async_get_device_automations_from_domain(
|
|||||||
hass, domain, automation_type
|
hass, domain, automation_type
|
||||||
)
|
)
|
||||||
except InvalidDeviceAutomationConfig:
|
except InvalidDeviceAutomationConfig:
|
||||||
return None
|
return {}
|
||||||
|
|
||||||
function_name = TYPES[automation_type][1]
|
function_name = TYPES[automation_type][1]
|
||||||
|
|
||||||
return await getattr(platform, function_name)(hass, device_id)
|
return await asyncio.gather(
|
||||||
|
*(
|
||||||
|
getattr(platform, function_name)(hass, device_id)
|
||||||
async def _async_get_device_automations(hass, automation_type, device_id):
|
for device_id in device_ids
|
||||||
"""List device automations."""
|
),
|
||||||
device_registry, entity_registry = await asyncio.gather(
|
return_exceptions=return_exceptions,
|
||||||
hass.helpers.device_registry.async_get_registry(),
|
|
||||||
hass.helpers.entity_registry.async_get_registry(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
domains = set()
|
|
||||||
automations: list[MutableMapping[str, Any]] = []
|
|
||||||
device = device_registry.async_get(device_id)
|
|
||||||
|
|
||||||
if device is None:
|
async def _async_get_device_automations(
|
||||||
raise DeviceNotFound
|
hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None
|
||||||
|
) -> Mapping[str, list[dict[str, Any]]]:
|
||||||
|
"""List device automations."""
|
||||||
|
device_registry = dr.async_get(hass)
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
domain_devices: dict[str, set[str]] = {}
|
||||||
|
device_entities_domains: dict[str, set[str]] = {}
|
||||||
|
match_device_ids = set(device_ids or device_registry.devices)
|
||||||
|
combined_results: dict[str, list[dict[str, Any]]] = {}
|
||||||
|
|
||||||
for entry_id in device.config_entries:
|
for entry in entity_registry.entities.values():
|
||||||
config_entry = hass.config_entries.async_get_entry(entry_id)
|
if not entry.disabled_by and entry.device_id in match_device_ids:
|
||||||
domains.add(config_entry.domain)
|
device_entities_domains.setdefault(entry.device_id, set()).add(entry.domain)
|
||||||
|
|
||||||
entity_entries = async_entries_for_device(entity_registry, device_id)
|
for device_id in match_device_ids:
|
||||||
for entity_entry in entity_entries:
|
combined_results[device_id] = []
|
||||||
domains.add(entity_entry.domain)
|
device = device_registry.async_get(device_id)
|
||||||
|
if device is None:
|
||||||
|
raise DeviceNotFound
|
||||||
|
for entry_id in device.config_entries:
|
||||||
|
if config_entry := hass.config_entries.async_get_entry(entry_id):
|
||||||
|
domain_devices.setdefault(config_entry.domain, set()).add(device_id)
|
||||||
|
for domain in device_entities_domains.get(device_id, []):
|
||||||
|
domain_devices.setdefault(domain, set()).add(device_id)
|
||||||
|
|
||||||
device_automations = await asyncio.gather(
|
# If specific device ids were requested, we allow
|
||||||
|
# InvalidDeviceAutomationConfig to be thrown, otherwise we skip
|
||||||
|
# devices that do not have valid triggers
|
||||||
|
return_exceptions = not bool(device_ids)
|
||||||
|
|
||||||
|
for domain_results in await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
_async_get_device_automations_from_domain(
|
_async_get_device_automations_from_domain(
|
||||||
hass, domain, automation_type, device_id
|
hass, domain, automation_type, domain_device_ids, return_exceptions
|
||||||
)
|
)
|
||||||
for domain in domains
|
for domain, domain_device_ids in domain_devices.items()
|
||||||
)
|
)
|
||||||
)
|
):
|
||||||
for device_automation in device_automations:
|
for device_results in domain_results:
|
||||||
if device_automation is not None:
|
if device_results is None or isinstance(
|
||||||
automations.extend(device_automation)
|
device_results, InvalidDeviceAutomationConfig
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
for automation in device_results:
|
||||||
|
combined_results[automation["device_id"]].append(automation)
|
||||||
|
|
||||||
return automations
|
return combined_results
|
||||||
|
|
||||||
|
|
||||||
async def _async_get_device_automation_capabilities(hass, automation_type, automation):
|
async def _async_get_device_automation_capabilities(hass, automation_type, automation):
|
||||||
@ -207,7 +239,9 @@ def handle_device_errors(func):
|
|||||||
async def websocket_device_automation_list_actions(hass, connection, msg):
|
async def websocket_device_automation_list_actions(hass, connection, msg):
|
||||||
"""Handle request for device actions."""
|
"""Handle request for device actions."""
|
||||||
device_id = msg["device_id"]
|
device_id = msg["device_id"]
|
||||||
actions = await _async_get_device_automations(hass, "action", device_id)
|
actions = (await _async_get_device_automations(hass, "action", [device_id])).get(
|
||||||
|
device_id
|
||||||
|
)
|
||||||
connection.send_result(msg["id"], actions)
|
connection.send_result(msg["id"], actions)
|
||||||
|
|
||||||
|
|
||||||
@ -222,7 +256,9 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
|
|||||||
async def websocket_device_automation_list_conditions(hass, connection, msg):
|
async def websocket_device_automation_list_conditions(hass, connection, msg):
|
||||||
"""Handle request for device conditions."""
|
"""Handle request for device conditions."""
|
||||||
device_id = msg["device_id"]
|
device_id = msg["device_id"]
|
||||||
conditions = await _async_get_device_automations(hass, "condition", device_id)
|
conditions = (
|
||||||
|
await _async_get_device_automations(hass, "condition", [device_id])
|
||||||
|
).get(device_id)
|
||||||
connection.send_result(msg["id"], conditions)
|
connection.send_result(msg["id"], conditions)
|
||||||
|
|
||||||
|
|
||||||
@ -237,7 +273,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
|
|||||||
async def websocket_device_automation_list_triggers(hass, connection, msg):
|
async def websocket_device_automation_list_triggers(hass, connection, msg):
|
||||||
"""Handle request for device triggers."""
|
"""Handle request for device triggers."""
|
||||||
device_id = msg["device_id"]
|
device_id = msg["device_id"]
|
||||||
triggers = await _async_get_device_automations(hass, "trigger", device_id)
|
triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get(
|
||||||
|
device_id
|
||||||
|
)
|
||||||
connection.send_result(msg["id"], triggers)
|
connection.send_result(msg["id"], triggers)
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,10 +29,9 @@ from homeassistant.auth import (
|
|||||||
providers as auth_providers,
|
providers as auth_providers,
|
||||||
)
|
)
|
||||||
from homeassistant.auth.permissions import system_policies
|
from homeassistant.auth.permissions import system_policies
|
||||||
from homeassistant.components import recorder
|
from homeassistant.components import device_automation, recorder
|
||||||
from homeassistant.components.device_automation import ( # noqa: F401
|
from homeassistant.components.device_automation import ( # noqa: F401
|
||||||
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
|
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
|
||||||
_async_get_device_automations as async_get_device_automations,
|
|
||||||
)
|
)
|
||||||
from homeassistant.components.mqtt.models import ReceiveMessage
|
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||||
from homeassistant.config import async_process_component_config
|
from homeassistant.config import async_process_component_config
|
||||||
@ -69,6 +68,16 @@ CLIENT_ID = "https://example.com/app"
|
|||||||
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
|
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_device_automations(
|
||||||
|
hass: HomeAssistant, automation_type: str, device_id: str
|
||||||
|
) -> Any:
|
||||||
|
"""Get a device automation for a single device id."""
|
||||||
|
automations = await device_automation.async_get_device_automations(
|
||||||
|
hass, automation_type, [device_id]
|
||||||
|
)
|
||||||
|
return automations.get(device_id)
|
||||||
|
|
||||||
|
|
||||||
def threadsafe_callback_factory(func):
|
def threadsafe_callback_factory(func):
|
||||||
"""Create threadsafe functions out of callbacks.
|
"""Create threadsafe functions out of callbacks.
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""The test for light device automation."""
|
"""The test for light device automation."""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import device_automation
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||||
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
||||||
@ -372,6 +373,76 @@ async def test_websocket_get_no_condition_capabilities(
|
|||||||
assert capabilities == expected_capabilities
|
assert capabilities == expected_capabilities
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_device_automations_single_device_trigger(
|
||||||
|
hass, device_reg, entity_reg
|
||||||
|
):
|
||||||
|
"""Test we get can fetch the triggers for a device id."""
|
||||||
|
await async_setup_component(hass, "device_automation", {})
|
||||||
|
config_entry = MockConfigEntry(domain="test", data={})
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
device_entry = device_reg.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
)
|
||||||
|
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||||
|
result = await device_automation.async_get_device_automations(
|
||||||
|
hass, "trigger", [device_entry.id]
|
||||||
|
)
|
||||||
|
assert device_entry.id in result
|
||||||
|
assert len(result[device_entry.id]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_device_automations_all_devices_trigger(
|
||||||
|
hass, device_reg, entity_reg
|
||||||
|
):
|
||||||
|
"""Test we get can fetch all the triggers when no device id is passed."""
|
||||||
|
await async_setup_component(hass, "device_automation", {})
|
||||||
|
config_entry = MockConfigEntry(domain="test", data={})
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
device_entry = device_reg.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
)
|
||||||
|
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||||
|
result = await device_automation.async_get_device_automations(hass, "trigger")
|
||||||
|
assert device_entry.id in result
|
||||||
|
assert len(result[device_entry.id]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_device_automations_all_devices_condition(
|
||||||
|
hass, device_reg, entity_reg
|
||||||
|
):
|
||||||
|
"""Test we get can fetch all the conditions when no device id is passed."""
|
||||||
|
await async_setup_component(hass, "device_automation", {})
|
||||||
|
config_entry = MockConfigEntry(domain="test", data={})
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
device_entry = device_reg.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
)
|
||||||
|
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||||
|
result = await device_automation.async_get_device_automations(hass, "condition")
|
||||||
|
assert device_entry.id in result
|
||||||
|
assert len(result[device_entry.id]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_device_automations_all_devices_action(
|
||||||
|
hass, device_reg, entity_reg
|
||||||
|
):
|
||||||
|
"""Test we get can fetch all the actions when no device id is passed."""
|
||||||
|
await async_setup_component(hass, "device_automation", {})
|
||||||
|
config_entry = MockConfigEntry(domain="test", data={})
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
device_entry = device_reg.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||||
|
)
|
||||||
|
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||||
|
result = await device_automation.async_get_device_automations(hass, "action")
|
||||||
|
assert device_entry.id in result
|
||||||
|
assert len(result[device_entry.id]) == 3
|
||||||
|
|
||||||
|
|
||||||
async def test_websocket_get_trigger_capabilities(
|
async def test_websocket_get_trigger_capabilities(
|
||||||
hass, hass_ws_client, device_reg, entity_reg
|
hass, hass_ws_client, device_reg, entity_reg
|
||||||
):
|
):
|
||||||
|
@ -2,9 +2,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
from homeassistant.components.device_automation import (
|
|
||||||
_async_get_device_automations as async_get_device_automations,
|
|
||||||
)
|
|
||||||
from homeassistant.components.remote import DOMAIN
|
from homeassistant.components.remote import DOMAIN
|
||||||
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
||||||
from homeassistant.helpers import device_registry
|
from homeassistant.helpers import device_registry
|
||||||
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
|
|||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockConfigEntry,
|
MockConfigEntry,
|
||||||
|
async_get_device_automations,
|
||||||
async_mock_service,
|
async_mock_service,
|
||||||
mock_device_registry,
|
mock_device_registry,
|
||||||
mock_registry,
|
mock_registry,
|
||||||
|
@ -2,9 +2,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
from homeassistant.components.device_automation import (
|
|
||||||
_async_get_device_automations as async_get_device_automations,
|
|
||||||
)
|
|
||||||
from homeassistant.components.switch import DOMAIN
|
from homeassistant.components.switch import DOMAIN
|
||||||
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
|
||||||
from homeassistant.helpers import device_registry
|
from homeassistant.helpers import device_registry
|
||||||
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
|
|||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockConfigEntry,
|
MockConfigEntry,
|
||||||
|
async_get_device_automations,
|
||||||
async_mock_service,
|
async_mock_service,
|
||||||
mock_device_registry,
|
mock_device_registry,
|
||||||
mock_registry,
|
mock_registry,
|
||||||
|
@ -8,14 +8,11 @@ import zigpy.zcl.clusters.security as security
|
|||||||
import zigpy.zcl.foundation as zcl_f
|
import zigpy.zcl.foundation as zcl_f
|
||||||
|
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
from homeassistant.components.device_automation import (
|
|
||||||
_async_get_device_automations as async_get_device_automations,
|
|
||||||
)
|
|
||||||
from homeassistant.components.zha import DOMAIN
|
from homeassistant.components.zha import DOMAIN
|
||||||
from homeassistant.helpers import device_registry as dr
|
from homeassistant.helpers import device_registry as dr
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import async_mock_service, mock_coro
|
from tests.common import async_get_device_automations, async_mock_service, mock_coro
|
||||||
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
|
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
|
||||||
|
|
||||||
SHORT_PRESS = "remote_button_short_press"
|
SHORT_PRESS = "remote_button_short_press"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user