Clean up entity component (#11691)

* Clean up entity component

* Lint

* List -> Tuple

* Add Entity.async_remove back

* Unflake setting up group test
This commit is contained in:
Paulus Schoutsen 2018-01-22 22:54:41 -08:00 committed by GitHub
parent d478517c51
commit 183e0543b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 230 additions and 191 deletions

View File

@ -338,10 +338,9 @@ class AutomationEntity(ToggleEntity):
yield from self.async_update_ha_state() yield from self.async_update_ha_state()
@asyncio.coroutine @asyncio.coroutine
def async_remove(self): def async_will_remove_from_hass(self):
"""Remove automation from HASS.""" """Remove listeners when removing automation from HASS."""
yield from self.async_turn_off() yield from self.async_turn_off()
yield from super().async_remove()
@asyncio.coroutine @asyncio.coroutine
def async_enable(self): def async_enable(self):

View File

@ -238,6 +238,5 @@ class FlicButton(BinarySensorDevice):
import pyflic import pyflic
if connection_status == pyflic.ConnectionStatus.Disconnected: if connection_status == pyflic.ConnectionStatus.Disconnected:
_LOGGER.info("Button (%s) disconnected. Reason: %s", _LOGGER.warning("Button (%s) disconnected. Reason: %s",
self.address, disconnect_reason) self.address, disconnect_reason)
self.remove()

View File

@ -124,15 +124,15 @@ def async_setup(hass, config):
"""Set up the camera component.""" """Set up the camera component."""
component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL) component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(CameraImageView(component.entities)) hass.http.register_view(CameraImageView(component))
hass.http.register_view(CameraMjpegStream(component.entities)) hass.http.register_view(CameraMjpegStream(component))
yield from component.async_setup(config) yield from component.async_setup(config)
@callback @callback
def update_tokens(time): def update_tokens(time):
"""Update tokens of the entities.""" """Update tokens of the entities."""
for entity in component.entities.values(): for entity in component.entities:
entity.async_update_token() entity.async_update_token()
hass.async_add_job(entity.async_update_ha_state()) hass.async_add_job(entity.async_update_ha_state())
@ -358,14 +358,14 @@ class CameraView(HomeAssistantView):
requires_auth = False requires_auth = False
def __init__(self, entities): def __init__(self, component):
"""Initialize a basic camera view.""" """Initialize a basic camera view."""
self.entities = entities self.component = component
@asyncio.coroutine @asyncio.coroutine
def get(self, request, entity_id): def get(self, request, entity_id):
"""Start a GET request.""" """Start a GET request."""
camera = self.entities.get(entity_id) camera = self.component.get_entity(entity_id)
if camera is None: if camera is None:
status = 404 if request[KEY_AUTHENTICATED] else 401 status = 404 if request[KEY_AUTHENTICATED] else 401

View File

