diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index ab5707a0121..8b4e1be2d52 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -50,8 +50,6 @@ class Entity(object): """ ABC for Home Assistant entities. """ # pylint: disable=no-self-use - _hidden = False - # SAFE TO OVERWRITE # The properties and methods here are safe to overwrite when inherting this # class. These may be used to customize the behavior of the entity. @@ -103,13 +101,14 @@ class Entity(object): """ Retrieve latest state. """ pass + entity_id = None + # DO NOT OVERWRITE # These properties and methods are either managed by Home Assistant or they # are used to perform a very specific function. Overwriting these may # produce undesirable effects in the entity's operation. hass = None - entity_id = None def update_ha_state(self, force_refresh=False): """ diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 0450a788809..3382d90b62b 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -1,9 +1,4 @@ -""" -homeassistant.helpers.entity_component -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Provides helpers for components that manage entities. -""" +"""Provides helpers for components that manage entities.""" from threading import Lock from homeassistant.bootstrap import prepare_setup_platform @@ -18,14 +13,14 @@ DEFAULT_SCAN_INTERVAL = 15 class EntityComponent(object): + """Helper class that will help a component manage its entities.""" + # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-arguments - """ - Helper class that will help a component manage its entities. - """ def __init__(self, logger, domain, hass, scan_interval=DEFAULT_SCAN_INTERVAL, discovery_platforms=None, group_name=None): + """Initialize an entity component.""" self.logger = logger self.hass = hass @@ -44,9 +39,10 @@ class EntityComponent(object): def setup(self, config): """ - Sets up a full entity component: - - Loads the platforms from the config - - Will listen for supported discovered platforms + Set up a full entity component. + + Loads the platforms from the config and will listen for supported + discovered platforms. """ self.config = config @@ -57,13 +53,18 @@ class EntityComponent(object): self._setup_platform(p_type, p_config) if self.discovery_platforms: - discovery.listen(self.hass, self.discovery_platforms.keys(), - self._entity_discovered) + discovery.listen( + self.hass, self.discovery_platforms.keys(), + lambda service, info: + self._setup_platform(self.discovery_platforms[service], {}, + info)) def add_entities(self, new_entities): """ - Takes in a list of new entities. For each entity will see if it already - exists. If not, will add it, set it up and push the first state. + Add new entities to this component. + + For each entity will see if it already exists. If not, will add it, + set it up and push the first state. """ with self.lock: for entity in new_entities: @@ -101,8 +102,10 @@ class EntityComponent(object): def extract_from_service(self, service): """ - Takes a service and extracts all known entities. - Will return all if no entity IDs given in service. + Extract all known entities from a service call. + + Will return all entities if no entities specified in call. + Will return an empty list if entities specified but unknown. """ with self.lock: if ATTR_ENTITY_ID not in service.data: @@ -113,7 +116,7 @@ class EntityComponent(object): if entity_id in self.entities] def _update_entity_states(self, now): - """ Update the states of all the entities. """ + """Update the states of all the polling entities.""" with self.lock: # We copy the entities because new entities might be detected # during state update causing deadlocks. @@ -125,16 +128,9 @@ class EntityComponent(object): for entity in entities: entity.update_ha_state(True) - def _entity_discovered(self, service, info): - """ Called when a entity is discovered. """ - if service not in self.discovery_platforms: - return - - self._setup_platform(self.discovery_platforms[service], {}, info) - def _setup_platform(self, platform_type, platform_config, discovery_info=None): - """ Tries to setup a platform for this component. """ + """Setup a platform for this component.""" platform = prepare_setup_platform( self.hass, self.config, self.domain, platform_type) diff --git a/tests/common.py b/tests/common.py index b8108c673fd..350786c8e14 100644 --- a/tests/common.py +++ b/tests/common.py @@ -145,11 +145,26 @@ class MockHTTP(object): class MockModule(object): """ Provides a fake module. """ - def __init__(self, domain, dependencies=[], setup=None): + def __init__(self, domain=None, dependencies=[], setup=None): self.DOMAIN = domain self.DEPENDENCIES = dependencies # Setup a mock setup if none given. - self.setup = lambda hass, config: False if setup is None else setup + if setup is None: + self.setup = lambda hass, config: False + else: + self.setup = setup + + +class MockPlatform(object): + """ Provides a fake platform. """ + + def __init__(self, setup_platform=None, dependencies=[]): + self.DEPENDENCIES = dependencies + self._setup_platform = setup_platform + + def setup_platform(self, hass, config, add_devices, discovery_info=None): + if self._setup_platform is not None: + self._setup_platform(hass, config, add_devices, discovery_info) class MockToggleDevice(ToggleEntity): diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py new file mode 100644 index 00000000000..5f21c2193ce --- /dev/null +++ b/tests/helpers/test_entity_component.py @@ -0,0 +1,236 @@ +""" +tests.test_helper_entity_component +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Tests the entity component helper. +""" +# pylint: disable=protected-access,too-many-public-methods +from collections import OrderedDict +import logging +import unittest +from unittest.mock import patch, Mock + +import homeassistant.core as ha +import homeassistant.loader as loader +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.components import discovery + +from tests.common import get_test_home_assistant, MockPlatform, MockModule + +_LOGGER = logging.getLogger(__name__) +DOMAIN = "test_domain" + + +class EntityTest(Entity): + def __init__(self, **values): + self._values = values + + if 'entity_id' in values: + self.entity_id = values['entity_id'] + + @property + def name(self): + return self._handle('name') + + @property + def should_poll(self): + return self._handle('should_poll') + + @property + def unique_id(self): + return self._handle('unique_id') + + def _handle(self, attr): + if attr in self._values: + return self._values[attr] + return getattr(super(), attr) + + +class TestHelpersEntityComponent(unittest.TestCase): + """ Tests homeassistant.helpers.entity_component module. """ + + def setUp(self): # pylint: disable=invalid-name + """Initialize a test Home Assistant instance.""" + self.hass = get_test_home_assistant() + + def tearDown(self): # pylint: disable=invalid-name + """Clean up the test Home Assistant instance.""" + self.hass.stop() + + def test_setting_up_group(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass, + group_name='everyone') + + # No group after setup + assert 0 == len(self.hass.states.entity_ids()) + + component.add_entities([EntityTest(name='hello')]) + + # group exists + assert 2 == len(self.hass.states.entity_ids()) + assert ['group.everyone'] == self.hass.states.entity_ids('group') + + group = self.hass.states.get('group.everyone') + + assert ('test_domain.hello',) == group.attributes.get('entity_id') + + # group extended + component.add_entities([EntityTest(name='hello2')]) + + assert 3 == len(self.hass.states.entity_ids()) + group = self.hass.states.get('group.everyone') + + assert ['test_domain.hello', 'test_domain.hello2'] == \ + sorted(group.attributes.get('entity_id')) + + @patch('homeassistant.helpers.entity_component.track_utc_time_change') + def test_polling_only_updates_entities_it_should_poll(self, mock_track): + component = EntityComponent(_LOGGER, DOMAIN, self.hass, 20) + + no_poll_ent = EntityTest(should_poll=False) + no_poll_ent.update_ha_state = Mock() + poll_ent = EntityTest(should_poll=True) + poll_ent.update_ha_state = Mock() + + component.add_entities([no_poll_ent]) + assert not mock_track.called + + component.add_entities([poll_ent]) + assert mock_track.called + assert [0, 20, 40] == list(mock_track.call_args[1].get('second')) + + no_poll_ent.update_ha_state.reset_mock() + poll_ent.update_ha_state.reset_mock() + + component._update_entity_states(None) + + assert not no_poll_ent.update_ha_state.called + assert poll_ent.update_ha_state.called + + def test_update_state_adds_entities(self): + """Test if updating poll entities cause an entity to be added works.""" + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + ent1 = EntityTest() + ent2 = EntityTest(should_poll=True) + + component.add_entities([ent2]) + assert 1 == len(self.hass.states.entity_ids()) + ent2.update_ha_state = lambda *_: component.add_entities([ent1]) + component._update_entity_states(None) + assert 2 == len(self.hass.states.entity_ids()) + + def test_not_adding_duplicate_entities(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert 0 == len(self.hass.states.entity_ids()) + + component.add_entities([None, EntityTest(unique_id='not_very_unique')]) + + assert 1 == len(self.hass.states.entity_ids()) + + component.add_entities([EntityTest(unique_id='not_very_unique')]) + + assert 1 == len(self.hass.states.entity_ids()) + + def test_not_assigning_entity_id_if_prescribes_one(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert 'hello.world' not in self.hass.states.entity_ids() + + component.add_entities([EntityTest(entity_id='hello.world')]) + + assert 'hello.world' in self.hass.states.entity_ids() + + def test_extract_from_service_returns_all_if_no_entity_id(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + component.add_entities([ + EntityTest(name='test_1'), + EntityTest(name='test_2'), + ]) + + call = ha.ServiceCall('test', 'service') + + assert ['test_domain.test_1', 'test_domain.test_2'] == \ + sorted(ent.entity_id for ent in + component.extract_from_service(call)) + + def test_extract_from_service_filter_out_non_existing_entities(self): + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + component.add_entities([ + EntityTest(name='test_1'), + EntityTest(name='test_2'), + ]) + + call = ha.ServiceCall('test', 'service', { + 'entity_id': ['test_domain.test_2', 'test_domain.non_exist'] + }) + + assert ['test_domain.test_2'] == \ + [ent.entity_id for ent in component.extract_from_service(call)] + + def test_setup_loads_platforms(self): + component_setup = Mock(return_value=True) + platform_setup = Mock(return_value=None) + loader.set_component( + 'test_component', + MockModule('test_component', setup=component_setup)) + loader.set_component('test_domain.mod2', + MockPlatform(platform_setup, ['test_component'])) + + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert not component_setup.called + assert not platform_setup.called + + component.setup({ + DOMAIN: { + 'platform': 'mod2', + } + }) + + assert component_setup.called + assert platform_setup.called + + def test_setup_recovers_when_setup_raises(self): + platform1_setup = Mock(side_effect=Exception('Broken')) + platform2_setup = Mock(return_value=None) + + loader.set_component('test_domain.mod1', MockPlatform(platform1_setup)) + loader.set_component('test_domain.mod2', MockPlatform(platform2_setup)) + + component = EntityComponent(_LOGGER, DOMAIN, self.hass) + + assert not platform1_setup.called + assert not platform2_setup.called + + component.setup(OrderedDict([ + (DOMAIN, {'platform': 'mod1'}), + ("{} 2".format(DOMAIN), {'platform': 'non_exist'}), + ("{} 3".format(DOMAIN), {'platform': 'mod2'}), + ])) + + assert platform1_setup.called + assert platform2_setup.called + + @patch('homeassistant.helpers.entity_component.EntityComponent' + '._setup_platform') + def test_setup_does_discovery(self, mock_setup): + component = EntityComponent( + _LOGGER, DOMAIN, self.hass, discovery_platforms={ + 'discovery.test': 'platform_test', + }) + + component.setup({}) + + self.hass.bus.fire(discovery.EVENT_PLATFORM_DISCOVERED, { + discovery.ATTR_SERVICE: 'discovery.test', + discovery.ATTR_DISCOVERED: 'discovery_info', + }) + + self.hass.pool.block_till_done() + + assert mock_setup.called + assert ('platform_test', {}, 'discovery_info') == \ + mock_setup.call_args[0]