Merge pull request #1059 from balloob/entity-component-enhancements

Add tests and custom interval for entity component
This commit is contained in:
Paulus Schoutsen 2016-01-31 09:03:55 -08:00
commit 16b1529d14
6 changed files with 395 additions and 98 deletions

View File

@ -26,7 +26,7 @@ CONF_PASSWORD = "password"
CONF_API_KEY = "api_key" CONF_API_KEY = "api_key"
CONF_ACCESS_TOKEN = "access_token" CONF_ACCESS_TOKEN = "access_token"
CONF_FILENAME = "filename" CONF_FILENAME = "filename"
CONF_SCAN_INTERVAL = "scan_interval"
CONF_VALUE_TEMPLATE = "value_template" CONF_VALUE_TEMPLATE = "value_template"
# #### EVENTS #### # #### EVENTS ####

View File

@ -50,8 +50,6 @@ class Entity(object):
""" ABC for Home Assistant entities. """ """ ABC for Home Assistant entities. """
# pylint: disable=no-self-use # pylint: disable=no-self-use
_hidden = False
# SAFE TO OVERWRITE # SAFE TO OVERWRITE
# The properties and methods here are safe to overwrite when inherting this # The properties and methods here are safe to overwrite when inherting this
# class. These may be used to customize the behavior of the entity. # class. These may be used to customize the behavior of the entity.
@ -103,13 +101,14 @@ class Entity(object):
""" Retrieve latest state. """ """ Retrieve latest state. """
pass pass
entity_id = None
# DO NOT OVERWRITE # DO NOT OVERWRITE
# These properties and methods are either managed by Home Assistant or they # These properties and methods are either managed by Home Assistant or they
# are used to perform a very specific function. Overwriting these may # are used to perform a very specific function. Overwriting these may
# produce undesirable effects in the entity's operation. # produce undesirable effects in the entity's operation.
hass = None hass = None
entity_id = None
def update_ha_state(self, force_refresh=False): def update_ha_state(self, force_refresh=False):
""" """

View File