@ -42,8 +42,6 @@ ATTR_ORDER = 'order'
ATTR_VIEW = 'view' ATTR_VIEW = 'view'
ATTR_VISIBLE = 'visible' ATTR_VISIBLE = 'visible'
DATA_ALL_GROUPS = 'data_all_groups'
SERVICE_SET_VISIBILITY = 'set_visibility' SERVICE_SET_VISIBILITY = 'set_visibility'
SERVICE_SET = 'set' SERVICE_SET = 'set'
SERVICE_REMOVE = 'remove' SERVICE_REMOVE = 'remove'
@ -250,8 +248,10 @@ def get_entity_ids(hass, entity_id, domain_filter=None):
@asyncio.coroutine @asyncio.coroutine
def async_setup(hass, config): def async_setup(hass, config):
"""Set up all groups found definded in the configuration.""" """Set up all groups found definded in the configuration."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = hass.data.get(DOMAIN)
hass.data[DATA_ALL_GROUPS] = {}
if component is None:
component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
yield from _async_process_config(hass, config, component) yield from _async_process_config(hass, config, component)
@ -271,10 +271,11 @@ def async_setup(hass, config):
def groups_service_handler(service): def groups_service_handler(service):
"""Handle dynamic group service functions.""" """Handle dynamic group service functions."""
object_id = service.data[ATTR_OBJECT_ID] object_id = service.data[ATTR_OBJECT_ID]
service_groups = hass.data[DATA_ALL_GROUPS] entity_id = ENTITY_ID_FORMAT.format(object_id)
group = component.get_entity(entity_id)
# new group # new group
if service.service == SERVICE_SET and object_id not in service_groups: if service.service == SERVICE_SET and group is None:
entity_ids = service.data.get(ATTR_ENTITIES) or \ entity_ids = service.data.get(ATTR_ENTITIES) or \
service.data.get(ATTR_ADD_ENTITIES) or None service.data.get(ATTR_ADD_ENTITIES) or None
@ -289,12 +290,15 @@ def async_setup(hass, config):
user_defined=False, user_defined=False,
**extra_arg **extra_arg
) )
return
if group is None:
_LOGGER.warning("%s:Group '%s' doesn't exist!",
service.service, object_id)
return return
# update group # update group
if service.service == SERVICE_SET: if service.service == SERVICE_SET:
group = service_groups[object_id]
need_update = False need_update = False
if ATTR_ADD_ENTITIES in service.data: if ATTR_ADD_ENTITIES in service.data:
@ -333,12 +337,7 @@ def async_setup(hass, config):
# remove group # remove group
if service.service == SERVICE_REMOVE: if service.service == SERVICE_REMOVE:
if object_id not in service_groups: yield from component.async_remove_entity(entity_id)
_LOGGER.warning("Group '%s' doesn't exist!", object_id)
return
del_group = service_groups.pop(object_id)
yield from del_group.async_stop()
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_SET, groups_service_handler, DOMAIN, SERVICE_SET, groups_service_handler,
@ -395,7 +394,7 @@ class Group(Entity):
"""Track a group of entity ids.""" """Track a group of entity ids."""
def __init__(self, hass, name, order=None, visible=True, icon=None, def __init__(self, hass, name, order=None, visible=True, icon=None,
view=False, control=None, user_defined=True): view=False, control=None, user_defined=True, entity_ids=None):
"""Initialize a group. """Initialize a group.
This Object has factory function for creation. This Object has factory function for creation.
@ -405,7 +404,10 @@ class Group(Entity):
self._state = STATE_UNKNOWN self._state = STATE_UNKNOWN
self._icon = icon self._icon = icon
self.view = view self.view = view
self.tracking = [] if entity_ids:
self.tracking = tuple(ent_id.lower() for ent_id in entity_ids)
else:
self.tracking = tuple()
self.group_on = None self.group_on = None
self.group_off = None self.group_off = None
self.visible = visible self.visible = visible
@ -439,23 +441,21 @@ class Group(Entity):
hass, name, hass, name,
order=len(hass.states.async_entity_ids(DOMAIN)), order=len(hass.states.async_entity_ids(DOMAIN)),
visible=visible, icon=icon, view=view, control=control, visible=visible, icon=icon, view=view, control=control,
user_defined=user_defined user_defined=user_defined, entity_ids=entity_ids
) )
group.entity_id = async_generate_entity_id( group.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id or name, hass=hass) ENTITY_ID_FORMAT, object_id or name, hass=hass)
# run other async stuff
if entity_ids is not None:
yield from group.async_update_tracked_entity_ids(entity_ids)
else:
yield from group.async_update_ha_state(True)
# If called before the platform async_setup is called (test cases) # If called before the platform async_setup is called (test cases)
if DATA_ALL_GROUPS not in hass.data: component = hass.data.get(DOMAIN)
hass.data[DATA_ALL_GROUPS] = {}
if component is None:
component = hass.data[DOMAIN] = \
EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([group], True)
hass.data[DATA_ALL_GROUPS][object_id] = group
return group return group
@property @property
@ -534,10 +534,6 @@ class Group(Entity):
yield from self.async_update_ha_state(True) yield from self.async_update_ha_state(True)
self.async_start() self.async_start()
def start(self):
"""Start tracking members."""
self.hass.add_job(self.async_start)
@callback @callback
def async_start(self): def async_start(self):
"""Start tracking members. """Start tracking members.
@ -549,17 +545,15 @@ class Group(Entity):
self.hass, self.tracking, self._async_state_changed_listener self.hass, self.tracking, self._async_state_changed_listener
) )
def stop(self):
"""Unregister the group from Home Assistant."""
run_coroutine_threadsafe(self.async_stop(), self.hass.loop).result()
@asyncio.coroutine @asyncio.coroutine
def async_stop(self): def async_stop(self):
"""Unregister the group from Home Assistant. """Unregister the group from Home Assistant.
This method must be run in the event loop. This method must be run in the event loop.
""" """
yield from self.async_remove() if self._async_unsub_state_changed:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
@asyncio.coroutine @asyncio.coroutine
def async_update(self): def async_update(self):
@ -567,17 +561,19 @@ class Group(Entity):
self._state = STATE_UNKNOWN self._state = STATE_UNKNOWN
self._async_update_group_state() self._async_update_group_state()
def async_remove(self): @asyncio.coroutine
"""Remove group from HASS. def async_added_to_hass(self):
"""Callback when added to HASS."""
if self.tracking:
self.async_start()
This method must be run in the event loop and returns a coroutine. @asyncio.coroutine
""" def async_will_remove_from_hass(self):
"""Callback when removed from HASS."""
if self._async_unsub_state_changed: if self._async_unsub_state_changed:
self._async_unsub_state_changed() self._async_unsub_state_changed()
self._async_unsub_state_changed = None self._async_unsub_state_changed = None
return super().async_remove()
@asyncio.coroutine @asyncio.coroutine
def _async_state_changed_listener(self, entity_id, old_state, new_state): def _async_state_changed_listener(self, entity_id, old_state, new_state):
"""Respond to a member state changing. """Respond to a member state changing.

View File

@ -82,7 +82,7 @@ def async_setup(hass, config):
mailbox_entity = MailboxEntity(hass, mailbox) mailbox_entity = MailboxEntity(hass, mailbox)
component = EntityComponent( component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL) logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
yield from component.async_add_entity(mailbox_entity) yield from component.async_add_entities([mailbox_entity])
setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
in config_per_platform(config, DOMAIN)] in config_per_platform(config, DOMAIN)]

View File

@ -366,7 +366,7 @@ def async_setup(hass, config):
component = EntityComponent( component = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL) logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL)
hass.http.register_view(MediaPlayerImageView(component.entities)) hass.http.register_view(MediaPlayerImageView(component))
yield from component.async_setup(config) yield from component.async_setup(config)
@ -929,14 +929,14 @@ class MediaPlayerImageView(HomeAssistantView):
url = '/api/media_player_proxy/{entity_id}' url = '/api/media_player_proxy/{entity_id}'
name = 'api:media_player:image' name = 'api:media_player:image'
def __init__(self, entities): def __init__(self, component):
"""Initialize a media player view.""" """Initialize a media player view."""
self.entities = entities self.component = component
@asyncio.coroutine @asyncio.coroutine
def get(self, request, entity_id): def get(self, request, entity_id):
"""Start a get request.""" """Start a get request."""
player = self.entities.get(entity_id) player = self.component.get_entity(entity_id)
if player is None: if player is None:
status = 404 if request[KEY_AUTHENTICATED] else 401 status = 404 if request[KEY_AUTHENTICATED] else 401
return web.Response(status=status) return web.Response(status=status)

View File

@ -161,7 +161,7 @@ def async_setup(hass, config):
face.store.pop(g_id) face.store.pop(g_id)
entity = entities.pop(g_id) entity = entities.pop(g_id)
yield from entity.async_remove() hass.states.async_remove(entity.entity_id)
except HomeAssistantError as err: except HomeAssistantError as err:
_LOGGER.error("Can't delete group '%s' with error: %s", g_id, err) _LOGGER.error("Can't delete group '%s' with error: %s", g_id, err)

View File

@ -86,7 +86,7 @@ def _create_instance(hass, account_name, api_key, shared_secret,
token, stored_rtm_config, component): token, stored_rtm_config, component):
entity = RememberTheMilk(account_name, api_key, shared_secret, entity = RememberTheMilk(account_name, api_key, shared_secret,
token, stored_rtm_config) token, stored_rtm_config)
component.add_entity(entity) component.add_entities([entity])
hass.services.register( hass.services.register(
DOMAIN, '{}_create_task'.format(account_name), entity.create_task, DOMAIN, '{}_create_task'.format(account_name), entity.create_task,
schema=SERVICE_SCHEMA_CREATE_TASK) schema=SERVICE_SCHEMA_CREATE_TASK)

View File

@ -156,7 +156,7 @@ def _async_process_config(hass, config, component):
def service_handler(service): def service_handler(service):
"""Execute a service call to script.<script name>.""" """Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service) entity_id = ENTITY_ID_FORMAT.format(service.service)
script = component.entities.get(entity_id) script = component.get_entity(entity_id)
if script.is_on: if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id) _LOGGER.warning("Script %s already running.", entity_id)
return return
@ -219,15 +219,11 @@ class ScriptEntity(ToggleEntity):
"""Turn script off.""" """Turn script off."""
self.script.async_stop() self.script.async_stop()
def async_remove(self): @asyncio.coroutine
"""Remove script from HASS. def async_will_remove_from_hass(self):
"""Stop script and remove service when it will be removed from HASS."""
This method must be run in the event loop and returns a coroutine.
"""
if self.script.is_running: if self.script.is_running:
self.script.async_stop() self.script.async_stop()
# remove service # remove service
self.hass.services.async_remove(DOMAIN, self.object_id) self.hass.services.async_remove(DOMAIN, self.object_id)
return super().async_remove()

