Allow targeting areas in service calls (#21472)

* Allow targeting areas in service calls

* Lint + Type

* Address comments
This commit is contained in:
Paulus Schoutsen 2019-03-04 09:51:12 -08:00 committed by GitHub
parent f62eb22ef8
commit 8213016eaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 180 additions and 62 deletions

View File

@ -17,7 +17,7 @@ import voluptuous as vol
import homeassistant.core as ha import homeassistant.core as ha
import homeassistant.config as conf_util import homeassistant.config as conf_util
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.service import extract_entity_ids from homeassistant.helpers.service import async_extract_entity_ids
from homeassistant.helpers import intent from homeassistant.helpers import intent
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, ATTR_ENTITY_ID, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE,
@ -70,7 +70,7 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> Awaitable[bool]:
"""Set up general services related to Home Assistant.""" """Set up general services related to Home Assistant."""
async def async_handle_turn_service(service): async def async_handle_turn_service(service):
"""Handle calls to homeassistant.turn_on/off.""" """Handle calls to homeassistant.turn_on/off."""
entity_ids = extract_entity_ids(hass, service) entity_ids = await async_extract_entity_ids(hass, service)
# Generic turn on/off method requires entity id # Generic turn on/off method requires entity id
if not entity_ids: if not entity_ids:

View File

@ -89,7 +89,7 @@ async def async_setup(hass, config):
async def async_handle_alert_service(service_call): async def async_handle_alert_service(service_call):
"""Handle calls to alert services.""" """Handle calls to alert services."""
alert_ids = service.extract_entity_ids(hass, service_call) alert_ids = await service.async_extract_entity_ids(hass, service_call)
for alert_id in alert_ids: for alert_id in alert_ids:
for alert in entities: for alert in entities:

View File

@ -120,7 +120,7 @@ async def async_setup(hass, config):
async def trigger_service_handler(service_call): async def trigger_service_handler(service_call):
"""Handle automation triggers.""" """Handle automation triggers."""
tasks = [] tasks = []
for entity in component.async_extract_from_service(service_call): for entity in await component.async_extract_from_service(service_call):
tasks.append(entity.async_trigger( tasks.append(entity.async_trigger(
service_call.data.get(ATTR_VARIABLES), service_call.data.get(ATTR_VARIABLES),
skip_condition=True, skip_condition=True,
@ -133,7 +133,7 @@ async def async_setup(hass, config):
"""Handle automation turn on/off service calls.""" """Handle automation turn on/off service calls."""
tasks = [] tasks = []
method = 'async_{}'.format(service_call.service) method = 'async_{}'.format(service_call.service)
for entity in component.async_extract_from_service(service_call): for entity in await component.async_extract_from_service(service_call):
tasks.append(getattr(entity, method)()) tasks.append(getattr(entity, method)())
if tasks: if tasks:
@ -142,7 +142,7 @@ async def async_setup(hass, config):
async def toggle_service_handler(service_call): async def toggle_service_handler(service_call):
"""Handle automation toggle service calls.""" """Handle automation toggle service calls."""
tasks = [] tasks = []
for entity in component.async_extract_from_service(service_call): for entity in await component.async_extract_from_service(service_call):
if entity.is_on: if entity.is_on:
tasks.append(entity.async_turn_off()) tasks.append(entity.async_turn_off())
else: else:

View File

@ -300,8 +300,8 @@ async def async_setup(hass, config):
visible = service.data.get(ATTR_VISIBLE) visible = service.data.get(ATTR_VISIBLE)
tasks = [] tasks = []
for group in component.async_extract_from_service(service, for group in await component.async_extract_from_service(
expand_group=False): service, expand_group=False):
group.visible = visible group.visible = visible
tasks.append(group.async_update_ha_state()) tasks.append(group.async_update_ha_state())

View File

@ -75,7 +75,7 @@ async def async_setup(hass, config):
async def async_scan_service(service): async def async_scan_service(service):
"""Service handler for scan.""" """Service handler for scan."""
image_entities = component.async_extract_from_service(service) image_entities = await component.async_extract_from_service(service)
update_tasks = [] update_tasks = []
for entity in image_entities: for entity in image_entities:

View File

@ -256,7 +256,7 @@ async def async_setup(hass, config):
params = service.data.copy() params = service.data.copy()
# Convert the entity ids to valid light ids # Convert the entity ids to valid light ids
target_lights = component.async_extract_from_service(service) target_lights = await component.async_extract_from_service(service)
params.pop(ATTR_ENTITY_ID, None) params.pop(ATTR_ENTITY_ID, None)
if service.context.user_id: if service.context.user_id:

View File

@ -68,7 +68,7 @@ async def async_setup(hass, config):
async def async_handle_scene_service(service): async def async_handle_scene_service(service):
"""Handle calls to the switch services.""" """Handle calls to the switch services."""
target_scenes = component.async_extract_from_service(service) target_scenes = await component.async_extract_from_service(service)
tasks = [scene.async_activate() for scene in target_scenes] tasks = [scene.async_activate() for scene in target_scenes]
if tasks: if tasks:

View File

@ -74,20 +74,21 @@ async def async_setup(hass, config):
# We could turn on script directly here, but we only want to offer # We could turn on script directly here, but we only want to offer
# one way to do it. Otherwise no easy way to detect invocations. # one way to do it. Otherwise no easy way to detect invocations.
var = service.data.get(ATTR_VARIABLES) var = service.data.get(ATTR_VARIABLES)
for script in component.async_extract_from_service(service): for script in await component.async_extract_from_service(service):
await hass.services.async_call(DOMAIN, script.object_id, var, await hass.services.async_call(DOMAIN, script.object_id, var,
context=service.context) context=service.context)
async def turn_off_service(service): async def turn_off_service(service):
"""Cancel a script.""" """Cancel a script."""
# Stopping a script is ok to be done in parallel # Stopping a script is ok to be done in parallel
await asyncio.wait( await asyncio.wait([
[script.async_turn_off() for script script.async_turn_off() for script
in component.async_extract_from_service(service)], loop=hass.loop) in await component.async_extract_from_service(service)
], loop=hass.loop)
async def toggle_service(service): async def toggle_service(service):
"""Toggle a script.""" """Toggle a script."""
for script in component.async_extract_from_service(service): for script in await component.async_extract_from_service(service):
await script.async_toggle(context=service.context) await script.async_toggle(context=service.context)
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service, hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,

View File

@ -245,6 +245,9 @@ ATTR_NAME = 'name'
# Contains one string or a list of strings, each being an entity id # Contains one string or a list of strings, each being an entity id
ATTR_ENTITY_ID = 'entity_id' ATTR_ENTITY_ID = 'entity_id'
# Contains one string or a list of strings, each being an area id
ATTR_AREA_ID = 'area_id'
# String with a friendly name for the entity # String with a friendly name for the entity
ATTR_FRIENDLY_NAME = 'friendly_name' ATTR_FRIENDLY_NAME = 'friendly_name'

View File

@ -37,6 +37,11 @@ class AreaRegistry:
self.areas = {} # type: MutableMapping[str, AreaEntry] self.areas = {} # type: MutableMapping[str, AreaEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback
def async_get_area(self, area_id: str) -> Optional[AreaEntry]:
"""Get all areas."""
return self.areas.get(area_id)
@callback @callback
def async_list_areas(self) -> Iterable[AreaEntry]: def async_list_areas(self) -> Iterable[AreaEntry]:
"""Get all areas.""" """Get all areas."""

View File

@ -1,6 +1,7 @@
"""Provide a way to connect entities belonging to one device.""" """Provide a way to connect entities belonging to one device."""
import logging import logging
import uuid import uuid
from typing import List
from collections import OrderedDict from collections import OrderedDict
@ -280,3 +281,11 @@ async def async_get_registry(hass) -> DeviceRegistry:
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg()) task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
return await task return await task
@callback
def async_entries_for_area(registry: DeviceRegistry, area_id: str) \
-> List[DeviceEntry]:
"""Return entries that match an area."""
return [device for device in registry.devices.values()
if device.area_id == area_id]

View File

@ -12,7 +12,7 @@ from homeassistant.const import (
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.service import extract_entity_ids from homeassistant.helpers.service import async_extract_entity_ids
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import slugify from homeassistant.util import slugify
from .entity_platform import EntityPlatform from .entity_platform import EntityPlatform
@ -153,8 +153,7 @@ class EntityComponent:
await platform.async_reset() await platform.async_reset()
return True return True
@callback async def async_extract_from_service(self, service, expand_group=True):
def async_extract_from_service(self, service, expand_group=True):
"""Extract all known and available entities from a service call. """Extract all known and available entities from a service call.
Will return all entities if no entities specified in call. Will return all entities if no entities specified in call.
@ -174,7 +173,8 @@ class EntityComponent:
return [entity for entity in self.entities if entity.available] return [entity for entity in self.entities if entity.available]
entity_ids = set(extract_entity_ids(self.hass, service, expand_group)) entity_ids = await async_extract_entity_ids(
self.hass, service, expand_group)
return [entity for entity in self.entities return [entity for entity in self.entities
if entity.available and entity.entity_id in entity_ids] if entity.available and entity.entity_id in entity_ids]

View File

@ -10,7 +10,7 @@ timer.
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
from typing import Optional from typing import Optional, List
import weakref import weakref
import attr import attr
@ -292,6 +292,14 @@ async def async_get_registry(hass) -> EntityRegistry:
return await task return await task
@callback
def async_entries_for_device(registry: EntityRegistry, device_id: str) \
-> List[RegistryEntry]:
"""Return entries that match a device."""
return [entry for entry in registry.entities.values()
if entry.device_id == device_id]
async def _async_migrate(entities): async def _async_migrate(entities):
"""Migrate the YAML config file to storage helper format.""" """Migrate the YAML config file to storage helper format."""
return { return {

View File

@ -6,7 +6,8 @@ from os import path
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL from homeassistant.const import (
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
from homeassistant.helpers import template from homeassistant.helpers import template
@ -89,30 +90,64 @@ async def async_call_from_config(hass, config, blocking=False, variables=None,
def extract_entity_ids(hass, service_call, expand_group=True): def extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call. """Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
"""
return run_coroutine_threadsafe(
async_extract_entity_ids(hass, service_call, expand_group), hass.loop
).result()
@bind_hass
async def async_extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
Async friendly. Async friendly.
""" """
if not (service_call.data and ATTR_ENTITY_ID in service_call.data): entity_ids = service_call.data.get(ATTR_ENTITY_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
if not entity_ids and not area_ids:
return [] return []
group = hass.components.group extracted = set()
# Entity ID attr can be a list or a string if entity_ids:
service_ent_id = service_call.data[ATTR_ENTITY_ID] # Entity ID attr can be a list or a string
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
if expand_group: if expand_group:
entity_ids = \
hass.components.group.expand_entity_ids(entity_ids)
if isinstance(service_ent_id, str): extracted.update(entity_ids)
return group.expand_entity_ids([service_ent_id])
return [ent_id for ent_id in if area_ids:
group.expand_entity_ids(service_ent_id)] if isinstance(area_ids, str):
area_ids = [area_ids]
if isinstance(service_ent_id, str): dev_reg, ent_reg = await asyncio.gather(
return [service_ent_id] hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(),
)
devices = [
device
for area_id in area_ids
for device in
hass.helpers.device_registry.async_entries_for_area(
dev_reg, area_id)
]
extracted.update(
entry.entity_id
for device in devices
for entry in
hass.helpers.entity_registry.async_entries_for_device(
ent_reg, device.id)
)
return service_ent_id return extracted
@bind_hass @bind_hass
@ -213,8 +248,7 @@ async def entity_service_call(hass, platforms, func, call, service_name=''):
if not target_all_entities: if not target_all_entities:
# A set of entities we're trying to target. # A set of entities we're trying to target.
entity_ids = set( entity_ids = await async_extract_entity_ids(hass, call, True)
extract_entity_ids(hass, call, True))
# If the service function is a string, we'll pass it the service call data # If the service function is a string, we'll pass it the service call data
if isinstance(func, str): if isinstance(func, str):

View File

@ -206,7 +206,7 @@ def test_extract_from_service_available_device(hass):
assert ['test_domain.test_1', 'test_domain.test_3'] == \ assert ['test_domain.test_1', 'test_domain.test_3'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
component.async_extract_from_service(call_1)) (yield from component.async_extract_from_service(call_1)))
call_2 = ha.ServiceCall('test', 'service', data={ call_2 = ha.ServiceCall('test', 'service', data={
'entity_id': ['test_domain.test_3', 'test_domain.test_4'], 'entity_id': ['test_domain.test_3', 'test_domain.test_4'],
@ -214,7 +214,7 @@ def test_extract_from_service_available_device(hass):
assert ['test_domain.test_3'] == \ assert ['test_domain.test_3'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
component.async_extract_from_service(call_2)) (yield from component.async_extract_from_service(call_2)))
@asyncio.coroutine @asyncio.coroutine
@ -275,7 +275,7 @@ def test_extract_from_service_returns_all_if_no_entity_id(hass):
assert ['test_domain.test_1', 'test_domain.test_2'] == \ assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
component.async_extract_from_service(call)) (yield from component.async_extract_from_service(call)))
@asyncio.coroutine @asyncio.coroutine
@ -293,7 +293,7 @@ def test_extract_from_service_filter_out_non_existing_entities(hass):
assert ['test_domain.test_2'] == \ assert ['test_domain.test_2'] == \
[ent.entity_id for ent [ent.entity_id for ent
in component.async_extract_from_service(call)] in (yield from component.async_extract_from_service(call))]
@asyncio.coroutine @asyncio.coroutine
@ -308,7 +308,8 @@ def test_extract_from_service_no_group_expand(hass):
'entity_id': ['group.test_group'] 'entity_id': ['group.test_group']
}) })
extracted = component.async_extract_from_service(call, expand_group=False) extracted = yield from component.async_extract_from_service(
call, expand_group=False)
assert extracted == [test_group] assert extracted == [test_group]
@ -466,7 +467,7 @@ async def test_extract_all_omit_entity_id(hass, caplog):
assert ['test_domain.test_1', 'test_domain.test_2'] == \ assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
component.async_extract_from_service(call)) await component.async_extract_from_service(call))
assert ('Not passing an entity ID to a service to target all entities is ' assert ('Not passing an entity ID to a service to target all entities is '
'deprecated') in caplog.text 'deprecated') in caplog.text
@ -483,6 +484,6 @@ async def test_extract_all_use_match_all(hass, caplog):
assert ['test_domain.test_1', 'test_domain.test_2'] == \ assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
component.async_extract_from_service(call)) await component.async_extract_from_service(call))
assert ('Not passing an entity ID to a service to target all entities is ' assert ('Not passing an entity ID to a service to target all entities is '
'deprecated') not in caplog.text 'deprecated') not in caplog.text

View File

@ -15,8 +15,11 @@ from homeassistant.helpers import service, template
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.auth.permissions import PolicyPermissions from homeassistant.auth.permissions import PolicyPermissions
from homeassistant.helpers import (
from tests.common import get_test_home_assistant, mock_service, mock_coro device_registry as dev_reg, entity_registry as ent_reg)
from tests.common import (
get_test_home_assistant, mock_service, mock_coro, mock_registry,
mock_device_registry)
@pytest.fixture @pytest.fixture
@ -163,29 +166,83 @@ class TestServiceHelpers(unittest.TestCase):
}) })
assert 3 == mock_log.call_count assert 3 == mock_log.call_count
def test_extract_entity_ids(self):
"""Test extract_entity_ids method."""
self.hass.states.set('light.Bowl', STATE_ON)
self.hass.states.set('light.Ceiling', STATE_OFF)
self.hass.states.set('light.Kitchen', STATE_OFF)
loader.get_component(self.hass, 'group').Group.create_group( async def test_extract_entity_ids(hass):
self.hass, 'test', ['light.Ceiling', 'light.Kitchen']) """Test extract_entity_ids method."""
hass.states.async_set('light.Bowl', STATE_ON)
hass.states.async_set('light.Ceiling', STATE_OFF)
hass.states.async_set('light.Kitchen', STATE_OFF)
call = ha.ServiceCall('light', 'turn_on', await loader.get_component(hass, 'group').Group.async_create_group(
{ATTR_ENTITY_ID: 'light.Bowl'}) hass, 'test', ['light.Ceiling', 'light.Kitchen'])
assert ['light.bowl'] == \ call = ha.ServiceCall('light', 'turn_on',
service.extract_entity_ids(self.hass, call) {ATTR_ENTITY_ID: 'light.Bowl'})
call = ha.ServiceCall('light', 'turn_on', assert {'light.bowl'} == \
{ATTR_ENTITY_ID: 'group.test'}) await service.async_extract_entity_ids(hass, call)
assert ['light.ceiling', 'light.kitchen'] == \ call = ha.ServiceCall('light', 'turn_on',
service.extract_entity_ids(self.hass, call) {ATTR_ENTITY_ID: 'group.test'})
assert ['group.test'] == service.extract_entity_ids( assert {'light.ceiling', 'light.kitchen'} == \
self.hass, call, expand_group=False) await service.async_extract_entity_ids(hass, call)
assert {'group.test'} == await service.async_extract_entity_ids(
hass, call, expand_group=False)
async def test_extract_entity_ids_from_area(hass):
"""Test extract_entity_ids method with areas."""
hass.states.async_set('light.Bowl', STATE_ON)
hass.states.async_set('light.Ceiling', STATE_OFF)
hass.states.async_set('light.Kitchen', STATE_OFF)
device_in_area = dev_reg.DeviceEntry(area_id='test-area')
device_no_area = dev_reg.DeviceEntry()
device_diff_area = dev_reg.DeviceEntry(area_id='diff-area')
mock_device_registry(hass, {
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
})
entity_in_area = ent_reg.RegistryEntry(
entity_id='light.in_area',
unique_id='in-area-id',
platform='test',
device_id=device_in_area.id,
)
entity_no_area = ent_reg.RegistryEntry(
entity_id='light.no_area',
unique_id='no-area-id',
platform='test',
device_id=device_no_area.id,
)
entity_diff_area = ent_reg.RegistryEntry(
entity_id='light.diff_area',
unique_id='diff-area-id',
platform='test',
device_id=device_diff_area.id,
)
mock_registry(hass, {
entity_in_area.entity_id: entity_in_area,
entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
})
call = ha.ServiceCall('light', 'turn_on',
{'area_id': 'test-area'})
assert {'light.in_area'} == \
await service.async_extract_entity_ids(hass, call)
call = ha.ServiceCall('light', 'turn_on',
{'area_id': ['test-area', 'diff-area']})
assert {'light.in_area', 'light.diff_area'} == \
await service.async_extract_entity_ids(hass, call)
@asyncio.coroutine @asyncio.coroutine