@ -1,11 +1,7 @@
""" """Provides helpers for components that manage entities."""
homeassistant.helpers.entity_component
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Provides helpers for components that manage entities.
"""
from threading import Lock from threading import Lock
from homeassistant.const import CONF_SCAN_INTERVAL
from homeassistant.bootstrap import prepare_setup_platform from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.helpers import config_per_platform from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity import generate_entity_id from homeassistant.helpers.entity import generate_entity_id
@ -18,14 +14,14 @@ DEFAULT_SCAN_INTERVAL = 15
class EntityComponent(object): class EntityComponent(object):
"""Helper class that will help a component manage its entities."""
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
"""
Helper class that will help a component manage its entities.
"""
def __init__(self, logger, domain, hass, def __init__(self, logger, domain, hass,
scan_interval=DEFAULT_SCAN_INTERVAL, scan_interval=DEFAULT_SCAN_INTERVAL,
discovery_platforms=None, group_name=None): discovery_platforms=None, group_name=None):
"""Initialize an entity component."""
self.logger = logger self.logger = logger
self.hass = hass self.hass = hass
@ -42,11 +38,15 @@ class EntityComponent(object):
self.config = None self.config = None
self.lock = Lock() self.lock = Lock()
self.add_entities = EntityPlatform(self,
self.scan_interval).add_entities
def setup(self, config): def setup(self, config):
""" """
Sets up a full entity component: Set up a full entity component.
- Loads the platforms from the config
- Will listen for supported discovered platforms Loads the platforms from the config and will listen for supported
discovered platforms.
""" """
self.config = config self.config = config
@ -57,52 +57,18 @@ class EntityComponent(object):
self._setup_platform(p_type, p_config) self._setup_platform(p_type, p_config)
if self.discovery_platforms: if self.discovery_platforms:
discovery.listen(self.hass, self.discovery_platforms.keys(), discovery.listen(
self._entity_discovered) self.hass, self.discovery_platforms.keys(),
lambda service, info:
def add_entities(self, new_entities): self._setup_platform(self.discovery_platforms[service], {},
""" info))
Takes in a list of new entities. For each entity will see if it already
exists. If not, will add it, set it up and push the first state.
"""
with self.lock:
for entity in new_entities:
if entity is None or entity in self.entities.values():
continue
entity.hass = self.hass
if getattr(entity, 'entity_id', None) is None:
entity.entity_id = generate_entity_id(
self.entity_id_format, entity.name,
self.entities.keys())
self.entities[entity.entity_id] = entity
entity.update_ha_state()
if self.group is None and self.group_name is not None:
self.group = group.Group(self.hass, self.group_name,
user_defined=False)
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())
if self.is_polling or \
not any(entity.should_poll for entity
in self.entities.values()):
return
self.is_polling = True
track_utc_time_change(
self.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def extract_from_service(self, service): def extract_from_service(self, service):
""" """
Takes a service and extracts all known entities. Extract all known entities from a service call.
Will return all if no entity IDs given in service.
Will return all entities if no entities specified in call.
Will return an empty list if entities specified but unknown.
""" """
with self.lock: with self.lock:
if ATTR_ENTITY_ID not in service.data: if ATTR_ENTITY_ID not in service.data:
@ -112,29 +78,9 @@ class EntityComponent(object):
in extract_entity_ids(self.hass, service) in extract_entity_ids(self.hass, service)
if entity_id in self.entities] if entity_id in self.entities]
def _update_entity_states(self, now):
""" Update the states of all the entities. """
with self.lock:
# We copy the entities because new entities might be detected
# during state update causing deadlocks.
entities = list(entity for entity in self.entities.values()
if entity.should_poll)
self.logger.info("Updating %s entities", self.domain)
for entity in entities:
entity.update_ha_state(True)
def _entity_discovered(self, service, info):
""" Called when a entity is discovered. """
if service not in self.discovery_platforms:
return
self._setup_platform(self.discovery_platforms[service], {}, info)
def _setup_platform(self, platform_type, platform_config, def _setup_platform(self, platform_type, platform_config,
discovery_info=None): discovery_info=None):
""" Tries to setup a platform for this component. """ """Setup a platform for this component."""
platform = prepare_setup_platform( platform = prepare_setup_platform(
self.hass, self.config, self.domain, platform_type) self.hass, self.config, self.domain, platform_type)
@ -142,12 +88,85 @@ class EntityComponent(object):
return return
try: try:
# Config > Platform > Component
scan_interval = platform_config.get(
CONF_SCAN_INTERVAL,
getattr(platform, 'SCAN_INTERVAL', self.scan_interval))
platform.setup_platform( platform.setup_platform(
self.hass, platform_config, self.add_entities, discovery_info) self.hass, platform_config,
EntityPlatform(self, scan_interval).add_entities,
discovery_info)
platform_name = '{}.{}'.format(self.domain, platform_type)
self.hass.config.components.append(platform_name)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
self.logger.exception( self.logger.exception(
'Error while setting up platform %s', platform_type) 'Error while setting up platform %s', platform_type)
return return
platform_name = '{}.{}'.format(self.domain, platform_type) def add_entity(self, entity):
self.hass.config.components.append(platform_name) """Add entity to component."""
if entity is None or entity in self.entities.values():
return False
entity.hass = self.hass
if getattr(entity, 'entity_id', None) is None:
entity.entity_id = generate_entity_id(
self.entity_id_format, entity.name,
self.entities.keys())
self.entities[entity.entity_id] = entity
entity.update_ha_state()
return True
def update_group(self):
"""Set up and/or update component group."""
if self.group is None and self.group_name is not None:
self.group = group.Group(self.hass, self.group_name,
user_defined=False)
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())
class EntityPlatform(object):
"""Keep track of entities for a single platform."""
# pylint: disable=too-few-public-methods
def __init__(self, component, scan_interval):
self.component = component
self.scan_interval = scan_interval
self.platform_entities = []
self.is_polling = False
def add_entities(self, new_entities):
"""Add entities for a single platform."""
with self.component.lock:
for entity in new_entities:
if self.component.add_entity(entity):
self.platform_entities.append(entity)
self.component.update_group()
if self.is_polling or \
not any(entity.should_poll for entity
in self.platform_entities):
return
self.is_polling = True
track_utc_time_change(
self.component.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def _update_entity_states(self, now):
"""Update the states of all the polling entities."""
with self.component.lock:
# We copy the entities because new entities might be detected
# during state update causing deadlocks.
entities = list(entity for entity in self.platform_entities
if entity.should_poll)
for entity in entities:
entity.update_ha_state(True)

View File

@ -1,13 +0,0 @@
"""
homeassistant.util.environement
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Environement helpers.
"""
import sys
def is_virtual():
""" Return if we run in a virtual environtment. """
# Check supports venv && virtualenv
return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or
hasattr(sys, 'real_prefix'))