View File

@ -15,8 +15,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.config import DATA_CUSTOMIZE from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.async import ( from homeassistant.util.async import run_callback_threadsafe
run_coroutine_threadsafe, run_callback_threadsafe)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10 SLOW_UPDATE_WARNING = 10
@ -66,9 +65,12 @@ class Entity(object):
# this class. These may be used to customize the behavior of the entity. # this class. These may be used to customize the behavior of the entity.
entity_id = None # type: str entity_id = None # type: str
# Owning hass instance. Will be set by EntityComponent # Owning hass instance. Will be set by EntityPlatform
hass = None # type: Optional[HomeAssistant] hass = None # type: Optional[HomeAssistant]
# Owning platform instance. Will be set by EntityPlatform
platform = None
# If we reported if this entity was slow # If we reported if this entity was slow
_slow_reported = False _slow_reported = False
@ -311,19 +313,13 @@ class Entity(object):
if self.parallel_updates: if self.parallel_updates:
self.parallel_updates.release() self.parallel_updates.release()
def remove(self) -> None:
"""Remove entity from HASS."""
run_coroutine_threadsafe(
self.async_remove(), self.hass.loop
).result()
@asyncio.coroutine @asyncio.coroutine
def async_remove(self) -> None: def async_remove(self):
"""Remove entity from async HASS. """Remove entity from Home Assistant."""
if self.platform is not None:
This method must be run in the event loop. yield from self.platform.async_remove_entity(self.entity_id)
""" else:
self.hass.states.async_remove(self.entity_id) self.hass.states.async_remove(self.entity_id)
def _attr_setter(self, name, typ, attr, attrs): def _attr_setter(self, name, typ, attr, attrs):
"""Populate attributes based on properties.""" """Populate attributes based on properties."""

