From 5601fbdc7a94230c54b438272b527531a53dc595 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 8 Feb 2018 03:16:51 -0800 Subject: [PATCH] Entity layer cleanup (#12237) * Simplify entity update * Split entity platform from entity component * Decouple entity platform from entity component * Always include unit of measurement again * Lint * Fix test --- homeassistant/helpers/entity.py | 63 +-- homeassistant/helpers/entity_component.py | 356 ++------------ homeassistant/helpers/entity_platform.py | 317 +++++++++++++ tests/common.py | 44 +- tests/helpers/test_entity_component.py | 535 ++-------------------- tests/helpers/test_entity_platform.py | 435 ++++++++++++++++++ tests/test_config.py | 12 - 7 files changed, 905 insertions(+), 857 deletions(-) create mode 100644 homeassistant/helpers/entity_platform.py create mode 100644 tests/helpers/test_entity_platform.py diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index d1e5c0d82a0..c7653d5d5b9 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -152,7 +152,7 @@ class Entity(object): @property def assumed_state(self) -> bool: """Return True if unable to access real state of the entity.""" - return None + return False @property def force_update(self) -> bool: @@ -221,21 +221,41 @@ class Entity(object): if device_attr is not None: attr.update(device_attr) - self._attr_setter('unit_of_measurement', str, ATTR_UNIT_OF_MEASUREMENT, - attr) + unit_of_measurement = self.unit_of_measurement + if unit_of_measurement is not None: + attr[ATTR_UNIT_OF_MEASUREMENT] = unit_of_measurement - self._attr_setter('name', str, ATTR_FRIENDLY_NAME, attr) - self._attr_setter('icon', str, ATTR_ICON, attr) - self._attr_setter('entity_picture', str, ATTR_ENTITY_PICTURE, attr) - self._attr_setter('hidden', bool, ATTR_HIDDEN, attr) - self._attr_setter('assumed_state', bool, ATTR_ASSUMED_STATE, attr) - self._attr_setter('supported_features', int, ATTR_SUPPORTED_FEATURES, - attr) - self._attr_setter('device_class', str, ATTR_DEVICE_CLASS, attr) + name = self.name + if name is not None: + attr[ATTR_FRIENDLY_NAME] = name + + icon = self.icon + if icon is not None: + attr[ATTR_ICON] = icon + + entity_picture = self.entity_picture + if entity_picture is not None: + attr[ATTR_ENTITY_PICTURE] = entity_picture + + hidden = self.hidden + if hidden: + attr[ATTR_HIDDEN] = hidden + + assumed_state = self.assumed_state + if assumed_state: + attr[ATTR_ASSUMED_STATE] = assumed_state + + supported_features = self.supported_features + if supported_features is not None: + attr[ATTR_SUPPORTED_FEATURES] = supported_features + + device_class = self.device_class + if device_class is not None: + attr[ATTR_DEVICE_CLASS] = str(device_class) end = timer() - if not self._slow_reported and end - start > 0.4: + if end - start > 0.4 and not self._slow_reported: self._slow_reported = True _LOGGER.warning("Updating state for %s (%s) took %.3f seconds. " "Please report platform to the developers at " @@ -246,10 +266,6 @@ class Entity(object): if DATA_CUSTOMIZE in self.hass.data: attr.update(self.hass.data[DATA_CUSTOMIZE].get(self.entity_id)) - # Remove hidden property if false so it won't show up. - if not attr.get(ATTR_HIDDEN, True): - attr.pop(ATTR_HIDDEN) - # Convert temperature if we detect one try: unit_of_measure = attr.get(ATTR_UNIT_OF_MEASUREMENT) @@ -321,21 +337,6 @@ class Entity(object): else: self.hass.states.async_remove(self.entity_id) - def _attr_setter(self, name, typ, attr, attrs): - """Populate attributes based on properties.""" - if attr in attrs: - return - - value = getattr(self, name) - - if value is None: - return - - try: - attrs[attr] = typ(value) - except (TypeError, ValueError): - pass - def __eq__(self, other): """Return the comparison.""" if not isinstance(other, self.__class__): diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 2c928f184e8..9dfbe580c16 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -6,25 +6,15 @@ from itertools import chain from homeassistant import config as conf_util from homeassistant.setup import async_prepare_setup_platform from homeassistant.const import ( - ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE, - DEVICE_DEFAULT_NAME) -from homeassistant.core import callback, valid_entity_id, split_entity_id -from homeassistant.exceptions import HomeAssistantError, PlatformNotReady + ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE) +from homeassistant.core import callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_per_platform, discovery -from homeassistant.helpers.event import ( - async_track_time_interval, async_track_point_in_time) from homeassistant.helpers.service import extract_entity_ids from homeassistant.util import slugify -from homeassistant.util.async import ( - run_callback_threadsafe, run_coroutine_threadsafe) -import homeassistant.util.dt as dt_util -from .entity_registry import EntityRegistry +from .entity_platform import EntityPlatform DEFAULT_SCAN_INTERVAL = timedelta(seconds=15) -SLOW_SETUP_WARNING = 10 -SLOW_SETUP_MAX_WAIT = 60 -PLATFORM_NOT_READY_RETRIES = 10 -DATA_REGISTRY = 'entity_registry' class EntityComponent(object): @@ -43,16 +33,23 @@ class EntityComponent(object): """Initialize an entity component.""" self.logger = logger self.hass = hass - self.domain = domain - self.entity_id_format = domain + '.{}' self.scan_interval = scan_interval self.group_name = group_name self.config = None self._platforms = { - 'core': EntityPlatform(self, domain, self.scan_interval, 0, None), + 'core': EntityPlatform( + hass=hass, + logger=logger, + domain=domain, + platform_name='core', + scan_interval=self.scan_interval, + parallel_updates=0, + entity_namespace=None, + async_entities_added_callback=self._async_update_group, + ) } self.async_add_entities = self._platforms['core'].async_add_entities self.add_entities = self._platforms['core'].add_entities @@ -107,17 +104,6 @@ class EntityComponent(object): discovery.async_listen_platform( self.hass, self.domain, component_platform_discovered) - def extract_from_service(self, service, expand_group=True): - """Extract all known entities from a service call. - - Will return all entities if no entities specified in call. - Will return an empty list if entities specified but unknown. - """ - return run_callback_threadsafe( - self.hass.loop, self.async_extract_from_service, service, - expand_group - ).result() - @callback def async_extract_from_service(self, service, expand_group=True): """Extract all known and available entities from a service call. @@ -136,11 +122,8 @@ class EntityComponent(object): @asyncio.coroutine def _async_setup_platform(self, platform_type, platform_config, - discovery_info=None, tries=0): - """Set up a platform for this component. - - This method must be run in the event loop. - """ + discovery_info=None): + """Set up a platform for this component.""" platform = yield from async_prepare_setup_platform( self.hass, self.config, self.domain, platform_type) @@ -161,59 +144,23 @@ class EntityComponent(object): if key not in self._platforms: entity_platform = self._platforms[key] = EntityPlatform( - self, platform_type, scan_interval, parallel_updates, - entity_namespace) + hass=self.hass, + logger=self.logger, + domain=self.domain, + platform_name=platform_type, + scan_interval=scan_interval, + parallel_updates=parallel_updates, + entity_namespace=entity_namespace, + async_entities_added_callback=self._async_update_group, + ) else: entity_platform = self._platforms[key] - self.logger.info("Setting up %s.%s", self.domain, platform_type) - warn_task = self.hass.loop.call_later( - SLOW_SETUP_WARNING, self.logger.warning, - "Setup of platform %s is taking over %s seconds.", platform_type, - SLOW_SETUP_WARNING) - - try: - if getattr(platform, 'async_setup_platform', None): - task = platform.async_setup_platform( - self.hass, platform_config, - entity_platform.async_schedule_add_entities, discovery_info - ) - else: - # This should not be replaced with hass.async_add_job because - # we don't want to track this task in case it blocks startup. - task = self.hass.loop.run_in_executor( - None, platform.setup_platform, self.hass, platform_config, - entity_platform.schedule_add_entities, discovery_info - ) - yield from asyncio.wait_for( - asyncio.shield(task, loop=self.hass.loop), - SLOW_SETUP_MAX_WAIT, loop=self.hass.loop) - yield from entity_platform.async_block_entities_done() - self.hass.config.components.add( - '{}.{}'.format(self.domain, platform_type)) - except PlatformNotReady: - tries += 1 - wait_time = min(tries, 6) * 30 - self.logger.warning( - 'Platform %s not ready yet. Retrying in %d seconds.', - platform_type, wait_time) - async_track_point_in_time( - self.hass, self._async_setup_platform( - platform_type, platform_config, discovery_info, tries), - dt_util.utcnow() + timedelta(seconds=wait_time)) - except asyncio.TimeoutError: - self.logger.error( - "Setup of platform %s is taking longer than %s seconds." - " Startup will proceed without waiting any longer.", - platform_type, SLOW_SETUP_MAX_WAIT) - except Exception: # pylint: disable=broad-except - self.logger.exception( - "Error while setting up platform %s", platform_type) - finally: - warn_task.cancel() + yield from entity_platform.async_setup( + platform, platform_config, discovery_info) @callback - def async_update_group(self): + def _async_update_group(self): """Set up and/or update component group. This method must be run in the event loop. @@ -230,12 +177,8 @@ class EntityComponent(object): visible=False, entity_ids=ids ) - def reset(self): - """Remove entities and reset the entity component to initial values.""" - run_coroutine_threadsafe(self.async_reset(), self.hass.loop).result() - @asyncio.coroutine - def async_reset(self): + def _async_reset(self): """Remove entities and reset the entity component to initial values. This method must be run in the event loop. @@ -261,11 +204,6 @@ class EntityComponent(object): if entity_id in platform.entities: yield from platform.async_remove_entity(entity_id) - def prepare_reload(self): - """Prepare reloading this entity component.""" - return run_coroutine_threadsafe( - self.async_prepare_reload(), loop=self.hass.loop).result() - @asyncio.coroutine def async_prepare_reload(self): """Prepare reloading this entity component. @@ -285,239 +223,5 @@ class EntityComponent(object): if conf is None: return None - yield from self.async_reset() + yield from self._async_reset() return conf - - -class EntityPlatform(object): - """Manage the entities for a single platform.""" - - def __init__(self, component, platform, scan_interval, parallel_updates, - entity_namespace): - """Initialize the entity platform.""" - self.component = component - self.platform = platform - self.scan_interval = scan_interval - self.parallel_updates = None - self.entity_namespace = entity_namespace - self.entities = {} - self._tasks = [] - self._async_unsub_polling = None - self._process_updates = asyncio.Lock(loop=component.hass.loop) - - if parallel_updates: - self.parallel_updates = asyncio.Semaphore( - parallel_updates, loop=component.hass.loop) - - @asyncio.coroutine - def async_block_entities_done(self): - """Wait until all entities add to hass.""" - if self._tasks: - pending = [task for task in self._tasks if not task.done()] - self._tasks.clear() - - if pending: - yield from asyncio.wait(pending, loop=self.component.hass.loop) - - def schedule_add_entities(self, new_entities, update_before_add=False): - """Add entities for a single platform.""" - run_callback_threadsafe( - self.component.hass.loop, - self.async_schedule_add_entities, list(new_entities), - update_before_add - ).result() - - @callback - def async_schedule_add_entities(self, new_entities, - update_before_add=False): - """Add entities for a single platform async.""" - self._tasks.append(self.component.hass.async_add_job( - self.async_add_entities( - new_entities, update_before_add=update_before_add) - )) - - def add_entities(self, new_entities, update_before_add=False): - """Add entities for a single platform.""" - # That avoid deadlocks - if update_before_add: - self.component.logger.warning( - "Call 'add_entities' with update_before_add=True " - "only inside tests or you can run into a deadlock!") - - run_coroutine_threadsafe( - self.async_add_entities(list(new_entities), update_before_add), - self.component.hass.loop).result() - - @asyncio.coroutine - def async_add_entities(self, new_entities, update_before_add=False): - """Add entities for a single platform async. - - This method must be run in the event loop. - """ - # handle empty list from component/platform - if not new_entities: - return - - hass = self.component.hass - component_entities = set(entity.entity_id for entity - in self.component.entities) - - registry = hass.data.get(DATA_REGISTRY) - - if registry is None: - registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass) - - yield from registry.async_ensure_loaded() - - tasks = [ - self._async_add_entity(entity, update_before_add, - component_entities, registry) - for entity in new_entities] - - yield from asyncio.wait(tasks, loop=self.component.hass.loop) - self.component.async_update_group() - - if self._async_unsub_polling is not None or \ - not any(entity.should_poll for entity - in self.entities.values()): - return - - self._async_unsub_polling = async_track_time_interval( - self.component.hass, self._update_entity_states, self.scan_interval - ) - - @asyncio.coroutine - def _async_add_entity(self, entity, update_before_add, component_entities, - registry): - """Helper method to add an entity to the platform.""" - if entity is None: - raise ValueError('Entity cannot be None') - - entity.hass = self.component.hass - entity.platform = self - entity.parallel_updates = self.parallel_updates - - # Update properties before we generate the entity_id - if update_before_add: - try: - yield from entity.async_device_update(warning=False) - except Exception: # pylint: disable=broad-except - self.component.logger.exception( - "%s: Error on device update!", self.platform) - return - - suggested_object_id = None - - # Get entity_id from unique ID registration - if entity.unique_id is not None: - if entity.entity_id is not None: - suggested_object_id = split_entity_id(entity.entity_id)[1] - else: - suggested_object_id = entity.name - - entry = registry.async_get_or_create( - self.component.domain, self.platform, entity.unique_id, - suggested_object_id=suggested_object_id) - entity.entity_id = entry.entity_id - - # We won't generate an entity ID if the platform has already set one - # We will however make sure that platform cannot pick a registered ID - elif (entity.entity_id is not None and - registry.async_is_registered(entity.entity_id)): - # If entity already registered, convert entity id to suggestion - suggested_object_id = split_entity_id(entity.entity_id)[1] - entity.entity_id = None - - # Generate entity ID - if entity.entity_id is None: - suggested_object_id = \ - suggested_object_id or entity.name or DEVICE_DEFAULT_NAME - - if self.entity_namespace is not None: - suggested_object_id = '{} {}'.format(self.entity_namespace, - suggested_object_id) - - entity.entity_id = registry.async_generate_entity_id( - self.component.domain, suggested_object_id) - - # Make sure it is valid in case an entity set the value themselves - if not valid_entity_id(entity.entity_id): - raise HomeAssistantError( - 'Invalid entity id: {}'.format(entity.entity_id)) - elif entity.entity_id in component_entities: - raise HomeAssistantError( - 'Entity id already exists: {}'.format(entity.entity_id)) - - self.entities[entity.entity_id] = entity - component_entities.add(entity.entity_id) - - if hasattr(entity, 'async_added_to_hass'): - yield from entity.async_added_to_hass() - - yield from entity.async_update_ha_state() - - @asyncio.coroutine - def async_reset(self): - """Remove all entities and reset data. - - This method must be run in the event loop. - """ - if not self.entities: - return - - tasks = [self._async_remove_entity(entity_id) - for entity_id in self.entities] - - yield from asyncio.wait(tasks, loop=self.component.hass.loop) - - if self._async_unsub_polling is not None: - self._async_unsub_polling() - self._async_unsub_polling = None - - @asyncio.coroutine - def async_remove_entity(self, entity_id): - """Remove entity id from platform.""" - yield from self._async_remove_entity(entity_id) - - # Clean up polling job if no longer needed - if (self._async_unsub_polling is not None and - not any(entity.should_poll for entity - in self.entities.values())): - self._async_unsub_polling() - self._async_unsub_polling = None - - @asyncio.coroutine - def _async_remove_entity(self, entity_id): - """Remove entity id from platform.""" - entity = self.entities.pop(entity_id) - - if hasattr(entity, 'async_will_remove_from_hass'): - yield from entity.async_will_remove_from_hass() - - self.component.hass.states.async_remove(entity_id) - - @asyncio.coroutine - def _update_entity_states(self, now): - """Update the states of all the polling entities. - - To protect from flooding the executor, we will update async entities - in parallel and other entities sequential. - - This method must be run in the event loop. - """ - if self._process_updates.locked(): - self.component.logger.warning( - "Updating %s %s took longer than the scheduled update " - "interval %s", self.platform, self.component.domain, - self.scan_interval) - return - - with (yield from self._process_updates): - tasks = [] - for entity in self.entities.values(): - if not entity.should_poll: - continue - tasks.append(entity.async_update_ha_state(True)) - - if tasks: - yield from asyncio.wait(tasks, loop=self.component.hass.loop) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py new file mode 100644 index 00000000000..3362f1e3b3f --- /dev/null +++ b/homeassistant/helpers/entity_platform.py @@ -0,0 +1,317 @@ +"""Class to manage the entities for a single platform.""" +import asyncio +from datetime import timedelta + +from homeassistant.const import DEVICE_DEFAULT_NAME +from homeassistant.core import callback, valid_entity_id, split_entity_id +from homeassistant.exceptions import HomeAssistantError, PlatformNotReady +from homeassistant.util.async import ( + run_callback_threadsafe, run_coroutine_threadsafe) +import homeassistant.util.dt as dt_util + +from .event import async_track_time_interval, async_track_point_in_time +from .entity_registry import EntityRegistry + +SLOW_SETUP_WARNING = 10 +SLOW_SETUP_MAX_WAIT = 60 +PLATFORM_NOT_READY_RETRIES = 10 +DATA_REGISTRY = 'entity_registry' + + +class EntityPlatform(object): + """Manage the entities for a single platform.""" + + def __init__(self, *, hass, logger, domain, platform_name, scan_interval, + parallel_updates, entity_namespace, + async_entities_added_callback): + """Initialize the entity platform. + + hass: HomeAssistant + logger: Logger + domain: str + platform_name: str + scan_interval: timedelta + parallel_updates: int + entity_namespace: str + async_entities_added_callback: @callback method + """ + self.hass = hass + self.logger = logger + self.domain = domain + self.platform_name = platform_name + self.scan_interval = scan_interval + self.parallel_updates = None + self.entity_namespace = entity_namespace + self.async_entities_added_callback = async_entities_added_callback + self.entities = {} + self._tasks = [] + self._async_unsub_polling = None + self._process_updates = asyncio.Lock(loop=hass.loop) + + if parallel_updates: + self.parallel_updates = asyncio.Semaphore( + parallel_updates, loop=hass.loop) + + @asyncio.coroutine + def async_setup(self, platform, platform_config, discovery_info=None, + tries=0): + """Setup the platform.""" + logger = self.logger + hass = self.hass + full_name = '{}.{}'.format(self.domain, self.platform_name) + + logger.info("Setting up %s", full_name) + warn_task = hass.loop.call_later( + SLOW_SETUP_WARNING, logger.warning, + "Setup of platform %s is taking over %s seconds.", + self.platform_name, SLOW_SETUP_WARNING) + + try: + if getattr(platform, 'async_setup_platform', None): + task = platform.async_setup_platform( + hass, platform_config, + self._async_schedule_add_entities, discovery_info + ) + else: + # This should not be replaced with hass.async_add_job because + # we don't want to track this task in case it blocks startup. + task = hass.loop.run_in_executor( + None, platform.setup_platform, hass, platform_config, + self._schedule_add_entities, discovery_info + ) + yield from asyncio.wait_for( + asyncio.shield(task, loop=hass.loop), + SLOW_SETUP_MAX_WAIT, loop=hass.loop) + + # Block till all entities are done + if self._tasks: + pending = [task for task in self._tasks if not task.done()] + self._tasks.clear() + + if pending: + yield from asyncio.wait( + pending, loop=self.hass.loop) + + hass.config.components.add(full_name) + except PlatformNotReady: + tries += 1 + wait_time = min(tries, 6) * 30 + logger.warning( + 'Platform %s not ready yet. Retrying in %d seconds.', + self.platform_name, wait_time) + async_track_point_in_time( + hass, self.async_setup( + platform, platform_config, discovery_info, tries), + dt_util.utcnow() + timedelta(seconds=wait_time)) + except asyncio.TimeoutError: + logger.error( + "Setup of platform %s is taking longer than %s seconds." + " Startup will proceed without waiting any longer.", + self.platform_name, SLOW_SETUP_MAX_WAIT) + except Exception: # pylint: disable=broad-except + logger.exception( + "Error while setting up platform %s", self.platform_name) + finally: + warn_task.cancel() + + def _schedule_add_entities(self, new_entities, update_before_add=False): + """Synchronously schedule adding entities for a single platform.""" + run_callback_threadsafe( + self.hass.loop, + self._async_schedule_add_entities, list(new_entities), + update_before_add + ).result() + + @callback + def _async_schedule_add_entities(self, new_entities, + update_before_add=False): + """Schedule adding entities for a single platform async.""" + self._tasks.append(self.hass.async_add_job( + self.async_add_entities( + new_entities, update_before_add=update_before_add) + )) + + def add_entities(self, new_entities, update_before_add=False): + """Add entities for a single platform.""" + # That avoid deadlocks + if update_before_add: + self.logger.warning( + "Call 'add_entities' with update_before_add=True " + "only inside tests or you can run into a deadlock!") + + run_coroutine_threadsafe( + self.async_add_entities(list(new_entities), update_before_add), + self.hass.loop).result() + + @asyncio.coroutine + def async_add_entities(self, new_entities, update_before_add=False): + """Add entities for a single platform async. + + This method must be run in the event loop. + """ + # handle empty list from component/platform + if not new_entities: + return + + hass = self.hass + component_entities = set(hass.states.async_entity_ids(self.domain)) + + registry = hass.data.get(DATA_REGISTRY) + + if registry is None: + registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass) + + yield from registry.async_ensure_loaded() + + tasks = [ + self._async_add_entity(entity, update_before_add, + component_entities, registry) + for entity in new_entities] + + yield from asyncio.wait(tasks, loop=self.hass.loop) + self.async_entities_added_callback() + + if self._async_unsub_polling is not None or \ + not any(entity.should_poll for entity + in self.entities.values()): + return + + self._async_unsub_polling = async_track_time_interval( + self.hass, self._update_entity_states, self.scan_interval + ) + + @asyncio.coroutine + def _async_add_entity(self, entity, update_before_add, component_entities, + registry): + """Helper method to add an entity to the platform.""" + if entity is None: + raise ValueError('Entity cannot be None') + + entity.hass = self.hass + entity.platform = self + entity.parallel_updates = self.parallel_updates + + # Update properties before we generate the entity_id + if update_before_add: + try: + yield from entity.async_device_update(warning=False) + except Exception: # pylint: disable=broad-except + self.logger.exception( + "%s: Error on device update!", self.platform_name) + return + + suggested_object_id = None + + # Get entity_id from unique ID registration + if entity.unique_id is not None: + if entity.entity_id is not None: + suggested_object_id = split_entity_id(entity.entity_id)[1] + else: + suggested_object_id = entity.name + + entry = registry.async_get_or_create( + self.domain, self.platform_name, entity.unique_id, + suggested_object_id=suggested_object_id) + entity.entity_id = entry.entity_id + + # We won't generate an entity ID if the platform has already set one + # We will however make sure that platform cannot pick a registered ID + elif (entity.entity_id is not None and + registry.async_is_registered(entity.entity_id)): + # If entity already registered, convert entity id to suggestion + suggested_object_id = split_entity_id(entity.entity_id)[1] + entity.entity_id = None + + # Generate entity ID + if entity.entity_id is None: + suggested_object_id = \ + suggested_object_id or entity.name or DEVICE_DEFAULT_NAME + + if self.entity_namespace is not None: + suggested_object_id = '{} {}'.format(self.entity_namespace, + suggested_object_id) + + entity.entity_id = registry.async_generate_entity_id( + self.domain, suggested_object_id) + + # Make sure it is valid in case an entity set the value themselves + if not valid_entity_id(entity.entity_id): + raise HomeAssistantError( + 'Invalid entity id: {}'.format(entity.entity_id)) + elif entity.entity_id in component_entities: + raise HomeAssistantError( + 'Entity id already exists: {}'.format(entity.entity_id)) + + self.entities[entity.entity_id] = entity + component_entities.add(entity.entity_id) + + if hasattr(entity, 'async_added_to_hass'): + yield from entity.async_added_to_hass() + + yield from entity.async_update_ha_state() + + @asyncio.coroutine + def async_reset(self): + """Remove all entities and reset data. + + This method must be run in the event loop. + """ + if not self.entities: + return + + tasks = [self._async_remove_entity(entity_id) + for entity_id in self.entities] + + yield from asyncio.wait(tasks, loop=self.hass.loop) + + if self._async_unsub_polling is not None: + self._async_unsub_polling() + self._async_unsub_polling = None + + @asyncio.coroutine + def async_remove_entity(self, entity_id): + """Remove entity id from platform.""" + yield from self._async_remove_entity(entity_id) + + # Clean up polling job if no longer needed + if (self._async_unsub_polling is not None and + not any(entity.should_poll for entity + in self.entities.values())): + self._async_unsub_polling() + self._async_unsub_polling = None + + @asyncio.coroutine + def _async_remove_entity(self, entity_id): + """Remove entity id from platform.""" + entity = self.entities.pop(entity_id) + + if hasattr(entity, 'async_will_remove_from_hass'): + yield from entity.async_will_remove_from_hass() + + self.hass.states.async_remove(entity_id) + + @asyncio.coroutine + def _update_entity_states(self, now): + """Update the states of all the polling entities. + + To protect from flooding the executor, we will update async entities + in parallel and other entities sequential. + + This method must be run in the event loop. + """ + if self._process_updates.locked(): + self.logger.warning( + "Updating %s %s took longer than the scheduled update " + "interval %s", self.platform_name, self.domain, + self.scan_interval) + return + + with (yield from self._process_updates): + tasks = [] + for entity in self.entities.values(): + if not entity.should_poll: + continue + tasks.append(entity.async_update_ha_state(True)) + + if tasks: + yield from asyncio.wait(tasks, loop=self.hass.loop) diff --git a/tests/common.py b/tests/common.py index ed4439c1c49..22af8ecb8a3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -14,7 +14,9 @@ from aiohttp import web from homeassistant import core as ha, loader from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config -from homeassistant.helpers import intent, dispatcher, entity, restore_state +from homeassistant.helpers import ( + intent, dispatcher, entity, restore_state, entity_registry, + entity_platform) from homeassistant.util.unit_system import METRIC_SYSTEM import homeassistant.util.dt as date_util import homeassistant.util.yaml as yaml @@ -22,7 +24,6 @@ from homeassistant.const import ( STATE_ON, STATE_OFF, DEVICE_DEFAULT_NAME, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE, ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE) -from homeassistant.helpers import entity_component, entity_registry from homeassistant.components import mqtt, recorder from homeassistant.components.http.auth import auth_middleware from homeassistant.components.http.const import ( @@ -320,7 +321,7 @@ def mock_registry(hass): """Mock the Entity Registry.""" registry = entity_registry.EntityRegistry(hass) registry.entities = {} - hass.data[entity_component.DATA_REGISTRY] = registry + hass.data[entity_platform.DATA_REGISTRY] = registry return registry @@ -585,3 +586,40 @@ class MockDependency: func(*args, **kwargs) return run_mocked + + +class MockEntity(entity.Entity): + """Mock Entity class.""" + + def __init__(self, **values): + """Initialize an entity.""" + self._values = values + + if 'entity_id' in values: + self.entity_id = values['entity_id'] + + @property + def name(self): + """Return the name of the entity.""" + return self._handle('name') + + @property + def should_poll(self): + """Return the ste of the polling.""" + return self._handle('should_poll') + + @property + def unique_id(self): + """Return the unique ID of the entity.""" + return self._handle('unique_id') + + @property + def available(self): + """Return True if entity is available.""" + return self._handle('available') + + def _handle(self, attr): + """Helper for the attributes.""" + if attr in self._values: + return self._values[attr] + return getattr(super(), attr) diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 349766d025e..ef92da3172b 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -4,67 +4,27 @@ import asyncio from collections import OrderedDict import logging import unittest -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock from datetime import timedelta import homeassistant.core as ha import homeassistant.loader as loader from homeassistant.exceptions import PlatformNotReady from homeassistant.components import group -from homeassistant.helpers.entity import Entity, generate_entity_id -from homeassistant.helpers.entity_component import ( - EntityComponent, DEFAULT_SCAN_INTERVAL, SLOW_SETUP_WARNING) -from homeassistant.helpers import entity_component +from homeassistant.helpers.entity_component import EntityComponent from homeassistant.setup import setup_component from homeassistant.helpers import discovery import homeassistant.util.dt as dt_util from tests.common import ( - get_test_home_assistant, MockPlatform, MockModule, fire_time_changed, - mock_coro, async_fire_time_changed, mock_registry) + get_test_home_assistant, MockPlatform, MockModule, mock_coro, + async_fire_time_changed, MockEntity) _LOGGER = logging.getLogger(__name__) DOMAIN = "test_domain" -class EntityTest(Entity): - """Test for the Entity component.""" - - def __init__(self, **values): - """Initialize an entity.""" - self._values = values - - if 'entity_id' in values: - self.entity_id = values['entity_id'] - - @property - def name(self): - """Return the name of the entity.""" - return self._handle('name') - - @property - def should_poll(self): - """Return the ste of the polling.""" - return self._handle('should_poll') - - @property - def unique_id(self): - """Return the unique ID of the entity.""" - return self._handle('unique_id') - - @property - def available(self): - """Return True if entity is available.""" - return self._handle('available') - - def _handle(self, attr): - """Helper for the attributes.""" - if attr in self._values: - return self._values[attr] - return getattr(super(), attr) - - class TestHelpersEntityComponent(unittest.TestCase): """Test homeassistant.helpers.entity_component module.""" @@ -85,7 +45,7 @@ class TestHelpersEntityComponent(unittest.TestCase): # No group after setup assert len(self.hass.states.entity_ids()) == 0 - component.add_entities([EntityTest()]) + component.add_entities([MockEntity()]) self.hass.block_till_done() # group exists @@ -98,7 +58,7 @@ class TestHelpersEntityComponent(unittest.TestCase): ('test_domain.unnamed_device',) # group extended - component.add_entities([EntityTest(name='goodbye')]) + component.add_entities([MockEntity(name='goodbye')]) self.hass.block_till_done() assert len(self.hass.states.entity_ids()) == 3 @@ -108,151 +68,6 @@ class TestHelpersEntityComponent(unittest.TestCase): assert group.attributes.get('entity_id') == \ ('test_domain.goodbye', 'test_domain.unnamed_device') - def test_polling_only_updates_entities_it_should_poll(self): - """Test the polling of only updated entities.""" - component = EntityComponent( - _LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) - - no_poll_ent = EntityTest(should_poll=False) - no_poll_ent.async_update = Mock() - poll_ent = EntityTest(should_poll=True) - poll_ent.async_update = Mock() - - component.add_entities([no_poll_ent, poll_ent]) - - no_poll_ent.async_update.reset_mock() - poll_ent.async_update.reset_mock() - - fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) - self.hass.block_till_done() - - assert not no_poll_ent.async_update.called - assert poll_ent.async_update.called - - def test_polling_updates_entities_with_exception(self): - """Test the updated entities that not break with an exception.""" - component = EntityComponent( - _LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) - - update_ok = [] - update_err = [] - - def update_mock(): - """Mock normal update.""" - update_ok.append(None) - - def update_mock_err(): - """Mock error update.""" - update_err.append(None) - raise AssertionError("Fake error update") - - ent1 = EntityTest(should_poll=True) - ent1.update = update_mock_err - ent2 = EntityTest(should_poll=True) - ent2.update = update_mock - ent3 = EntityTest(should_poll=True) - ent3.update = update_mock - ent4 = EntityTest(should_poll=True) - ent4.update = update_mock - - component.add_entities([ent1, ent2, ent3, ent4]) - - update_ok.clear() - update_err.clear() - - fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) - self.hass.block_till_done() - - assert len(update_ok) == 3 - assert len(update_err) == 1 - - def test_update_state_adds_entities(self): - """Test if updating poll entities cause an entity to be added works.""" - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - - ent1 = EntityTest() - ent2 = EntityTest(should_poll=True) - - component.add_entities([ent2]) - assert 1 == len(self.hass.states.entity_ids()) - ent2.update = lambda *_: component.add_entities([ent1]) - - fire_time_changed( - self.hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL - ) - self.hass.block_till_done() - - assert 2 == len(self.hass.states.entity_ids()) - - def test_update_state_adds_entities_with_update_before_add_true(self): - """Test if call update before add to state machine.""" - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - - ent = EntityTest() - ent.update = Mock(spec_set=True) - - component.add_entities([ent], True) - self.hass.block_till_done() - - assert 1 == len(self.hass.states.entity_ids()) - assert ent.update.called - - def test_update_state_adds_entities_with_update_before_add_false(self): - """Test if not call update before add to state machine.""" - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - - ent = EntityTest() - ent.update = Mock(spec_set=True) - - component.add_entities([ent], False) - self.hass.block_till_done() - - assert 1 == len(self.hass.states.entity_ids()) - assert not ent.update.called - - def test_extract_from_service_returns_all_if_no_entity_id(self): - """Test the extraction of everything from service.""" - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - component.add_entities([ - EntityTest(name='test_1'), - EntityTest(name='test_2'), - ]) - - call = ha.ServiceCall('test', 'service') - - assert ['test_domain.test_1', 'test_domain.test_2'] == \ - sorted(ent.entity_id for ent in - component.extract_from_service(call)) - - def test_extract_from_service_filter_out_non_existing_entities(self): - """Test the extraction of non existing entities from service.""" - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - component.add_entities([ - EntityTest(name='test_1'), - EntityTest(name='test_2'), - ]) - - call = ha.ServiceCall('test', 'service', { - 'entity_id': ['test_domain.test_2', 'test_domain.non_exist'] - }) - - assert ['test_domain.test_2'] == \ - [ent.entity_id for ent in component.extract_from_service(call)] - - def test_extract_from_service_no_group_expand(self): - """Test not expanding a group.""" - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - test_group = group.Group.create_group( - self.hass, 'test_group', ['light.Ceiling', 'light.Kitchen']) - component.add_entities([test_group]) - - call = ha.ServiceCall('test', 'service', { - 'entity_id': ['group.test_group'] - }) - - extracted = component.extract_from_service(call, expand_group=False) - self.assertEqual([test_group], extracted) - def test_setup_loads_platforms(self): """Test the loading of the platforms.""" component_setup = Mock(return_value=True) @@ -320,13 +135,13 @@ class TestHelpersEntityComponent(unittest.TestCase): assert ('platform_test', {}, {'msg': 'discovery_info'}) == \ mock_setup.call_args[0] - @patch('homeassistant.helpers.entity_component.' + @patch('homeassistant.helpers.entity_platform.' 'async_track_time_interval') def test_set_scan_interval_via_config(self, mock_track): """Test the setting of the scan interval via configuration.""" def platform_setup(hass, config, add_devices, discovery_info=None): """Test the platform setup.""" - add_devices([EntityTest(should_poll=True)]) + add_devices([MockEntity(should_poll=True)]) loader.set_component('test_domain.platform', MockPlatform(platform_setup)) @@ -344,38 +159,13 @@ class TestHelpersEntityComponent(unittest.TestCase): assert mock_track.called assert timedelta(seconds=30) == mock_track.call_args[0][2] - @patch('homeassistant.helpers.entity_component.' - 'async_track_time_interval') - def test_set_scan_interval_via_platform(self, mock_track): - """Test the setting of the scan interval via platform.""" - def platform_setup(hass, config, add_devices, discovery_info=None): - """Test the platform setup.""" - add_devices([EntityTest(should_poll=True)]) - - platform = MockPlatform(platform_setup) - platform.SCAN_INTERVAL = timedelta(seconds=30) - - loader.set_component('test_domain.platform', platform) - - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - - component.setup({ - DOMAIN: { - 'platform': 'platform', - } - }) - - self.hass.block_till_done() - assert mock_track.called - assert timedelta(seconds=30) == mock_track.call_args[0][2] - def test_set_entity_namespace_via_config(self): """Test setting an entity namespace.""" def platform_setup(hass, config, add_devices, discovery_info=None): """Test the platform setup.""" add_devices([ - EntityTest(name='beer'), - EntityTest(name=None), + MockEntity(name='beer'), + MockEntity(name=None), ]) platform = MockPlatform(platform_setup) @@ -396,83 +186,16 @@ class TestHelpersEntityComponent(unittest.TestCase): assert sorted(self.hass.states.entity_ids()) == \ ['test_domain.yummy_beer', 'test_domain.yummy_unnamed_device'] - def test_adding_entities_with_generator_and_thread_callback(self): - """Test generator in add_entities that calls thread method. - - We should make sure we resolve the generator to a list before passing - it into an async context. - """ - component = EntityComponent(_LOGGER, DOMAIN, self.hass) - - def create_entity(number): - """Create entity helper.""" - entity = EntityTest() - entity.entity_id = generate_entity_id(component.entity_id_format, - 'Number', hass=self.hass) - return entity - - component.add_entities(create_entity(i) for i in range(2)) - - -@asyncio.coroutine -def test_platform_warn_slow_setup(hass): - """Warn we log when platform setup takes a long time.""" - platform = MockPlatform() - - loader.set_component('test_domain.platform', platform) - - component = EntityComponent(_LOGGER, DOMAIN, hass) - - with patch.object(hass.loop, 'call_later', MagicMock()) \ - as mock_call: - yield from component.async_setup({ - DOMAIN: { - 'platform': 'platform', - } - }) - assert mock_call.called - - timeout, logger_method = mock_call.mock_calls[0][1][:2] - - assert timeout == SLOW_SETUP_WARNING - assert logger_method == _LOGGER.warning - - assert mock_call().cancel.called - - -@asyncio.coroutine -def test_platform_error_slow_setup(hass, caplog): - """Don't block startup more than SLOW_SETUP_MAX_WAIT.""" - with patch.object(entity_component, 'SLOW_SETUP_MAX_WAIT', 0): - called = [] - - @asyncio.coroutine - def setup_platform(*args): - called.append(1) - yield from asyncio.sleep(1, loop=hass.loop) - - platform = MockPlatform(async_setup_platform=setup_platform) - component = EntityComponent(_LOGGER, DOMAIN, hass) - loader.set_component('test_domain.test_platform', platform) - yield from component.async_setup({ - DOMAIN: { - 'platform': 'test_platform', - } - }) - assert len(called) == 1 - assert 'test_domain.test_platform' not in hass.config.components - assert 'test_platform is taking longer than 0 seconds' in caplog.text - @asyncio.coroutine def test_extract_from_service_available_device(hass): """Test the extraction of entity from service and device is available.""" component = EntityComponent(_LOGGER, DOMAIN, hass) yield from component.async_add_entities([ - EntityTest(name='test_1'), - EntityTest(name='test_2', available=False), - EntityTest(name='test_3'), - EntityTest(name='test_4', available=False), + MockEntity(name='test_1'), + MockEntity(name='test_2', available=False), + MockEntity(name='test_3'), + MockEntity(name='test_4', available=False), ]) call_1 = ha.ServiceCall('test', 'service') @@ -490,26 +213,6 @@ def test_extract_from_service_available_device(hass): component.async_extract_from_service(call_2)) -@asyncio.coroutine -def test_updated_state_used_for_entity_id(hass): - """Test that first update results used for entity ID generation.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - - class EntityTestNameFetcher(EntityTest): - """Mock entity that fetches a friendly name.""" - - @asyncio.coroutine - def async_update(self): - """Mock update that assigns a name.""" - self._values['name'] = "Living Room" - - yield from component.async_add_entities([EntityTestNameFetcher()], True) - - entity_ids = hass.states.async_entity_ids() - assert 1 == len(entity_ids) - assert entity_ids[0] == "test_domain.living_room" - - @asyncio.coroutine def test_platform_not_ready(hass): """Test that we retry when platform not ready.""" @@ -555,188 +258,50 @@ def test_platform_not_ready(hass): @asyncio.coroutine -def test_parallel_updates_async_platform(hass): - """Warn we log when platform setup takes a long time.""" - platform = MockPlatform() - - @asyncio.coroutine - def mock_update(*args, **kwargs): - pass - - platform.async_setup_platform = mock_update - - loader.set_component('test_domain.platform', platform) - +def test_extract_from_service_returns_all_if_no_entity_id(hass): + """Test the extraction of everything from service.""" component = EntityComponent(_LOGGER, DOMAIN, hass) - component._platforms = {} + yield from component.async_add_entities([ + MockEntity(name='test_1'), + MockEntity(name='test_2'), + ]) - yield from component.async_setup({ - DOMAIN: { - 'platform': 'platform', - } + call = ha.ServiceCall('test', 'service') + + assert ['test_domain.test_1', 'test_domain.test_2'] == \ + sorted(ent.entity_id for ent in + component.async_extract_from_service(call)) + + +@asyncio.coroutine +def test_extract_from_service_filter_out_non_existing_entities(hass): + """Test the extraction of non existing entities from service.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + yield from component.async_add_entities([ + MockEntity(name='test_1'), + MockEntity(name='test_2'), + ]) + + call = ha.ServiceCall('test', 'service', { + 'entity_id': ['test_domain.test_2', 'test_domain.non_exist'] }) - handle = list(component._platforms.values())[-1] - - assert handle.parallel_updates is None + assert ['test_domain.test_2'] == \ + [ent.entity_id for ent + in component.async_extract_from_service(call)] @asyncio.coroutine -def test_parallel_updates_async_platform_with_constant(hass): - """Warn we log when platform setup takes a long time.""" - platform = MockPlatform() - - @asyncio.coroutine - def mock_update(*args, **kwargs): - pass - - platform.async_setup_platform = mock_update - platform.PARALLEL_UPDATES = 1 - - loader.set_component('test_domain.platform', platform) - +def test_extract_from_service_no_group_expand(hass): + """Test not expanding a group.""" component = EntityComponent(_LOGGER, DOMAIN, hass) - component._platforms = {} + test_group = yield from group.Group.async_create_group( + hass, 'test_group', ['light.Ceiling', 'light.Kitchen']) + yield from component.async_add_entities([test_group]) - yield from component.async_setup({ - DOMAIN: { - 'platform': 'platform', - } + call = ha.ServiceCall('test', 'service', { + 'entity_id': ['group.test_group'] }) - handle = list(component._platforms.values())[-1] - - assert handle.parallel_updates is not None - - -@asyncio.coroutine -def test_parallel_updates_sync_platform(hass): - """Warn we log when platform setup takes a long time.""" - platform = MockPlatform() - - loader.set_component('test_domain.platform', platform) - - component = EntityComponent(_LOGGER, DOMAIN, hass) - component._platforms = {} - - yield from component.async_setup({ - DOMAIN: { - 'platform': 'platform', - } - }) - - handle = list(component._platforms.values())[-1] - - assert handle.parallel_updates is not None - - -@asyncio.coroutine -def test_raise_error_on_update(hass): - """Test the add entity if they raise an error on update.""" - updates = [] - component = EntityComponent(_LOGGER, DOMAIN, hass) - entity1 = EntityTest(name='test_1') - entity2 = EntityTest(name='test_2') - - def _raise(): - """Helper to raise an exception.""" - raise AssertionError - - entity1.update = _raise - entity2.update = lambda: updates.append(1) - - yield from component.async_add_entities([entity1, entity2], True) - - assert len(updates) == 1 - assert 1 in updates - - -@asyncio.coroutine -def test_async_remove_with_platform(hass): - """Remove an entity from a platform.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - entity1 = EntityTest(name='test_1') - yield from component.async_add_entities([entity1]) - assert len(hass.states.async_entity_ids()) == 1 - yield from entity1.async_remove() - assert len(hass.states.async_entity_ids()) == 0 - - -@asyncio.coroutine -def test_not_adding_duplicate_entities_with_unique_id(hass): - """Test for not adding duplicate entities.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - - yield from component.async_add_entities([ - EntityTest(name='test1', unique_id='not_very_unique')]) - - assert len(hass.states.async_entity_ids()) == 1 - - yield from component.async_add_entities([ - EntityTest(name='test2', unique_id='not_very_unique')]) - - assert len(hass.states.async_entity_ids()) == 1 - - -@asyncio.coroutine -def test_using_prescribed_entity_id(hass): - """Test for using predefined entity ID.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - yield from component.async_add_entities([ - EntityTest(name='bla', entity_id='hello.world')]) - assert 'hello.world' in hass.states.async_entity_ids() - - -@asyncio.coroutine -def test_using_prescribed_entity_id_with_unique_id(hass): - """Test for ammending predefined entity ID because currently exists.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - - yield from component.async_add_entities([ - EntityTest(entity_id='test_domain.world')]) - yield from component.async_add_entities([ - EntityTest(entity_id='test_domain.world', unique_id='bla')]) - - assert 'test_domain.world_2' in hass.states.async_entity_ids() - - -@asyncio.coroutine -def test_using_prescribed_entity_id_which_is_registered(hass): - """Test not allowing predefined entity ID that already registered.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - registry = mock_registry(hass) - # Register test_domain.world - registry.async_get_or_create( - DOMAIN, 'test', '1234', suggested_object_id='world') - - # This entity_id will be rewritten - yield from component.async_add_entities([ - EntityTest(entity_id='test_domain.world')]) - - assert 'test_domain.world_2' in hass.states.async_entity_ids() - - -@asyncio.coroutine -def test_name_which_conflict_with_registered(hass): - """Test not generating conflicting entity ID based on name.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - registry = mock_registry(hass) - - # Register test_domain.world - registry.async_get_or_create( - DOMAIN, 'test', '1234', suggested_object_id='world') - - yield from component.async_add_entities([ - EntityTest(name='world')]) - - assert 'test_domain.world_2' in hass.states.async_entity_ids() - - -@asyncio.coroutine -def test_entity_with_name_and_entity_id_getting_registered(hass): - """Ensure that entity ID is used for registration.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - yield from component.async_add_entities([ - EntityTest(unique_id='1234', name='bla', - entity_id='test_domain.world')]) - assert 'test_domain.world' in hass.states.async_entity_ids() + extracted = component.async_extract_from_service(call, expand_group=False) + assert extracted == [test_group] diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py new file mode 100644 index 00000000000..4c27cc45a00 --- /dev/null +++ b/tests/helpers/test_entity_platform.py @@ -0,0 +1,435 @@ +"""Tests for the EntityPlatform helper.""" +import asyncio +import logging +import unittest +from unittest.mock import patch, Mock, MagicMock +from datetime import timedelta + +import homeassistant.loader as loader +from homeassistant.helpers.entity import generate_entity_id +from homeassistant.helpers.entity_component import ( + EntityComponent, DEFAULT_SCAN_INTERVAL) +from homeassistant.helpers import entity_platform + +import homeassistant.util.dt as dt_util + +from tests.common import ( + get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry, + MockEntity) + +_LOGGER = logging.getLogger(__name__) +DOMAIN = "test_domain" + + +class TestHelpersEntityPlatform(unittest.TestCase): + """Test homeassistant.helpers.entity_component module.""" + + def setUp(self): # pylint: disable=invalid-name + """Initialize a test Home Assistant instance.""" + self.hass = get_test_home_assistant() + + def tearDown(self): # pylint: disable=invalid-name + """Clean up the test Home Assistant instance.""" + self.hass.stop() + + def test_polling_only_updates_entities_it_should_poll(self): + """Test the polling of only updated entities.""" + component = EntityComponent( + _LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) + + no_poll_ent = MockEntity(should_poll=False) + no_poll_ent.async_update = Mock() + poll_ent = MockEntity(should_poll=True) + poll_ent.async_update = Mock() + + component.add_entities([no_poll_ent, poll_ent]) + + no_poll_ent.async_update.reset_mock() + poll_ent.async_update.reset_mock() + + fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) + self.hass.block_till_done() + + assert not no_poll_ent.async_update.called + assert poll_ent.async_update.called + + def test_polling_updates_entities_with_exception(self): + """Test the updated entities that not break with an exception.""" + component = EntityComponent( + _LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) + + update_ok = [] + update_err = [] + + def update_mock(): + """Mock normal update.""" + update_ok.append(None) + + def update_mock_err(): + """Mock error update.""" + update_err.append(None) + raise AssertionError("Fake error update") + + ent1 = MockEntity(should_poll=True) + ent1.update = update_mock_err + ent2 = MockEntity(should_poll=True) + ent2.update = update_mock + ent3 = MockEntity(should_poll=True) + ent3.update = update_mock + ent4 = MockEntity(should_poll=True) + ent4.update = update_mock + + component.add_entities([ent1, ent2, ent3, ent4]) + + update_ok.clear() + update_err.clear() + + fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) + self.hass.block_till_done() + + assert len(update_ok) == 3 + assert len(update_err) == 1 + + def test_update_state_adds_entities(self): + """Test if updating poll entities cause an entity to be added works.""" + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + ent1 = MockEntity() + ent2 = MockEntity(should_poll=True) + + component.add_entities([ent2]) + assert 1 == len(self.hass.states.entity_ids()) + ent2.update = lambda *_: component.add_entities([ent1]) + + fire_time_changed( + self.hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL + ) + self.hass.block_till_done() + + assert 2 == len(self.hass.states.entity_ids()) + + def test_update_state_adds_entities_with_update_before_add_true(self): + """Test if call update before add to state machine.""" + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + ent = MockEntity() + ent.update = Mock(spec_set=True) + + component.add_entities([ent], True) + self.hass.block_till_done() + + assert 1 == len(self.hass.states.entity_ids()) + assert ent.update.called + + def test_update_state_adds_entities_with_update_before_add_false(self): + """Test if not call update before add to state machine.""" + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + ent = MockEntity() + ent.update = Mock(spec_set=True) + + component.add_entities([ent], False) + self.hass.block_till_done() + + assert 1 == len(self.hass.states.entity_ids()) + assert not ent.update.called + + @patch('homeassistant.helpers.entity_platform.' + 'async_track_time_interval') + def test_set_scan_interval_via_platform(self, mock_track): + """Test the setting of the scan interval via platform.""" + def platform_setup(hass, config, add_devices, discovery_info=None): + """Test the platform setup.""" + add_devices([MockEntity(should_poll=True)]) + + platform = MockPlatform(platform_setup) + platform.SCAN_INTERVAL = timedelta(seconds=30) + + loader.set_component('test_domain.platform', platform) + + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + component.setup({ + DOMAIN: { + 'platform': 'platform', + } + }) + + self.hass.block_till_done() + assert mock_track.called + assert timedelta(seconds=30) == mock_track.call_args[0][2] + + def test_adding_entities_with_generator_and_thread_callback(self): + """Test generator in add_entities that calls thread method. + + We should make sure we resolve the generator to a list before passing + it into an async context. + """ + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + def create_entity(number): + """Create entity helper.""" + entity = MockEntity() + entity.entity_id = generate_entity_id(DOMAIN + '.{}', + 'Number', hass=self.hass) + return entity + + component.add_entities(create_entity(i) for i in range(2)) + + +@asyncio.coroutine +def test_platform_warn_slow_setup(hass): + """Warn we log when platform setup takes a long time.""" + platform = MockPlatform() + + loader.set_component('test_domain.platform', platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + + with patch.object(hass.loop, 'call_later', MagicMock()) \ + as mock_call: + yield from component.async_setup({ + DOMAIN: { + 'platform': 'platform', + } + }) + assert mock_call.called + + timeout, logger_method = mock_call.mock_calls[0][1][:2] + + assert timeout == entity_platform.SLOW_SETUP_WARNING + assert logger_method == _LOGGER.warning + + assert mock_call().cancel.called + + +@asyncio.coroutine +def test_platform_error_slow_setup(hass, caplog): + """Don't block startup more than SLOW_SETUP_MAX_WAIT.""" + with patch.object(entity_platform, 'SLOW_SETUP_MAX_WAIT', 0): + called = [] + + @asyncio.coroutine + def setup_platform(*args): + called.append(1) + yield from asyncio.sleep(1, loop=hass.loop) + + platform = MockPlatform(async_setup_platform=setup_platform) + component = EntityComponent(_LOGGER, DOMAIN, hass) + loader.set_component('test_domain.test_platform', platform) + yield from component.async_setup({ + DOMAIN: { + 'platform': 'test_platform', + } + }) + assert len(called) == 1 + assert 'test_domain.test_platform' not in hass.config.components + assert 'test_platform is taking longer than 0 seconds' in caplog.text + + +@asyncio.coroutine +def test_updated_state_used_for_entity_id(hass): + """Test that first update results used for entity ID generation.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + + class MockEntityNameFetcher(MockEntity): + """Mock entity that fetches a friendly name.""" + + @asyncio.coroutine + def async_update(self): + """Mock update that assigns a name.""" + self._values['name'] = "Living Room" + + yield from component.async_add_entities([MockEntityNameFetcher()], True) + + entity_ids = hass.states.async_entity_ids() + assert 1 == len(entity_ids) + assert entity_ids[0] == "test_domain.living_room" + + +@asyncio.coroutine +def test_parallel_updates_async_platform(hass): + """Warn we log when platform setup takes a long time.""" + platform = MockPlatform() + + @asyncio.coroutine + def mock_update(*args, **kwargs): + pass + + platform.async_setup_platform = mock_update + + loader.set_component('test_domain.platform', platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + component._platforms = {} + + yield from component.async_setup({ + DOMAIN: { + 'platform': 'platform', + } + }) + + handle = list(component._platforms.values())[-1] + + assert handle.parallel_updates is None + + +@asyncio.coroutine +def test_parallel_updates_async_platform_with_constant(hass): + """Warn we log when platform setup takes a long time.""" + platform = MockPlatform() + + @asyncio.coroutine + def mock_update(*args, **kwargs): + pass + + platform.async_setup_platform = mock_update + platform.PARALLEL_UPDATES = 1 + + loader.set_component('test_domain.platform', platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + component._platforms = {} + + yield from component.async_setup({ + DOMAIN: { + 'platform': 'platform', + } + }) + + handle = list(component._platforms.values())[-1] + + assert handle.parallel_updates is not None + + +@asyncio.coroutine +def test_parallel_updates_sync_platform(hass): + """Warn we log when platform setup takes a long time.""" + platform = MockPlatform() + + loader.set_component('test_domain.platform', platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + component._platforms = {} + + yield from component.async_setup({ + DOMAIN: { + 'platform': 'platform', + } + }) + + handle = list(component._platforms.values())[-1] + + assert handle.parallel_updates is not None + + +@asyncio.coroutine +def test_raise_error_on_update(hass): + """Test the add entity if they raise an error on update.""" + updates = [] + component = EntityComponent(_LOGGER, DOMAIN, hass) + entity1 = MockEntity(name='test_1') + entity2 = MockEntity(name='test_2') + + def _raise(): + """Helper to raise an exception.""" + raise AssertionError + + entity1.update = _raise + entity2.update = lambda: updates.append(1) + + yield from component.async_add_entities([entity1, entity2], True) + + assert len(updates) == 1 + assert 1 in updates + + +@asyncio.coroutine +def test_async_remove_with_platform(hass): + """Remove an entity from a platform.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + entity1 = MockEntity(name='test_1') + yield from component.async_add_entities([entity1]) + assert len(hass.states.async_entity_ids()) == 1 + yield from entity1.async_remove() + assert len(hass.states.async_entity_ids()) == 0 + + +@asyncio.coroutine +def test_not_adding_duplicate_entities_with_unique_id(hass): + """Test for not adding duplicate entities.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + + yield from component.async_add_entities([ + MockEntity(name='test1', unique_id='not_very_unique')]) + + assert len(hass.states.async_entity_ids()) == 1 + + yield from component.async_add_entities([ + MockEntity(name='test2', unique_id='not_very_unique')]) + + assert len(hass.states.async_entity_ids()) == 1 + + +@asyncio.coroutine +def test_using_prescribed_entity_id(hass): + """Test for using predefined entity ID.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + yield from component.async_add_entities([ + MockEntity(name='bla', entity_id='hello.world')]) + assert 'hello.world' in hass.states.async_entity_ids() + + +@asyncio.coroutine +def test_using_prescribed_entity_id_with_unique_id(hass): + """Test for ammending predefined entity ID because currently exists.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + + yield from component.async_add_entities([ + MockEntity(entity_id='test_domain.world')]) + yield from component.async_add_entities([ + MockEntity(entity_id='test_domain.world', unique_id='bla')]) + + assert 'test_domain.world_2' in hass.states.async_entity_ids() + + +@asyncio.coroutine +def test_using_prescribed_entity_id_which_is_registered(hass): + """Test not allowing predefined entity ID that already registered.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + registry = mock_registry(hass) + # Register test_domain.world + registry.async_get_or_create( + DOMAIN, 'test', '1234', suggested_object_id='world') + + # This entity_id will be rewritten + yield from component.async_add_entities([ + MockEntity(entity_id='test_domain.world')]) + + assert 'test_domain.world_2' in hass.states.async_entity_ids() + + +@asyncio.coroutine +def test_name_which_conflict_with_registered(hass): + """Test not generating conflicting entity ID based on name.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + registry = mock_registry(hass) + + # Register test_domain.world + registry.async_get_or_create( + DOMAIN, 'test', '1234', suggested_object_id='world') + + yield from component.async_add_entities([ + MockEntity(name='world')]) + + assert 'test_domain.world_2' in hass.states.async_entity_ids() + + +@asyncio.coroutine +def test_entity_with_name_and_entity_id_getting_registered(hass): + """Ensure that entity ID is used for registration.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + yield from component.async_add_entities([ + MockEntity(unique_id='1234', name='bla', + entity_id='test_domain.world')]) + assert 'test_domain.world' in hass.states.async_entity_ids() diff --git a/tests/test_config.py b/tests/test_config.py index 377c650e91f..541eaf4f79e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -255,18 +255,6 @@ class TestConfig(unittest.TestCase): return self.hass.states.get('test.test') - def test_entity_customization_false(self): - """Test entity customization through configuration.""" - config = {CONF_LATITUDE: 50, - CONF_LONGITUDE: 50, - CONF_NAME: 'Test', - CONF_CUSTOMIZE: { - 'test.test': {'hidden': False}}} - - state = self._compute_state(config) - - assert 'hidden' not in state.attributes - def test_entity_customization(self): """Test entity customization through configuration.""" config = {CONF_LATITUDE: 50,