diff --git a/homeassistant/components/__init__.py b/homeassistant/components/__init__.py index 8715f0baa96..f3045df6a12 100644 --- a/homeassistant/components/__init__.py +++ b/homeassistant/components/__init__.py @@ -17,7 +17,7 @@ import voluptuous as vol import homeassistant.core as ha import homeassistant.config as conf_util from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers.service import extract_entity_ids +from homeassistant.helpers.service import async_extract_entity_ids from homeassistant.helpers import intent from homeassistant.const import ( ATTR_ENTITY_ID, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, @@ -70,7 +70,7 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> Awaitable[bool]: """Set up general services related to Home Assistant.""" async def async_handle_turn_service(service): """Handle calls to homeassistant.turn_on/off.""" - entity_ids = extract_entity_ids(hass, service) + entity_ids = await async_extract_entity_ids(hass, service) # Generic turn on/off method requires entity id if not entity_ids: diff --git a/homeassistant/components/alert/__init__.py b/homeassistant/components/alert/__init__.py index f92fd6b187b..4c990d62d4b 100644 --- a/homeassistant/components/alert/__init__.py +++ b/homeassistant/components/alert/__init__.py @@ -89,7 +89,7 @@ async def async_setup(hass, config): async def async_handle_alert_service(service_call): """Handle calls to alert services.""" - alert_ids = service.extract_entity_ids(hass, service_call) + alert_ids = await service.async_extract_entity_ids(hass, service_call) for alert_id in alert_ids: for alert in entities: diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index ad231a2a348..5a7b19ce4e3 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -120,7 +120,7 @@ async def async_setup(hass, config): async def trigger_service_handler(service_call): """Handle automation triggers.""" tasks = [] - for entity in component.async_extract_from_service(service_call): + for entity in await component.async_extract_from_service(service_call): tasks.append(entity.async_trigger( service_call.data.get(ATTR_VARIABLES), skip_condition=True, @@ -133,7 +133,7 @@ async def async_setup(hass, config): """Handle automation turn on/off service calls.""" tasks = [] method = 'async_{}'.format(service_call.service) - for entity in component.async_extract_from_service(service_call): + for entity in await component.async_extract_from_service(service_call): tasks.append(getattr(entity, method)()) if tasks: @@ -142,7 +142,7 @@ async def async_setup(hass, config): async def toggle_service_handler(service_call): """Handle automation toggle service calls.""" tasks = [] - for entity in component.async_extract_from_service(service_call): + for entity in await component.async_extract_from_service(service_call): if entity.is_on: tasks.append(entity.async_turn_off()) else: diff --git a/homeassistant/components/group/__init__.py b/homeassistant/components/group/__init__.py index e0315209ba1..80ac01a78ac 100644 --- a/homeassistant/components/group/__init__.py +++ b/homeassistant/components/group/__init__.py @@ -300,8 +300,8 @@ async def async_setup(hass, config): visible = service.data.get(ATTR_VISIBLE) tasks = [] - for group in component.async_extract_from_service(service, - expand_group=False): + for group in await component.async_extract_from_service( + service, expand_group=False): group.visible = visible tasks.append(group.async_update_ha_state()) diff --git a/homeassistant/components/image_processing/__init__.py b/homeassistant/components/image_processing/__init__.py index f854384bb03..aa3b2db7369 100644 --- a/homeassistant/components/image_processing/__init__.py +++ b/homeassistant/components/image_processing/__init__.py @@ -75,7 +75,7 @@ async def async_setup(hass, config): async def async_scan_service(service): """Service handler for scan.""" - image_entities = component.async_extract_from_service(service) + image_entities = await component.async_extract_from_service(service) update_tasks = [] for entity in image_entities: diff --git a/homeassistant/components/light/__init__.py b/homeassistant/components/light/__init__.py index 93d7a67c6f0..ef82167b222 100644 --- a/homeassistant/components/light/__init__.py +++ b/homeassistant/components/light/__init__.py @@ -256,7 +256,7 @@ async def async_setup(hass, config): params = service.data.copy() # Convert the entity ids to valid light ids - target_lights = component.async_extract_from_service(service) + target_lights = await component.async_extract_from_service(service) params.pop(ATTR_ENTITY_ID, None) if service.context.user_id: diff --git a/homeassistant/components/scene/__init__.py b/homeassistant/components/scene/__init__.py index 8a7934bd694..35eedabd58a 100644 --- a/homeassistant/components/scene/__init__.py +++ b/homeassistant/components/scene/__init__.py @@ -68,7 +68,7 @@ async def async_setup(hass, config): async def async_handle_scene_service(service): """Handle calls to the switch services.""" - target_scenes = component.async_extract_from_service(service) + target_scenes = await component.async_extract_from_service(service) tasks = [scene.async_activate() for scene in target_scenes] if tasks: diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index fceedb57428..873a18120ac 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -74,20 +74,21 @@ async def async_setup(hass, config): # We could turn on script directly here, but we only want to offer # one way to do it. Otherwise no easy way to detect invocations. var = service.data.get(ATTR_VARIABLES) - for script in component.async_extract_from_service(service): + for script in await component.async_extract_from_service(service): await hass.services.async_call(DOMAIN, script.object_id, var, context=service.context) async def turn_off_service(service): """Cancel a script.""" # Stopping a script is ok to be done in parallel - await asyncio.wait( - [script.async_turn_off() for script - in component.async_extract_from_service(service)], loop=hass.loop) + await asyncio.wait([ + script.async_turn_off() for script + in await component.async_extract_from_service(service) + ], loop=hass.loop) async def toggle_service(service): """Toggle a script.""" - for script in component.async_extract_from_service(service): + for script in await component.async_extract_from_service(service): await script.async_toggle(context=service.context) hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service, diff --git a/homeassistant/const.py b/homeassistant/const.py index 49194c06c17..f24fbcc97ac 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -245,6 +245,9 @@ ATTR_NAME = 'name' # Contains one string or a list of strings, each being an entity id ATTR_ENTITY_ID = 'entity_id' +# Contains one string or a list of strings, each being an area id +ATTR_AREA_ID = 'area_id' + # String with a friendly name for the entity ATTR_FRIENDLY_NAME = 'friendly_name' diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 3fa820f8350..644d14cf869 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -37,6 +37,11 @@ class AreaRegistry: self.areas = {} # type: MutableMapping[str, AreaEntry] self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + @callback + def async_get_area(self, area_id: str) -> Optional[AreaEntry]: + """Get all areas.""" + return self.areas.get(area_id) + @callback def async_list_areas(self) -> Iterable[AreaEntry]: """Get all areas.""" diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 21c3b0d0209..9c8ee27d0d2 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -1,6 +1,7 @@ """Provide a way to connect entities belonging to one device.""" import logging import uuid +from typing import List from collections import OrderedDict @@ -280,3 +281,11 @@ async def async_get_registry(hass) -> DeviceRegistry: task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg()) return await task + + +@callback +def async_entries_for_area(registry: DeviceRegistry, area_id: str) \ + -> List[DeviceEntry]: + """Return entries that match an area.""" + return [device for device in registry.devices.values() + if device.area_id == area_id] diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 44213e6d7c8..744cf36ea66 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -12,7 +12,7 @@ from homeassistant.const import ( from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_per_platform, discovery -from homeassistant.helpers.service import extract_entity_ids +from homeassistant.helpers.service import async_extract_entity_ids from homeassistant.loader import bind_hass from homeassistant.util import slugify from .entity_platform import EntityPlatform @@ -153,8 +153,7 @@ class EntityComponent: await platform.async_reset() return True - @callback - def async_extract_from_service(self, service, expand_group=True): + async def async_extract_from_service(self, service, expand_group=True): """Extract all known and available entities from a service call. Will return all entities if no entities specified in call. @@ -174,7 +173,8 @@ class EntityComponent: return [entity for entity in self.entities if entity.available] - entity_ids = set(extract_entity_ids(self.hass, service, expand_group)) + entity_ids = await async_extract_entity_ids( + self.hass, service, expand_group) return [entity for entity in self.entities if entity.available and entity.entity_id in entity_ids] diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 6ee32f642bc..c0a0dfaa7d9 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -10,7 +10,7 @@ timer. from collections import OrderedDict from itertools import chain import logging -from typing import Optional +from typing import Optional, List import weakref import attr @@ -292,6 +292,14 @@ async def async_get_registry(hass) -> EntityRegistry: return await task +@callback +def async_entries_for_device(registry: EntityRegistry, device_id: str) \ + -> List[RegistryEntry]: + """Return entries that match a device.""" + return [entry for entry in registry.entities.values() + if entry.device_id == device_id] + + async def _async_migrate(entities): """Migrate the YAML config file to storage helper format.""" return { diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 22138d7c2aa..b685e0d67c7 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -6,7 +6,8 @@ from os import path import voluptuous as vol from homeassistant.auth.permissions.const import POLICY_CONTROL -from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL +from homeassistant.const import ( + ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID) import homeassistant.core as ha from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser from homeassistant.helpers import template @@ -89,30 +90,64 @@ async def async_call_from_config(hass, config, blocking=False, variables=None, def extract_entity_ids(hass, service_call, expand_group=True): """Extract a list of entity ids from a service call. + Will convert group entity ids to the entity ids it represents. + """ + return run_coroutine_threadsafe( + async_extract_entity_ids(hass, service_call, expand_group), hass.loop + ).result() + + +@bind_hass +async def async_extract_entity_ids(hass, service_call, expand_group=True): + """Extract a list of entity ids from a service call. + Will convert group entity ids to the entity ids it represents. Async friendly. """ - if not (service_call.data and ATTR_ENTITY_ID in service_call.data): + entity_ids = service_call.data.get(ATTR_ENTITY_ID) + area_ids = service_call.data.get(ATTR_AREA_ID) + + if not entity_ids and not area_ids: return [] - group = hass.components.group + extracted = set() - # Entity ID attr can be a list or a string - service_ent_id = service_call.data[ATTR_ENTITY_ID] + if entity_ids: + # Entity ID attr can be a list or a string + if isinstance(entity_ids, str): + entity_ids = [entity_ids] - if expand_group: + if expand_group: + entity_ids = \ + hass.components.group.expand_entity_ids(entity_ids) - if isinstance(service_ent_id, str): - return group.expand_entity_ids([service_ent_id]) + extracted.update(entity_ids) - return [ent_id for ent_id in - group.expand_entity_ids(service_ent_id)] + if area_ids: + if isinstance(area_ids, str): + area_ids = [area_ids] - if isinstance(service_ent_id, str): - return [service_ent_id] + dev_reg, ent_reg = await asyncio.gather( + hass.helpers.device_registry.async_get_registry(), + hass.helpers.entity_registry.async_get_registry(), + ) + devices = [ + device + for area_id in area_ids + for device in + hass.helpers.device_registry.async_entries_for_area( + dev_reg, area_id) + ] + extracted.update( + entry.entity_id + for device in devices + for entry in + hass.helpers.entity_registry.async_entries_for_device( + ent_reg, device.id) + ) - return service_ent_id + return extracted @bind_hass @@ -213,8 +248,7 @@ async def entity_service_call(hass, platforms, func, call, service_name=''): if not target_all_entities: # A set of entities we're trying to target. - entity_ids = set( - extract_entity_ids(hass, call, True)) + entity_ids = await async_extract_entity_ids(hass, call, True) # If the service function is a string, we'll pass it the service call data if isinstance(func, str): diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 27e33a4fe7d..163261a4b81 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -206,7 +206,7 @@ def test_extract_from_service_available_device(hass): assert ['test_domain.test_1', 'test_domain.test_3'] == \ sorted(ent.entity_id for ent in - component.async_extract_from_service(call_1)) + (yield from component.async_extract_from_service(call_1))) call_2 = ha.ServiceCall('test', 'service', data={ 'entity_id': ['test_domain.test_3', 'test_domain.test_4'], @@ -214,7 +214,7 @@ def test_extract_from_service_available_device(hass): assert ['test_domain.test_3'] == \ sorted(ent.entity_id for ent in - component.async_extract_from_service(call_2)) + (yield from component.async_extract_from_service(call_2))) @asyncio.coroutine @@ -275,7 +275,7 @@ def test_extract_from_service_returns_all_if_no_entity_id(hass): assert ['test_domain.test_1', 'test_domain.test_2'] == \ sorted(ent.entity_id for ent in - component.async_extract_from_service(call)) + (yield from component.async_extract_from_service(call))) @asyncio.coroutine @@ -293,7 +293,7 @@ def test_extract_from_service_filter_out_non_existing_entities(hass): assert ['test_domain.test_2'] == \ [ent.entity_id for ent - in component.async_extract_from_service(call)] + in (yield from component.async_extract_from_service(call))] @asyncio.coroutine @@ -308,7 +308,8 @@ def test_extract_from_service_no_group_expand(hass): 'entity_id': ['group.test_group'] }) - extracted = component.async_extract_from_service(call, expand_group=False) + extracted = yield from component.async_extract_from_service( + call, expand_group=False) assert extracted == [test_group] @@ -466,7 +467,7 @@ async def test_extract_all_omit_entity_id(hass, caplog): assert ['test_domain.test_1', 'test_domain.test_2'] == \ sorted(ent.entity_id for ent in - component.async_extract_from_service(call)) + await component.async_extract_from_service(call)) assert ('Not passing an entity ID to a service to target all entities is ' 'deprecated') in caplog.text @@ -483,6 +484,6 @@ async def test_extract_all_use_match_all(hass, caplog): assert ['test_domain.test_1', 'test_domain.test_2'] == \ sorted(ent.entity_id for ent in - component.async_extract_from_service(call)) + await component.async_extract_from_service(call)) assert ('Not passing an entity ID to a service to target all entities is ' 'deprecated') not in caplog.text diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 35e89fc5218..854ee9c74f6 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -15,8 +15,11 @@ from homeassistant.helpers import service, template from homeassistant.setup import async_setup_component import homeassistant.helpers.config_validation as cv from homeassistant.auth.permissions import PolicyPermissions - -from tests.common import get_test_home_assistant, mock_service, mock_coro +from homeassistant.helpers import ( + device_registry as dev_reg, entity_registry as ent_reg) +from tests.common import ( + get_test_home_assistant, mock_service, mock_coro, mock_registry, + mock_device_registry) @pytest.fixture @@ -163,29 +166,83 @@ class TestServiceHelpers(unittest.TestCase): }) assert 3 == mock_log.call_count - def test_extract_entity_ids(self): - """Test extract_entity_ids method.""" - self.hass.states.set('light.Bowl', STATE_ON) - self.hass.states.set('light.Ceiling', STATE_OFF) - self.hass.states.set('light.Kitchen', STATE_OFF) - loader.get_component(self.hass, 'group').Group.create_group( - self.hass, 'test', ['light.Ceiling', 'light.Kitchen']) +async def test_extract_entity_ids(hass): + """Test extract_entity_ids method.""" + hass.states.async_set('light.Bowl', STATE_ON) + hass.states.async_set('light.Ceiling', STATE_OFF) + hass.states.async_set('light.Kitchen', STATE_OFF) - call = ha.ServiceCall('light', 'turn_on', - {ATTR_ENTITY_ID: 'light.Bowl'}) + await loader.get_component(hass, 'group').Group.async_create_group( + hass, 'test', ['light.Ceiling', 'light.Kitchen']) - assert ['light.bowl'] == \ - service.extract_entity_ids(self.hass, call) + call = ha.ServiceCall('light', 'turn_on', + {ATTR_ENTITY_ID: 'light.Bowl'}) - call = ha.ServiceCall('light', 'turn_on', - {ATTR_ENTITY_ID: 'group.test'}) + assert {'light.bowl'} == \ + await service.async_extract_entity_ids(hass, call) - assert ['light.ceiling', 'light.kitchen'] == \ - service.extract_entity_ids(self.hass, call) + call = ha.ServiceCall('light', 'turn_on', + {ATTR_ENTITY_ID: 'group.test'}) - assert ['group.test'] == service.extract_entity_ids( - self.hass, call, expand_group=False) + assert {'light.ceiling', 'light.kitchen'} == \ + await service.async_extract_entity_ids(hass, call) + + assert {'group.test'} == await service.async_extract_entity_ids( + hass, call, expand_group=False) + + +async def test_extract_entity_ids_from_area(hass): + """Test extract_entity_ids method with areas.""" + hass.states.async_set('light.Bowl', STATE_ON) + hass.states.async_set('light.Ceiling', STATE_OFF) + hass.states.async_set('light.Kitchen', STATE_OFF) + + device_in_area = dev_reg.DeviceEntry(area_id='test-area') + device_no_area = dev_reg.DeviceEntry() + device_diff_area = dev_reg.DeviceEntry(area_id='diff-area') + + mock_device_registry(hass, { + device_in_area.id: device_in_area, + device_no_area.id: device_no_area, + device_diff_area.id: device_diff_area, + }) + + entity_in_area = ent_reg.RegistryEntry( + entity_id='light.in_area', + unique_id='in-area-id', + platform='test', + device_id=device_in_area.id, + ) + entity_no_area = ent_reg.RegistryEntry( + entity_id='light.no_area', + unique_id='no-area-id', + platform='test', + device_id=device_no_area.id, + ) + entity_diff_area = ent_reg.RegistryEntry( + entity_id='light.diff_area', + unique_id='diff-area-id', + platform='test', + device_id=device_diff_area.id, + ) + mock_registry(hass, { + entity_in_area.entity_id: entity_in_area, + entity_no_area.entity_id: entity_no_area, + entity_diff_area.entity_id: entity_diff_area, + }) + + call = ha.ServiceCall('light', 'turn_on', + {'area_id': 'test-area'}) + + assert {'light.in_area'} == \ + await service.async_extract_entity_ids(hass, call) + + call = ha.ServiceCall('light', 'turn_on', + {'area_id': ['test-area', 'diff-area']}) + + assert {'light.in_area', 'light.diff_area'} == \ + await service.async_extract_entity_ids(hass, call) @asyncio.coroutine