Allow update entities on add_entities callback (#4114)

* Allow udpate entities on add_entities callback

* fix wrong position

* update force_update to update_before_add

* add unittest for update_befor_add

* fix unittest

* change mocking
This commit is contained in:
Pascal Vizeli 2016-10-30 00:33:11 +02:00 committed by Paulus Schoutsen
parent 5d43d3eb1c
commit 9c0455e3dc
5 changed files with 50 additions and 31 deletions

View File

@ -63,7 +63,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
_LOGGER.error('No sensors added')
return False
hass.loop.create_task(async_add_devices(sensors))
hass.loop.create_task(async_add_devices(sensors, True))
return True
@ -82,8 +82,6 @@ class BinarySensorTemplate(BinarySensorDevice):
self._template = value_template
self._state = None
self._async_render()
@callback
def template_bsensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
@ -115,10 +113,6 @@ class BinarySensorTemplate(BinarySensorDevice):
@asyncio.coroutine
def async_update(self):
"""Update the state from the template."""
self._async_render()
def _async_render(self):
"""Render the state from the template."""
try:
self._state = self._template.async_render().lower() == 'true'
except TemplateError as ex:

View File

@ -61,7 +61,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
_LOGGER.error("No sensors added")
return False
hass.loop.create_task(async_add_devices(sensors))
hass.loop.create_task(async_add_devices(sensors, True))
return True
@ -80,9 +80,6 @@ class SensorTemplate(Entity):
self._template = state_template
self._state = None
# update state
self._async_render()
@callback
def template_sensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
@ -114,10 +111,6 @@ class SensorTemplate(Entity):
@asyncio.coroutine
def async_update(self):
"""Update the state from the template."""
self._async_render()
def _async_render(self):
"""Render the state from the template."""
try:
self._state = self._template.async_render()
except TemplateError as ex:

View File

@ -70,7 +70,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
_LOGGER.error("No switches added")
return False
hass.loop.create_task(async_add_devices(switches))
hass.loop.create_task(async_add_devices(switches, True))
return True
@ -90,8 +90,6 @@ class SwitchTemplate(SwitchDevice):
self._off_script = Script(hass, off_action)
self._state = False
self._async_render()
@callback
def template_switch_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
@ -131,10 +129,6 @@ class SwitchTemplate(SwitchDevice):
@asyncio.coroutine
def async_update(self):
"""Update the state from the template."""
self._async_render()
def _async_render(self):
"""Render the state from the template."""
try:
state = self._template.async_render().lower()

View File

@ -155,14 +155,15 @@ class EntityComponent(object):
self.logger.exception(
'Error while setting up platform %s', platform_type)
def add_entity(self, entity, platform=None):
def add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component."""
return run_coroutine_threadsafe(
self.async_add_entity(entity, platform), self.hass.loop
self.async_add_entity(entity, platform, update_before_add),
self.hass.loop
).result()
@asyncio.coroutine
def async_add_entity(self, entity, platform=None):
def async_add_entity(self, entity, platform=None, update_before_add=False):
"""Add entity to component.
This method must be run in the event loop.
@ -172,6 +173,13 @@ class EntityComponent(object):
entity.hass = self.hass
# update/init entity data
if update_before_add:
if hasattr(entity, 'async_update'):
yield from entity.async_update()
else:
yield from self.hass.loop.run_in_executor(None, entity.update)
if getattr(entity, 'entity_id', None) is None:
object_id = entity.name or DEVICE_DEFAULT_NAME
@ -274,19 +282,21 @@ class EntityPlatform(object):
self.platform_entities = []
self._async_unsub_polling = None
def add_entities(self, new_entities):
def add_entities(self, new_entities, update_before_add=False):
"""Add entities for a single platform."""
run_coroutine_threadsafe(
self.async_add_entities(new_entities), self.component.hass.loop
self.async_add_entities(new_entities, update_before_add),
self.component.hass.loop
).result()
@asyncio.coroutine
def async_add_entities(self, new_entities):
def async_add_entities(self, new_entities, update_before_add=False):
"""Add entities for a single platform async.
This method must be run in the event loop.
"""
tasks = [self._async_process_entity(entity) for entity in new_entities]
tasks = [self._async_process_entity(entity, update_before_add)
for entity in new_entities]
yield from asyncio.gather(*tasks, loop=self.component.hass.loop)
yield from self.component.async_update_group()
@ -301,9 +311,11 @@ class EntityPlatform(object):
second=range(0, 60, self.scan_interval))
@asyncio.coroutine
def _async_process_entity(self, new_entity):
def _async_process_entity(self, new_entity, update_before_add):
"""Add entities to StateMachine."""
ret = yield from self.component.async_add_entity(new_entity, self)
ret = yield from self.component.async_add_entity(
new_entity, self, update_before_add=update_before_add
)
if ret:
self.platform_entities.append(new_entity)

View File

@ -126,6 +126,32 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert 2 == len(self.hass.states.entity_ids())
def test_update_state_adds_entities_with_update_befor_add_true(self):
"""Test if call update befor add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = EntityTest()
ent.update = Mock(spec_set=True)
component.add_entities([ent], True)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert ent.update.called
def test_update_state_adds_entities_with_update_befor_add_false(self):
"""Test if not call update befor add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = EntityTest()
ent.update = Mock(spec_set=True)
component.add_entities([ent], False)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert not ent.update.called
def test_not_adding_duplicate_entities(self):
"""Test for not adding duplicate entities."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)