View File

@ -145,11 +145,26 @@ class MockHTTP(object):
class MockModule(object): class MockModule(object):
""" Provides a fake module. """ """ Provides a fake module. """
def __init__(self, domain, dependencies=[], setup=None): def __init__(self, domain=None, dependencies=[], setup=None):
self.DOMAIN = domain self.DOMAIN = domain
self.DEPENDENCIES = dependencies self.DEPENDENCIES = dependencies
# Setup a mock setup if none given. # Setup a mock setup if none given.
self.setup = lambda hass, config: False if setup is None else setup if setup is None:
self.setup = lambda hass, config: False
else:
self.setup = setup
class MockPlatform(object):
""" Provides a fake platform. """
def __init__(self, setup_platform=None, dependencies=[]):
self.DEPENDENCIES = dependencies
self._setup_platform = setup_platform
def setup_platform(self, hass, config, add_devices, discovery_info=None):
if self._setup_platform is not None:
self._setup_platform(hass, config, add_devices, discovery_info)
class MockToggleDevice(ToggleEntity): class MockToggleDevice(ToggleEntity):

View File

@ -0,0 +1,277 @@
"""
tests.test_helper_entity_component
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tests the entity component helper.
"""
# pylint: disable=protected-access,too-many-public-methods
from collections import OrderedDict
import logging
import unittest
from unittest.mock import patch, Mock
import homeassistant.core as ha
import homeassistant.loader as loader
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.components import discovery
import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, MockModule, fire_time_changed)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
class EntityTest(Entity):
def __init__(self, **values):
self._values = values
if 'entity_id' in values:
self.entity_id = values['entity_id']
@property
def name(self):
return self._handle('name')
@property
def should_poll(self):
return self._handle('should_poll')
@property
def unique_id(self):
return self._handle('unique_id')
def _handle(self, attr):
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)
class TestHelpersEntityComponent(unittest.TestCase):
""" Tests homeassistant.helpers.entity_component module. """
def setUp(self): # pylint: disable=invalid-name
"""Initialize a test Home Assistant instance."""
self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name
"""Clean up the test Home Assistant instance."""
self.hass.stop()
def test_setting_up_group(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass,
group_name='everyone')
# No group after setup
assert 0 == len(self.hass.states.entity_ids())
component.add_entities([EntityTest(name='hello')])
# group exists
assert 2 == len(self.hass.states.entity_ids())
assert ['group.everyone'] == self.hass.states.entity_ids('group')
group = self.hass.states.get('group.everyone')
assert ('test_domain.hello',) == group.attributes.get('entity_id')
# group extended
component.add_entities([EntityTest(name='hello2')])
assert 3 == len(self.hass.states.entity_ids())
group = self.hass.states.get('group.everyone')
assert ['test_domain.hello', 'test_domain.hello2'] == \
sorted(group.attributes.get('entity_id'))
def test_polling_only_updates_entities_it_should_poll(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass, 20)
no_poll_ent = EntityTest(should_poll=False)
no_poll_ent.update_ha_state = Mock()
poll_ent = EntityTest(should_poll=True)
poll_ent.update_ha_state = Mock()
component.add_entities([no_poll_ent, poll_ent])
no_poll_ent.update_ha_state.reset_mock()
poll_ent.update_ha_state.reset_mock()
fire_time_changed(self.hass, dt_util.utcnow().replace(second=0))
self.hass.pool.block_till_done()
assert not no_poll_ent.update_ha_state.called
assert poll_ent.update_ha_state.called
def test_update_state_adds_entities(self):
"""Test if updating poll entities cause an entity to be added works."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent1 = EntityTest()
ent2 = EntityTest(should_poll=True)
component.add_entities([ent2])
assert 1 == len(self.hass.states.entity_ids())
ent2.update_ha_state = lambda *_: component.add_entities([ent1])
fire_time_changed(self.hass, dt_util.utcnow().replace(second=0))
self.hass.pool.block_till_done()
assert 2 == len(self.hass.states.entity_ids())
def test_not_adding_duplicate_entities(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert 0 == len(self.hass.states.entity_ids())
component.add_entities([None, EntityTest(unique_id='not_very_unique')])
assert 1 == len(self.hass.states.entity_ids())
component.add_entities([EntityTest(unique_id='not_very_unique')])
assert 1 == len(self.hass.states.entity_ids())
def test_not_assigning_entity_id_if_prescribes_one(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert 'hello.world' not in self.hass.states.entity_ids()
component.add_entities([EntityTest(entity_id='hello.world')])
assert 'hello.world' in self.hass.states.entity_ids()
def test_extract_from_service_returns_all_if_no_entity_id(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service')
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.extract_from_service(call))
def test_extract_from_service_filter_out_non_existing_entities(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['test_domain.test_2', 'test_domain.non_exist']
})
assert ['test_domain.test_2'] == \
[ent.entity_id for ent in component.extract_from_service(call)]
def test_setup_loads_platforms(self):
component_setup = Mock(return_value=True)
platform_setup = Mock(return_value=None)
loader.set_component(
'test_component',
MockModule('test_component', setup=component_setup))
loader.set_component('test_domain.mod2',
MockPlatform(platform_setup, ['test_component']))
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert not component_setup.called
assert not platform_setup.called
component.setup({
DOMAIN: {
'platform': 'mod2',
}
})
assert component_setup.called
assert platform_setup.called
def test_setup_recovers_when_setup_raises(self):
platform1_setup = Mock(side_effect=Exception('Broken'))
platform2_setup = Mock(return_value=None)
loader.set_component('test_domain.mod1', MockPlatform(platform1_setup))
loader.set_component('test_domain.mod2', MockPlatform(platform2_setup))
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert not platform1_setup.called
assert not platform2_setup.called
component.setup(OrderedDict([
(DOMAIN, {'platform': 'mod1'}),
("{} 2".format(DOMAIN), {'platform': 'non_exist'}),
("{} 3".format(DOMAIN), {'platform': 'mod2'}),
]))
assert platform1_setup.called
assert platform2_setup.called
@patch('homeassistant.helpers.entity_component.EntityComponent'
'._setup_platform')
def test_setup_does_discovery(self, mock_setup):
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, discovery_platforms={
'discovery.test': 'platform_test',
})
component.setup({})
self.hass.bus.fire(discovery.EVENT_PLATFORM_DISCOVERED, {
discovery.ATTR_SERVICE: 'discovery.test',
discovery.ATTR_DISCOVERED: 'discovery_info',
})
self.hass.pool.block_till_done()
assert mock_setup.called
assert ('platform_test', {}, 'discovery_info') == \
mock_setup.call_args[0]
@patch('homeassistant.helpers.entity_component.track_utc_time_change')
def test_set_scan_interval_via_config(self, mock_track):
def platform_setup(hass, config, add_devices, discovery_info=None):
add_devices([EntityTest(should_poll=True)])
loader.set_component('test_domain.platform',
MockPlatform(platform_setup))
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.setup({
DOMAIN: {
'platform': 'platform',
'scan_interval': 30,
}
})
assert mock_track.called
assert [0, 30] == list(mock_track.call_args[1]['second'])
@patch('homeassistant.helpers.entity_component.track_utc_time_change')
def test_set_scan_interval_via_platform(self, mock_track):
def platform_setup(hass, config, add_devices, discovery_info=None):
add_devices([EntityTest(should_poll=True)])
platform = MockPlatform(platform_setup)
platform.SCAN_INTERVAL = 30
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.setup({
DOMAIN: {
'platform': 'platform',
}
})
assert mock_track.called
assert [0, 30] == list(mock_track.call_args[1]['second'])