mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Allow update entities on add_entities callback (#4114)
* Allow udpate entities on add_entities callback * fix wrong position * update force_update to update_before_add * add unittest for update_befor_add * fix unittest * change mocking
This commit is contained in:
parent
5d43d3eb1c
commit
9c0455e3dc
@ -63,7 +63,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
|
|||||||
_LOGGER.error('No sensors added')
|
_LOGGER.error('No sensors added')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
hass.loop.create_task(async_add_devices(sensors))
|
hass.loop.create_task(async_add_devices(sensors, True))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -82,8 +82,6 @@ class BinarySensorTemplate(BinarySensorDevice):
|
|||||||
self._template = value_template
|
self._template = value_template
|
||||||
self._state = None
|
self._state = None
|
||||||
|
|
||||||
self._async_render()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def template_bsensor_state_listener(entity, old_state, new_state):
|
def template_bsensor_state_listener(entity, old_state, new_state):
|
||||||
"""Called when the target device changes state."""
|
"""Called when the target device changes state."""
|
||||||
@ -115,10 +113,6 @@ class BinarySensorTemplate(BinarySensorDevice):
|
|||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_update(self):
|
def async_update(self):
|
||||||
"""Update the state from the template."""
|
"""Update the state from the template."""
|
||||||
self._async_render()
|
|
||||||
|
|
||||||
def _async_render(self):
|
|
||||||
"""Render the state from the template."""
|
|
||||||
try:
|
try:
|
||||||
self._state = self._template.async_render().lower() == 'true'
|
self._state = self._template.async_render().lower() == 'true'
|
||||||
except TemplateError as ex:
|
except TemplateError as ex:
|
||||||
|
@ -61,7 +61,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
|
|||||||
_LOGGER.error("No sensors added")
|
_LOGGER.error("No sensors added")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
hass.loop.create_task(async_add_devices(sensors))
|
hass.loop.create_task(async_add_devices(sensors, True))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -80,9 +80,6 @@ class SensorTemplate(Entity):
|
|||||||
self._template = state_template
|
self._template = state_template
|
||||||
self._state = None
|
self._state = None
|
||||||
|
|
||||||
# update state
|
|
||||||
self._async_render()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def template_sensor_state_listener(entity, old_state, new_state):
|
def template_sensor_state_listener(entity, old_state, new_state):
|
||||||
"""Called when the target device changes state."""
|
"""Called when the target device changes state."""
|
||||||
@ -114,10 +111,6 @@ class SensorTemplate(Entity):
|
|||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_update(self):
|
def async_update(self):
|
||||||
"""Update the state from the template."""
|
"""Update the state from the template."""
|
||||||
self._async_render()
|
|
||||||
|
|
||||||
def _async_render(self):
|
|
||||||
"""Render the state from the template."""
|
|
||||||
try:
|
try:
|
||||||
self._state = self._template.async_render()
|
self._state = self._template.async_render()
|
||||||
except TemplateError as ex:
|
except TemplateError as ex:
|
||||||
|
@ -70,7 +70,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
|
|||||||
_LOGGER.error("No switches added")
|
_LOGGER.error("No switches added")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
hass.loop.create_task(async_add_devices(switches))
|
hass.loop.create_task(async_add_devices(switches, True))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -90,8 +90,6 @@ class SwitchTemplate(SwitchDevice):
|
|||||||
self._off_script = Script(hass, off_action)
|
self._off_script = Script(hass, off_action)
|
||||||
self._state = False
|
self._state = False
|
||||||
|
|
||||||
self._async_render()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def template_switch_state_listener(entity, old_state, new_state):
|
def template_switch_state_listener(entity, old_state, new_state):
|
||||||
"""Called when the target device changes state."""
|
"""Called when the target device changes state."""
|
||||||
@ -131,10 +129,6 @@ class SwitchTemplate(SwitchDevice):
|
|||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_update(self):
|
def async_update(self):
|
||||||
"""Update the state from the template."""
|
"""Update the state from the template."""
|
||||||
self._async_render()
|
|
||||||
|
|
||||||
def _async_render(self):
|
|
||||||
"""Render the state from the template."""
|
|
||||||
try:
|
try:
|
||||||
state = self._template.async_render().lower()
|
state = self._template.async_render().lower()
|
||||||
|
|
||||||
|
@ -155,14 +155,15 @@ class EntityComponent(object):
|
|||||||
self.logger.exception(
|
self.logger.exception(
|
||||||
'Error while setting up platform %s', platform_type)
|
'Error while setting up platform %s', platform_type)
|
||||||
|
|
||||||
def add_entity(self, entity, platform=None):
|
def add_entity(self, entity, platform=None, update_before_add=False):
|
||||||
"""Add entity to component."""
|
"""Add entity to component."""
|
||||||
return run_coroutine_threadsafe(
|
return run_coroutine_threadsafe(
|
||||||
self.async_add_entity(entity, platform), self.hass.loop
|
self.async_add_entity(entity, platform, update_before_add),
|
||||||
|
self.hass.loop
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_add_entity(self, entity, platform=None):
|
def async_add_entity(self, entity, platform=None, update_before_add=False):
|
||||||
"""Add entity to component.
|
"""Add entity to component.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
@ -172,6 +173,13 @@ class EntityComponent(object):
|
|||||||
|
|
||||||
entity.hass = self.hass
|
entity.hass = self.hass
|
||||||
|
|
||||||
|
# update/init entity data
|
||||||
|
if update_before_add:
|
||||||
|
if hasattr(entity, 'async_update'):
|
||||||
|
yield from entity.async_update()
|
||||||
|
else:
|
||||||
|
yield from self.hass.loop.run_in_executor(None, entity.update)
|
||||||
|
|
||||||
if getattr(entity, 'entity_id', None) is None:
|
if getattr(entity, 'entity_id', None) is None:
|
||||||
object_id = entity.name or DEVICE_DEFAULT_NAME
|
object_id = entity.name or DEVICE_DEFAULT_NAME
|
||||||
|
|
||||||
@ -274,19 +282,21 @@ class EntityPlatform(object):
|
|||||||
self.platform_entities = []
|
self.platform_entities = []
|
||||||
self._async_unsub_polling = None
|
self._async_unsub_polling = None
|
||||||
|
|
||||||
def add_entities(self, new_entities):
|
def add_entities(self, new_entities, update_before_add=False):
|
||||||
"""Add entities for a single platform."""
|
"""Add entities for a single platform."""
|
||||||
run_coroutine_threadsafe(
|
run_coroutine_threadsafe(
|
||||||
self.async_add_entities(new_entities), self.component.hass.loop
|
self.async_add_entities(new_entities, update_before_add),
|
||||||
|
self.component.hass.loop
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_add_entities(self, new_entities):
|
def async_add_entities(self, new_entities, update_before_add=False):
|
||||||
"""Add entities for a single platform async.
|
"""Add entities for a single platform async.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
tasks = [self._async_process_entity(entity) for entity in new_entities]
|
tasks = [self._async_process_entity(entity, update_before_add)
|
||||||
|
for entity in new_entities]
|
||||||
|
|
||||||
yield from asyncio.gather(*tasks, loop=self.component.hass.loop)
|
yield from asyncio.gather(*tasks, loop=self.component.hass.loop)
|
||||||
yield from self.component.async_update_group()
|
yield from self.component.async_update_group()
|
||||||
@ -301,9 +311,11 @@ class EntityPlatform(object):
|
|||||||
second=range(0, 60, self.scan_interval))
|
second=range(0, 60, self.scan_interval))
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _async_process_entity(self, new_entity):
|
def _async_process_entity(self, new_entity, update_before_add):
|
||||||
"""Add entities to StateMachine."""
|
"""Add entities to StateMachine."""
|
||||||
ret = yield from self.component.async_add_entity(new_entity, self)
|
ret = yield from self.component.async_add_entity(
|
||||||
|
new_entity, self, update_before_add=update_before_add
|
||||||
|
)
|
||||||
if ret:
|
if ret:
|
||||||
self.platform_entities.append(new_entity)
|
self.platform_entities.append(new_entity)
|
||||||
|
|
||||||
|
@ -126,6 +126,32 @@ class TestHelpersEntityComponent(unittest.TestCase):
|
|||||||
|
|
||||||
assert 2 == len(self.hass.states.entity_ids())
|
assert 2 == len(self.hass.states.entity_ids())
|
||||||
|
|
||||||
|
def test_update_state_adds_entities_with_update_befor_add_true(self):
|
||||||
|
"""Test if call update befor add to state machine."""
|
||||||
|
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
|
||||||
|
|
||||||
|
ent = EntityTest()
|
||||||
|
ent.update = Mock(spec_set=True)
|
||||||
|
|
||||||
|
component.add_entities([ent], True)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
|
||||||
|
assert 1 == len(self.hass.states.entity_ids())
|
||||||
|
assert ent.update.called
|
||||||
|
|
||||||
|
def test_update_state_adds_entities_with_update_befor_add_false(self):
|
||||||
|
"""Test if not call update befor add to state machine."""
|
||||||
|
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
|
||||||
|
|
||||||
|
ent = EntityTest()
|
||||||
|
ent.update = Mock(spec_set=True)
|
||||||
|
|
||||||
|
component.add_entities([ent], False)
|
||||||
|
self.hass.block_till_done()
|
||||||
|
|
||||||
|
assert 1 == len(self.hass.states.entity_ids())
|
||||||
|
assert not ent.update.called
|
||||||
|
|
||||||
def test_not_adding_duplicate_entities(self):
|
def test_not_adding_duplicate_entities(self):
|
||||||
"""Test for not adding duplicate entities."""
|
"""Test for not adding duplicate entities."""
|
||||||
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
|
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user