From 90e17fc77f6fcabe84de8402bb8c62946d43126d Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 30 Jan 2016 18:55:52 -0800 Subject: [PATCH 1/3] Add tests for entity component --- homeassistant/helpers/entity.py | 5 +- homeassistant/helpers/entity_component.py | 50 +++-- tests/common.py | 19 +- tests/helpers/test_entity_component.py | 236 ++++++++++++++++++++++ 4 files changed, 278 insertions(+), 32 deletions(-) create mode 100644 tests/helpers/test_entity_component.py diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index ab5707a0121..8b4e1be2d52 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -50,8 +50,6 @@ class Entity(object): """ ABC for Home Assistant entities. """ # pylint: disable=no-self-use - _hidden = False - # SAFE TO OVERWRITE # The properties and methods here are safe to overwrite when inherting this # class. These may be used to customize the behavior of the entity. @@ -103,13 +101,14 @@ class Entity(object): """ Retrieve latest state. """ pass + entity_id = None + # DO NOT OVERWRITE # These properties and methods are either managed by Home Assistant or they # are used to perform a very specific function. Overwriting these may # produce undesirable effects in the entity's operation. hass = None - entity_id = None def update_ha_state(self, force_refresh=False): """ diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 0450a788809..3382d90b62b 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -1,9 +1,4 @@ -""" -homeassistant.helpers.entity_component -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Provides helpers for components that manage entities. -""" +"""Provides helpers for components that manage entities.""" from threading import Lock from homeassistant.bootstrap import prepare_setup_platform @@ -18,14 +13,14 @@ DEFAULT_SCAN_INTERVAL = 15 class EntityComponent(object): + """Helper class that will help a component manage its entities.""" + # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-arguments - """ - Helper class that will help a component manage its entities. - """ def __init__(self, logger, domain, hass, scan_interval=DEFAULT_SCAN_INTERVAL, discovery_platforms=None, group_name=None): + """Initialize an entity component.""" self.logger = logger self.hass = hass @@ -44,9 +39,10 @@ class EntityComponent(object): def setup(self, config): """ - Sets up a full entity component: - - Loads the platforms from the config - - Will listen for supported discovered platforms + Set up a full entity component. + + Loads the platforms from the config and will listen for supported + discovered platforms. """ self.config = config @@ -57,13 +53,18 @@ class EntityComponent(object): self._setup_platform(p_type, p_config) if self.discovery_platforms: - discovery.listen(self.hass, self.discovery_platforms.keys(), - self._entity_discovered) + discovery.listen( + self.hass, self.discovery_platforms.keys(), + lambda service, info: + self._setup_platform(self.discovery_platforms[service], {}, + info)) def add_entities(self, new_entities): """ - Takes in a list of new entities. For each entity will see if it already - exists. If not, will add it, set it up and push the first state. + Add new entities to this component. + + For each entity will see if it already exists. If not, will add it, + set it up and push the first state. """ with self.lock: for entity in new_entities: @@ -101,8 +102,10 @@ class EntityComponent(object): def extract_from_service(self, service): """ - Takes a service and extracts all known entities. - Will return all if no entity IDs given in service. + Extract all known entities from a service call. + + Will return all entities if no entities specified in call. + Will return an empty list if entities specified but unknown. """ with self.lock: if ATTR_ENTITY_ID not in service.data: @@ -113,7 +116,7 @@ class EntityComponent(object): if entity_id in self.entities] def _update_entity_states(self, now): - """ Update the states of all the entities. """ + """Update the states of all the polling entities.""" with self.lock: # We copy the entities because new entities might be detected # during state update causing deadlocks. @@ -125,16 +128,9 @@ class EntityComponent(object): for entity in entities: entity.update_ha_state(True) - def _entity_discovered(self, service, info): - """ Called when a entity is discovered. """ - if service not in self.discovery_platforms: - return - - self._setup_platform(self.discovery_platforms[service], {}, info) - def _setup_platform(self, platform_type, platform_config, discovery_info=None): - """ Tries to setup a platform for this component. """ + """Setup a platform for this component.""" platform = prepare_setup_platform( self.hass, self.config, self.domain, platform_type) diff --git a/tests/common.py b/tests/common.py index b8108c673fd..350786c8e14 100644 --- a/tests/common.py +++ b/tests/common.py @@ -145,11 +145,26 @@ class MockHTTP(object): class MockModule(object): """ Provides a fake module. """ - def __init__(self, domain, dependencies=[], setup=None): + def __init__(self, domain=None, dependencies=[], setup=None): self.DOMAIN = domain self.DEPENDENCIES = dependencies # Setup a mock setup if none given. - self.setup = lambda hass, config: False if setup is None else setup + if setup is None: + self.setup = lambda hass, config: False + else: + self.setup = setup + + +class MockPlatform(object): + """ Provides a fake platform. """ + + def __init__(self, setup_platform=None, dependencies=[]): + self.DEPENDENCIES = dependencies + self._setup_platform = setup_platform + + def setup_platform(self, hass, config, add_devices, discovery_info=None): + if self._setup_platform is not None: + self._setup_platform(hass, config, add_devices, discovery_info) class MockToggleDevice(ToggleEntity): diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py new file mode 100644 index 00000000000..5f21c2193ce --- /dev/null +++ b/tests/helpers/test_entity_component.py @@ -0,0 +1,236 @@ +""" +tests.test_helper_entity_component +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Tests the entity component helper. +""" +# pylint: disable=protected-access,too-many-public-methods +from collections import OrderedDict +import logging +import unittest +from unittest.mock import patch, Mock + +import homeassistant.core as ha +import homeassistant.loader as loader +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.components import discovery + +from tests.common import get_test_home_assistant, MockPlatform, MockModule + +_LOGGER = logging.getLogger(__name__) +DOMAIN = "test_domain" + + +class EntityTest(Entity): + def __init__(self, **values): + self._values = values + + if 'entity_id' in values: + self.entity_id = values['entity_id'] + + @property + def name(self): + return self._handle('name') + + @property + def should_poll(self): + return self._handle('should_poll') + + @property + def unique_id(self): + return self._handle('unique_id') + + def _handle(self, attr): + if attr in self._values: + return self._values[attr] + return getattr(super(), attr) + + +class TestHelpersEntityComponent(unittest.TestCase): + """ Tests homeassistant.helpers.entity_component module. """ + + def setUp(self): # pylint: disable=invalid-name + """Initialize a test Home Assistant instance.""" + self.hass = get_test_home_assistant() + + def tearDown(self): # pylint: disable=invalid-name + """Clean up the test Home Assistant instance.""" + self.hass.stop() + + def test_setting_up_group(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass, + group_name='everyone') + + # No group after setup + assert 0 == len(self.hass.states.entity_ids()) + + component.add_entities([EntityTest(name='hello')]) + + # group exists + assert 2 == len(self.hass.states.entity_ids()) + assert ['group.everyone'] == self.hass.states.entity_ids('group') + + group = self.hass.states.get('group.everyone') + + assert ('test_domain.hello',) == group.attributes.get('entity_id') + + # group extended + component.add_entities([EntityTest(name='hello2')]) + + assert 3 == len(self.hass.states.entity_ids()) + group = self.hass.states.get('group.everyone') + + assert ['test_domain.hello', 'test_domain.hello2'] == \ + sorted(group.attributes.get('entity_id')) + + @patch('homeassistant.helpers.entity_component.track_utc_time_change') + def test_polling_only_updates_entities_it_should_poll(self, mock_track): + component = EntityComponent(_LOGGER, DOMAIN, self.hass, 20) + + no_poll_ent = EntityTest(should_poll=False) + no_poll_ent.update_ha_state = Mock() + poll_ent = EntityTest(should_poll=True) + poll_ent.update_ha_state = Mock() + + component.add_entities([no_poll_ent]) + assert not mock_track.called + + component.add_entities([poll_ent]) + assert mock_track.called + assert [0, 20, 40] == list(mock_track.call_args[1].get('second')) + + no_poll_ent.update_ha_state.reset_mock() + poll_ent.update_ha_state.reset_mock() + + component._update_entity_states(None) + + assert not no_poll_ent.update_ha_state.called + assert poll_ent.update_ha_state.called + + def test_update_state_adds_entities(self): + """Test if updating poll entities cause an entity to be added works.""" + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + ent1 = EntityTest() + ent2 = EntityTest(should_poll=True) + + component.add_entities([ent2]) + assert 1 == len(self.hass.states.entity_ids()) + ent2.update_ha_state = lambda *_: component.add_entities([ent1]) + component._update_entity_states(None) + assert 2 == len(self.hass.states.entity_ids()) + + def test_not_adding_duplicate_entities(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert 0 == len(self.hass.states.entity_ids()) + + component.add_entities([None, EntityTest(unique_id='not_very_unique')]) + + assert 1 == len(self.hass.states.entity_ids()) + + component.add_entities([EntityTest(unique_id='not_very_unique')]) + + assert 1 == len(self.hass.states.entity_ids()) + + def test_not_assigning_entity_id_if_prescribes_one(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert 'hello.world' not in self.hass.states.entity_ids() + + component.add_entities([EntityTest(entity_id='hello.world')]) + + assert 'hello.world' in self.hass.states.entity_ids() + + def test_extract_from_service_returns_all_if_no_entity_id(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + component.add_entities([ + EntityTest(name='test_1'), + EntityTest(name='test_2'), + ]) + + call = ha.ServiceCall('test', 'service') + + assert ['test_domain.test_1', 'test_domain.test_2'] == \ + sorted(ent.entity_id for ent in + component.extract_from_service(call)) + + def test_extract_from_service_filter_out_non_existing_entities(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + component.add_entities([ + EntityTest(name='test_1'), + EntityTest(name='test_2'), + ]) + + call = ha.ServiceCall('test', 'service', { + 'entity_id': ['test_domain.test_2', 'test_domain.non_exist'] + }) + + assert ['test_domain.test_2'] == \ + [ent.entity_id for ent in component.extract_from_service(call)] + + def test_setup_loads_platforms(self): + component_setup = Mock(return_value=True) + platform_setup = Mock(return_value=None) + loader.set_component( + 'test_component', + MockModule('test_component', setup=component_setup)) + loader.set_component('test_domain.mod2', + MockPlatform(platform_setup, ['test_component'])) + + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert not component_setup.called + assert not platform_setup.called + + component.setup({ + DOMAIN: { + 'platform': 'mod2', + } + }) + + assert component_setup.called + assert platform_setup.called + + def test_setup_recovers_when_setup_raises(self): + platform1_setup = Mock(side_effect=Exception('Broken')) + platform2_setup = Mock(return_value=None) + + loader.set_component('test_domain.mod1', MockPlatform(platform1_setup)) + loader.set_component('test_domain.mod2', MockPlatform(platform2_setup)) + + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert not platform1_setup.called + assert not platform2_setup.called + + component.setup(OrderedDict([ + (DOMAIN, {'platform': 'mod1'}), + ("{} 2".format(DOMAIN), {'platform': 'non_exist'}), + ("{} 3".format(DOMAIN), {'platform': 'mod2'}), + ])) + + assert platform1_setup.called + assert platform2_setup.called + + @patch('homeassistant.helpers.entity_component.EntityComponent' + '._setup_platform') + def test_setup_does_discovery(self, mock_setup): + component = EntityComponent( + _LOGGER, DOMAIN, self.hass, discovery_platforms={ + 'discovery.test': 'platform_test', + }) + + component.setup({}) + + self.hass.bus.fire(discovery.EVENT_PLATFORM_DISCOVERED, { + discovery.ATTR_SERVICE: 'discovery.test', + discovery.ATTR_DISCOVERED: 'discovery_info', + }) + + self.hass.pool.block_till_done() + + assert mock_setup.called + assert ('platform_test', {}, 'discovery_info') == \ + mock_setup.call_args[0] From fce8815ab47c562f07a582100d76f43ca60fc55e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 31 Jan 2016 00:55:46 -0800 Subject: [PATCH 2/3] Support custom interval for platforms --- homeassistant/const.py | 2 +- homeassistant/helpers/entity_component.py | 137 +++++++++++++--------- tests/helpers/test_entity_component.py | 63 ++++++++-- 3 files changed, 133 insertions(+), 69 deletions(-) diff --git a/homeassistant/const.py b/homeassistant/const.py index e28c418d9e4..87be5fa5f6f 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -26,7 +26,7 @@ CONF_PASSWORD = "password" CONF_API_KEY = "api_key" CONF_ACCESS_TOKEN = "access_token" CONF_FILENAME = "filename" - +CONF_SCAN_INTERVAL = "scan_interval" CONF_VALUE_TEMPLATE = "value_template" # #### EVENTS #### diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 3382d90b62b..268e4e7b696 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -1,6 +1,7 @@ """Provides helpers for components that manage entities.""" from threading import Lock +from homeassistant.const import CONF_SCAN_INTERVAL from homeassistant.bootstrap import prepare_setup_platform from homeassistant.helpers import config_per_platform from homeassistant.helpers.entity import generate_entity_id @@ -37,6 +38,9 @@ class EntityComponent(object): self.config = None self.lock = Lock() + self.add_entities = EntityPlatform(self, + self.scan_interval).add_entities + def setup(self, config): """ Set up a full entity component. @@ -59,47 +63,6 @@ class EntityComponent(object): self._setup_platform(self.discovery_platforms[service], {}, info)) - def add_entities(self, new_entities): - """ - Add new entities to this component. - - For each entity will see if it already exists. If not, will add it, - set it up and push the first state. - """ - with self.lock: - for entity in new_entities: - if entity is None or entity in self.entities.values(): - continue - - entity.hass = self.hass - - if getattr(entity, 'entity_id', None) is None: - entity.entity_id = generate_entity_id( - self.entity_id_format, entity.name, - self.entities.keys()) - - self.entities[entity.entity_id] = entity - - entity.update_ha_state() - - if self.group is None and self.group_name is not None: - self.group = group.Group(self.hass, self.group_name, - user_defined=False) - - if self.group is not None: - self.group.update_tracked_entity_ids(self.entities.keys()) - - if self.is_polling or \ - not any(entity.should_poll for entity - in self.entities.values()): - return - - self.is_polling = True - - track_utc_time_change( - self.hass, self._update_entity_states, - second=range(0, 60, self.scan_interval)) - def extract_from_service(self, service): """ Extract all known entities from a service call. @@ -115,19 +78,6 @@ class EntityComponent(object): in extract_entity_ids(self.hass, service) if entity_id in self.entities] - def _update_entity_states(self, now): - """Update the states of all the polling entities.""" - with self.lock: - # We copy the entities because new entities might be detected - # during state update causing deadlocks. - entities = list(entity for entity in self.entities.values() - if entity.should_poll) - - self.logger.info("Updating %s entities", self.domain) - - for entity in entities: - entity.update_ha_state(True) - def _setup_platform(self, platform_type, platform_config, discovery_info=None): """Setup a platform for this component.""" @@ -138,12 +88,85 @@ class EntityComponent(object): return try: + # Config > Platform > Component + scan_interval = platform_config.get( + CONF_SCAN_INTERVAL, + getattr(platform, 'SCAN_INTERVAL', self.scan_interval)) platform.setup_platform( - self.hass, platform_config, self.add_entities, discovery_info) + self.hass, platform_config, + EntityPlatform(self, scan_interval).add_entities, + discovery_info) + platform_name = '{}.{}'.format(self.domain, platform_type) + self.hass.config.components.append(platform_name) except Exception: # pylint: disable=broad-except self.logger.exception( 'Error while setting up platform %s', platform_type) return - platform_name = '{}.{}'.format(self.domain, platform_type) - self.hass.config.components.append(platform_name) + def add_entity(self, entity): + """Add entity to component.""" + if entity is None or entity in self.entities.values(): + return False + + entity.hass = self.hass + + if getattr(entity, 'entity_id', None) is None: + entity.entity_id = generate_entity_id( + self.entity_id_format, entity.name, + self.entities.keys()) + + self.entities[entity.entity_id] = entity + entity.update_ha_state() + + return True + + def update_group(self): + """Set up and/or update component group.""" + if self.group is None and self.group_name is not None: + self.group = group.Group(self.hass, self.group_name, + user_defined=False) + + if self.group is not None: + self.group.update_tracked_entity_ids(self.entities.keys()) + + +class EntityPlatform(object): + """Keep track of entities for a single platform.""" + + # pylint: disable=too-few-public-methods + def __init__(self, component, scan_interval): + self.component = component + self.scan_interval = scan_interval + self.platform_entities = [] + self.is_polling = False + + def add_entities(self, new_entities): + """Add entities for a single platform.""" + with self.component.lock: + for entity in new_entities: + if self.component.add_entity(entity): + self.platform_entities.append(entity) + + self.component.update_group() + + if self.is_polling or \ + not any(entity.should_poll for entity + in self.platform_entities): + return + + self.is_polling = True + + track_utc_time_change( + self.component.hass, self._update_entity_states, + second=range(0, 60, self.scan_interval)) + + def _update_entity_states(self, now): + """Update the states of all the polling entities.""" + with self.component.lock: + # We copy the entities because new entities might be detected + # during state update causing deadlocks. + entities = list(entity for entity in self.platform_entities + if entity.should_poll) + + for entity in entities: + entity.update_ha_state(True) diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 5f21c2193ce..68aecd32f5b 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -15,8 +15,10 @@ import homeassistant.loader as loader from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.components import discovery +import homeassistant.util.dt as dt_util -from tests.common import get_test_home_assistant, MockPlatform, MockModule +from tests.common import ( + get_test_home_assistant, MockPlatform, MockModule, fire_time_changed) _LOGGER = logging.getLogger(__name__) DOMAIN = "test_domain" @@ -84,8 +86,7 @@ class TestHelpersEntityComponent(unittest.TestCase): assert ['test_domain.hello', 'test_domain.hello2'] == \ sorted(group.attributes.get('entity_id')) - @patch('homeassistant.helpers.entity_component.track_utc_time_change') - def test_polling_only_updates_entities_it_should_poll(self, mock_track): + def test_polling_only_updates_entities_it_should_poll(self): component = EntityComponent(_LOGGER, DOMAIN, self.hass, 20) no_poll_ent = EntityTest(should_poll=False) @@ -93,17 +94,13 @@ class TestHelpersEntityComponent(unittest.TestCase): poll_ent = EntityTest(should_poll=True) poll_ent.update_ha_state = Mock() - component.add_entities([no_poll_ent]) - assert not mock_track.called - - component.add_entities([poll_ent]) - assert mock_track.called - assert [0, 20, 40] == list(mock_track.call_args[1].get('second')) + component.add_entities([no_poll_ent, poll_ent]) no_poll_ent.update_ha_state.reset_mock() poll_ent.update_ha_state.reset_mock() - component._update_entity_states(None) + fire_time_changed(self.hass, dt_util.utcnow().replace(second=0)) + self.hass.pool.block_till_done() assert not no_poll_ent.update_ha_state.called assert poll_ent.update_ha_state.called @@ -118,7 +115,10 @@ class TestHelpersEntityComponent(unittest.TestCase): component.add_entities([ent2]) assert 1 == len(self.hass.states.entity_ids()) ent2.update_ha_state = lambda *_: component.add_entities([ent1]) - component._update_entity_states(None) + + fire_time_changed(self.hass, dt_util.utcnow().replace(second=0)) + self.hass.pool.block_till_done() + assert 2 == len(self.hass.states.entity_ids()) def test_not_adding_duplicate_entities(self): @@ -234,3 +234,44 @@ class TestHelpersEntityComponent(unittest.TestCase): assert mock_setup.called assert ('platform_test', {}, 'discovery_info') == \ mock_setup.call_args[0] + + @patch('homeassistant.helpers.entity_component.track_utc_time_change') + def test_set_scan_interval_via_config(self, mock_track): + def platform_setup(hass, config, add_devices, discovery_info=None): + add_devices([EntityTest(should_poll=True)]) + + loader.set_component('test_domain.platform', + MockPlatform(platform_setup)) + + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + component.setup({ + DOMAIN: { + 'platform': 'platform', + 'scan_interval': 30, + } + }) + + assert mock_track.called + assert [0, 30] == list(mock_track.call_args[1]['second']) + + @patch('homeassistant.helpers.entity_component.track_utc_time_change') + def test_set_scan_interval_via_platform(self, mock_track): + def platform_setup(hass, config, add_devices, discovery_info=None): + add_devices([EntityTest(should_poll=True)]) + + platform = MockPlatform(platform_setup) + platform.SCAN_INTERVAL = 30 + + loader.set_component('test_domain.platform', platform) + + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + component.setup({ + DOMAIN: { + 'platform': 'platform', + } + }) + + assert mock_track.called + assert [0, 30] == list(mock_track.call_args[1]['second']) From 0b8e0977051794e35e2cb7115daa27364c6b3ce9 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 31 Jan 2016 08:58:30 -0800 Subject: [PATCH 3/3] Remove unused environment util --- homeassistant/util/environment.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 homeassistant/util/environment.py diff --git a/homeassistant/util/environment.py b/homeassistant/util/environment.py deleted file mode 100644 index ea4c69e8f13..00000000000 --- a/homeassistant/util/environment.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -homeassistant.util.environement -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Environement helpers. -""" -import sys - - -def is_virtual(): - """ Return if we run in a virtual environtment. """ - # Check supports venv && virtualenv - return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or - hasattr(sys, 'real_prefix'))