View File

@ -1,6 +1,7 @@
"""Helpers for components that manage entities.""" """Helpers for components that manage entities."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from itertools import chain
from homeassistant import config as conf_util from homeassistant import config as conf_util
from homeassistant.setup import async_prepare_setup_platform from homeassistant.setup import async_prepare_setup_platform
@ -9,7 +10,6 @@ from homeassistant.const import (
DEVICE_DEFAULT_NAME) DEVICE_DEFAULT_NAME)
from homeassistant.core import callback, valid_entity_id from homeassistant.core import callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.loader import get_component
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.entity import async_generate_entity_id from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
@ -27,7 +27,15 @@ PLATFORM_NOT_READY_RETRIES = 10
class EntityComponent(object): class EntityComponent(object):
"""Helper class that will help a component manage its entities.""" """The EntityComponent manages platforms that manages entities.
This class has the following responsibilities:
- Process the configuration and set up a platform based component.
- Manage the platforms and their entities.
- Help extract the entities from a service call.
- Maintain a group that tracks all platform entities.
- Listen for discovery events for platforms related to the domain.
"""
def __init__(self, logger, domain, hass, def __init__(self, logger, domain, hass,
scan_interval=DEFAULT_SCAN_INTERVAL, group_name=None): scan_interval=DEFAULT_SCAN_INTERVAL, group_name=None):
@ -40,7 +48,6 @@ class EntityComponent(object):
self.scan_interval = scan_interval self.scan_interval = scan_interval
self.group_name = group_name self.group_name = group_name
self.entities = {}
self.config = None self.config = None
self._platforms = { self._platforms = {
@ -49,6 +56,20 @@ class EntityComponent(object):
self.async_add_entities = self._platforms['core'].async_add_entities self.async_add_entities = self._platforms['core'].async_add_entities
self.add_entities = self._platforms['core'].add_entities self.add_entities = self._platforms['core'].add_entities
@property
def entities(self):
"""Return an iterable that returns all entities."""
return chain.from_iterable(platform.entities.values() for platform
in self._platforms.values())
def get_entity(self, entity_id):
"""Helper method to get an entity."""
for platform in self._platforms.values():
entity = platform.entities.get(entity_id)
if entity is not None:
return entity
return None
def setup(self, config): def setup(self, config):
"""Set up a full entity component. """Set up a full entity component.
@ -77,11 +98,10 @@ class EntityComponent(object):
# Generic discovery listener for loading platform dynamically # Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.components.discovery.load_platform() # Refer to: homeassistant.components.discovery.load_platform()
@callback @asyncio.coroutine
def component_platform_discovered(platform, info): def component_platform_discovered(platform, info):
"""Handle the loading of a platform.""" """Handle the loading of a platform."""
self.hass.async_add_job( yield from self._async_setup_platform(platform, {}, info)
self._async_setup_platform(platform, {}, info))
discovery.async_listen_platform( discovery.async_listen_platform(
self.hass, self.domain, component_platform_discovered) self.hass, self.domain, component_platform_discovered)
@ -107,13 +127,11 @@ class EntityComponent(object):
This method must be run in the event loop. This method must be run in the event loop.
""" """
if ATTR_ENTITY_ID not in service.data: if ATTR_ENTITY_ID not in service.data:
return [entity for entity in self.entities.values() return [entity for entity in self.entities if entity.available]
if entity.available]
return [self.entities[entity_id] for entity_id entity_ids = set(extract_entity_ids(self.hass, service, expand_group))
in extract_entity_ids(self.hass, service, expand_group) return [entity for entity in self.entities
if entity_id in self.entities and if entity.available and entity.entity_id in entity_ids]
self.entities[entity_id].available]
@asyncio.coroutine @asyncio.coroutine
def _async_setup_platform(self, platform_type, platform_config, def _async_setup_platform(self, platform_type, platform_config,
@ -193,80 +211,23 @@ class EntityComponent(object):
finally: finally:
warn_task.cancel() warn_task.cancel()
def add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component."""
return run_coroutine_threadsafe(
self.async_add_entity(entity, platform, update_before_add),
self.hass.loop
).result()
@asyncio.coroutine
def async_add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component.
This method must be run in the event loop.
"""
if entity is None or entity in self.entities.values():
return False
entity.hass = self.hass
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.logger.exception("Error on device update!")
return False
# Write entity_id to entity
if getattr(entity, 'entity_id', None) is None:
object_id = entity.name or DEVICE_DEFAULT_NAME
if platform is not None and platform.entity_namespace is not None:
object_id = '{} {}'.format(platform.entity_namespace,
object_id)
entity.entity_id = async_generate_entity_id(
self.entity_id_format, object_id,
self.entities.keys())
# Make sure it is valid in case an entity set the value themselves
if entity.entity_id in self.entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
elif not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
return True
def update_group(self):
"""Set up and/or update component group."""
run_callback_threadsafe(
self.hass.loop, self.async_update_group).result()
@callback @callback
def async_update_group(self): def async_update_group(self):
"""Set up and/or update component group. """Set up and/or update component group.
This method must be run in the event loop. This method must be run in the event loop.
""" """
if self.group_name is not None: if self.group_name is None:
ids = sorted(self.entities, return
key=lambda x: self.entities[x].name or x)
group = get_component('group') ids = [entity.entity_id for entity in
group.async_set_group( sorted(self.entities,
self.hass, slugify(self.group_name), name=self.group_name, key=lambda entity: entity.name or entity.entity_id)]
visible=False, entity_ids=ids
) self.hass.components.group.async_set_group(
slugify(self.group_name), name=self.group_name,
visible=False, entity_ids=ids
)
def reset(self): def reset(self):
"""Remove entities and reset the entity component to initial values.""" """Remove entities and reset the entity component to initial values."""
@ -287,12 +248,17 @@ class EntityComponent(object):
self._platforms = { self._platforms = {
'core': self._platforms['core'] 'core': self._platforms['core']
} }
self.entities = {}
self.config = None self.config = None
if self.group_name is not None: if self.group_name is not None:
group = get_component('group') self.hass.components.group.async_remove(slugify(self.group_name))
group.async_remove(self.hass, slugify(self.group_name))
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove an entity managed by one of the platforms."""
for platform in self._platforms.values():
if entity_id in platform.entities:
yield from platform.async_remove_entity(entity_id)
def prepare_reload(self): def prepare_reload(self):
"""Prepare reloading this entity component.""" """Prepare reloading this entity component."""
@ -323,7 +289,7 @@ class EntityComponent(object):
class EntityPlatform(object): class EntityPlatform(object):
"""Keep track of entities for a single platform and stay in loop.""" """Manage the entities for a single platform."""
def __init__(self, component, platform, scan_interval, parallel_updates, def __init__(self, component, platform, scan_interval, parallel_updates,
entity_namespace): entity_namespace):
@ -333,7 +299,7 @@ class EntityPlatform(object):
self.scan_interval = scan_interval self.scan_interval = scan_interval
self.parallel_updates = None self.parallel_updates = None
self.entity_namespace = entity_namespace self.entity_namespace = entity_namespace
self.platform_entities = [] self.entities = {}
self._tasks = [] self._tasks = []
self._async_unsub_polling = None self._async_unsub_polling = None
self._process_updates = asyncio.Lock(loop=component.hass.loop) self._process_updates = asyncio.Lock(loop=component.hass.loop)
@ -391,40 +357,88 @@ class EntityPlatform(object):
if not new_entities: if not new_entities:
return return
@asyncio.coroutine component_entities = set(entity.entity_id for entity
def async_process_entity(new_entity): in self.component.entities)
"""Add entities to StateMachine."""
new_entity.parallel_updates = self.parallel_updates
ret = yield from self.component.async_add_entity(
new_entity, self, update_before_add=update_before_add
)
if ret:
self.platform_entities.append(new_entity)
tasks = [async_process_entity(entity) for entity in new_entities] tasks = [
self._async_add_entity(entity, update_before_add,
component_entities)
for entity in new_entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop) yield from asyncio.wait(tasks, loop=self.component.hass.loop)
self.component.async_update_group() self.component.async_update_group()
if self._async_unsub_polling is not None or \ if self._async_unsub_polling is not None or \
not any(entity.should_poll for entity not any(entity.should_poll for entity
in self.platform_entities): in self.entities.values()):
return return
self._async_unsub_polling = async_track_time_interval( self._async_unsub_polling = async_track_time_interval(
self.component.hass, self._update_entity_states, self.scan_interval self.component.hass, self._update_entity_states, self.scan_interval
) )
@asyncio.coroutine
def _async_add_entity(self, entity, update_before_add, component_entities):
"""Helper method to add an entity to the platform."""
if entity is None:
raise ValueError('Entity cannot be None')
# Do nothing if entity has already been added based on unique id.
if entity in self.component.entities:
return
entity.hass = self.component.hass
entity.platform = self
entity.parallel_updates = self.parallel_updates
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.component.logger.exception(
"%s: Error on device update!", self.platform)
return
# Write entity_id to entity
if getattr(entity, 'entity_id', None) is None:
object_id = entity.name or DEVICE_DEFAULT_NAME
if self.entity_namespace is not None:
object_id = '{} {}'.format(self.entity_namespace,
object_id)
entity.entity_id = async_generate_entity_id(
self.component.entity_id_format, object_id,
component_entities)
# Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
elif entity.entity_id in component_entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
component_entities.add(entity.entity_id)
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
@asyncio.coroutine @asyncio.coroutine
def async_reset(self): def async_reset(self):
"""Remove all entities and reset data. """Remove all entities and reset data.
This method must be run in the event loop. This method must be run in the event loop.
""" """
if not self.platform_entities: if not self.entities:
return return
tasks = [entity.async_remove() for entity in self.platform_entities] tasks = [self._async_remove_entity(entity_id)
for entity_id in self.entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop) yield from asyncio.wait(tasks, loop=self.component.hass.loop)
@ -432,6 +446,28 @@ class EntityPlatform(object):
self._async_unsub_polling() self._async_unsub_polling()
self._async_unsub_polling = None self._async_unsub_polling = None
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
yield from self._async_remove_entity(entity_id)
# Clean up polling job if no longer needed
if (self._async_unsub_polling is not None and
not any(entity.should_poll for entity
in self.entities.values())):
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def _async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
entity = self.entities.pop(entity_id)
if hasattr(entity, 'async_will_remove_from_hass'):
yield from entity.async_will_remove_from_hass()
self.component.hass.states.async_remove(entity_id)
@asyncio.coroutine @asyncio.coroutine
def _update_entity_states(self, now): def _update_entity_states(self, now):
"""Update the states of all the polling entities. """Update the states of all the polling entities.
@ -450,7 +486,7 @@ class EntityPlatform(object):
with (yield from self._process_updates): with (yield from self._process_updates):
tasks = [] tasks = []
for entity in self.platform_entities: for entity in self.entities.values():
if not entity.should_poll: if not entity.should_poll:
continue continue
tasks.append(entity.async_update_ha_state(True)) tasks.append(entity.async_update_ha_state(True))

View File

@ -350,7 +350,7 @@ class TestComponentsGroup(unittest.TestCase):
assert sorted(self.hass.states.entity_ids()) == \ assert sorted(self.hass.states.entity_ids()) == \
['group.empty_group', 'group.second_group', 'group.test_group'] ['group.empty_group', 'group.second_group', 'group.test_group']
assert self.hass.bus.listeners['state_changed'] == 3 assert self.hass.bus.listeners['state_changed'] == 2
with patch('homeassistant.config.load_yaml_config_file', return_value={ with patch('homeassistant.config.load_yaml_config_file', return_value={
'group': { 'group': {
@ -365,14 +365,6 @@ class TestComponentsGroup(unittest.TestCase):
assert self.hass.states.entity_ids() == ['group.hello'] assert self.hass.states.entity_ids() == ['group.hello']
assert self.hass.bus.listeners['state_changed'] == 1 assert self.hass.bus.listeners['state_changed'] == 1
def test_stopping_a_group(self):
"""Test that a group correctly removes itself."""
grp = group.Group.create_group(
self.hass, 'light', ['light.test_1', 'light.test_2'])
assert self.hass.states.entity_ids() == ['group.light']
grp.stop()
assert self.hass.states.entity_ids() == []
def test_changing_group_visibility(self): def test_changing_group_visibility(self):
"""Test that a group can be hidden and shown.""" """Test that a group can be hidden and shown."""
assert setup_component(self.hass, 'group', { assert setup_component(self.hass, 'group', {

View File

@ -388,3 +388,15 @@ def test_async_pararell_updates_with_two(hass):
test_lock.release() test_lock.release()
yield from asyncio.sleep(0, loop=hass.loop) yield from asyncio.sleep(0, loop=hass.loop)
test_lock.release() test_lock.release()
@asyncio.coroutine
def test_async_remove_no_platform(hass):
"""Test async_remove method when no platform set."""
ent = entity.Entity()
ent.hass = hass
ent.entity_id = 'test.test'
yield from ent.async_update_ha_state()
assert len(hass.states.async_entity_ids()) == 1
yield from ent.async_remove()
assert len(hass.states.async_entity_ids()) == 0

View File

@ -86,6 +86,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert len(self.hass.states.entity_ids()) == 0 assert len(self.hass.states.entity_ids()) == 0
component.add_entities([EntityTest()]) component.add_entities([EntityTest()])
self.hass.block_till_done()
# group exists # group exists
assert len(self.hass.states.entity_ids()) == 2 assert len(self.hass.states.entity_ids()) == 2
@ -98,6 +99,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
# group extended # group extended
component.add_entities([EntityTest(name='goodbye')]) component.add_entities([EntityTest(name='goodbye')])
self.hass.block_till_done()
assert len(self.hass.states.entity_ids()) == 3 assert len(self.hass.states.entity_ids()) == 3
group = self.hass.states.get('group.everyone') group = self.hass.states.get('group.everyone')
@ -214,7 +216,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert 0 == len(self.hass.states.entity_ids()) assert 0 == len(self.hass.states.entity_ids())
component.add_entities([None, EntityTest(unique_id='not_very_unique')]) component.add_entities([EntityTest(unique_id='not_very_unique')])
assert 1 == len(self.hass.states.entity_ids()) assert 1 == len(self.hass.states.entity_ids())
@ -671,3 +673,14 @@ def test_raise_error_on_update(hass):
assert len(updates) == 1 assert len(updates) == 1
assert 1 in updates assert 1 in updates
@asyncio.coroutine
def test_async_remove_with_platform(hass):
"""Remove an entity from a platform."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = EntityTest(name='test_1')
yield from component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1
yield from entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0