mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Entity to handle updates via events (#24733)
* Entity to handle updates via events * Fix a bug * Update entity.py
This commit is contained in:
parent
9e0636eefa
commit
06af6f19a3
@ -10,6 +10,7 @@ from homeassistant.const import (
|
||||
ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON,
|
||||
STATE_UNAVAILABLE, STATE_UNKNOWN, TEMP_CELSIUS, TEMP_FAHRENHEIT,
|
||||
ATTR_ENTITY_PICTURE, ATTR_SUPPORTED_FEATURES, ATTR_DEVICE_CLASS)
|
||||
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.config import DATA_CUSTOMIZE
|
||||
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||
@ -78,8 +79,8 @@ class Entity:
|
||||
# Process updates in parallel
|
||||
parallel_updates = None
|
||||
|
||||
# Name in the entity registry
|
||||
registry_name = None
|
||||
# Entry in the entity registry
|
||||
registry_entry = None
|
||||
|
||||
# Hold list for functions to call on remove.
|
||||
_on_remove = None
|
||||
@ -259,7 +260,9 @@ class Entity:
|
||||
if unit_of_measurement is not None:
|
||||
attr[ATTR_UNIT_OF_MEASUREMENT] = unit_of_measurement
|
||||
|
||||
name = self.registry_name or self.name
|
||||
entry = self.registry_entry
|
||||
# pylint: disable=consider-using-ternary
|
||||
name = (entry and entry.name) or self.name
|
||||
if name is not None:
|
||||
attr[ATTR_FRIENDLY_NAME] = name
|
||||
|
||||
@ -391,6 +394,7 @@ class Entity:
|
||||
|
||||
async def async_remove(self):
|
||||
"""Remove entity from Home Assistant."""
|
||||
await self.async_internal_will_remove_from_hass()
|
||||
await self.async_will_remove_from_hass()
|
||||
|
||||
if self._on_remove is not None:
|
||||
@ -399,27 +403,52 @@ class Entity:
|
||||
|
||||
self.hass.states.async_remove(self.entity_id)
|
||||
|
||||
@callback
|
||||
def async_registry_updated(self, old, new):
|
||||
"""Handle entity registry update."""
|
||||
self.registry_name = new.name
|
||||
|
||||
if new.entity_id == self.entity_id:
|
||||
self.async_schedule_update_ha_state()
|
||||
return
|
||||
|
||||
async def readd():
|
||||
"""Remove and add entity again."""
|
||||
await self.async_remove()
|
||||
await self.platform.async_add_entities([self])
|
||||
|
||||
self.hass.async_create_task(readd())
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
"""Run when entity about to be added to hass.
|
||||
|
||||
To be extended by integrations.
|
||||
"""
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
"""Run when entity will be removed from hass.
|
||||
|
||||
To be extended by integrations.
|
||||
"""
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass.
|
||||
|
||||
Not to be extended by integrations.
|
||||
"""
|
||||
if self.registry_entry is not None:
|
||||
self.async_on_remove(self.hass.bus.async_listen(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, self._async_registry_updated))
|
||||
|
||||
async def async_internal_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass.
|
||||
|
||||
Not to be extended by integrations.
|
||||
"""
|
||||
|
||||
async def _async_registry_updated(self, event):
|
||||
"""Handle entity registry update."""
|
||||
data = event.data
|
||||
if data['action'] != 'update' and data.get(
|
||||
'old_entity_id', data['entity_id']) != self.entity_id:
|
||||
return
|
||||
|
||||
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
|
||||
old = self.registry_entry
|
||||
self.registry_entry = ent_reg.async_get(data['entity_id'])
|
||||
|
||||
if self.registry_entry.entity_id == old.entity_id:
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
await self.async_remove()
|
||||
|
||||
self.entity_id = self.registry_entry.entity_id
|
||||
await self.platform.async_add_entities([self])
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Return the comparison."""
|
||||
|
@ -320,9 +320,8 @@ class EntityPlatform:
|
||||
'"{} {}"'.format(self.platform_name, entity.unique_id))
|
||||
return
|
||||
|
||||
entity.registry_entry = entry
|
||||
entity.entity_id = entry.entity_id
|
||||
entity.registry_name = entry.name
|
||||
entity.async_on_remove(entry.add_update_listener(entity))
|
||||
|
||||
# We won't generate an entity ID if the platform has already set one
|
||||
# We will however make sure that platform cannot pick a registered ID
|
||||
@ -360,6 +359,7 @@ class EntityPlatform:
|
||||
self.entities[entity_id] = entity
|
||||
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
||||
|
||||
await entity.async_internal_added_to_hass()
|
||||
await entity.async_added_to_hass()
|
||||
|
||||
await entity.async_update_ha_state()
|
||||
|
@ -12,7 +12,6 @@ from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import List, Optional, cast
|
||||
import weakref
|
||||
|
||||
import attr
|
||||
|
||||
@ -50,8 +49,6 @@ class RegistryEntry:
|
||||
disabled_by = attr.ib(
|
||||
type=str, default=None,
|
||||
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
|
||||
update_listeners = attr.ib(type=list, default=attr.Factory(list),
|
||||
repr=False)
|
||||
domain = attr.ib(type=str, init=False, repr=False)
|
||||
|
||||
@domain.default
|
||||
@ -64,18 +61,6 @@ class RegistryEntry:
|
||||
"""Return if entry is disabled."""
|
||||
return self.disabled_by is not None
|
||||
|
||||
def add_update_listener(self, listener):
|
||||
"""Listen for when entry is updated.
|
||||
|
||||
Listener: Callback function(old_entry, new_entry)
|
||||
|
||||
Returns function to unlisten.
|
||||
"""
|
||||
weak_listener = weakref.ref(listener)
|
||||
self.update_listeners.append(weak_listener)
|
||||
|
||||
return lambda: self.update_listeners.remove(weak_listener)
|
||||
|
||||
|
||||
class EntityRegistry:
|
||||
"""Class to hold a registry of entities."""
|
||||
@ -247,26 +232,17 @@ class EntityRegistry:
|
||||
|
||||
new = self.entities[entity_id] = attr.evolve(old, **changes)
|
||||
|
||||
to_remove = []
|
||||
for listener_ref in new.update_listeners:
|
||||
listener = listener_ref()
|
||||
if listener is None:
|
||||
to_remove.append(listener_ref)
|
||||
else:
|
||||
try:
|
||||
listener.async_registry_updated(old, new)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error calling update listener')
|
||||
|
||||
for ref in to_remove:
|
||||
new.update_listeners.remove(ref)
|
||||
|
||||
self.async_schedule_save()
|
||||
|
||||
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, {
|
||||
data = {
|
||||
'action': 'update',
|
||||
'entity_id': entity_id
|
||||
})
|
||||
'entity_id': entity_id,
|
||||
}
|
||||
|
||||
if old.entity_id != entity_id:
|
||||
data['old_entity_id'] = old.entity_id
|
||||
|
||||
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
|
||||
|
||||
return new
|
||||
|
||||
|
@ -186,18 +186,18 @@ class RestoreStateData():
|
||||
class RestoreEntity(Entity):
|
||||
"""Mixin class for restoring previous entity state."""
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Register this entity as a restorable entity."""
|
||||
_, data = await asyncio.gather(
|
||||
super().async_added_to_hass(),
|
||||
super().async_internal_added_to_hass(),
|
||||
RestoreStateData.async_get_instance(self.hass),
|
||||
)
|
||||
data.async_restore_entity_added(self.entity_id)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
async def async_internal_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
_, data = await asyncio.gather(
|
||||
super().async_will_remove_from_hass(),
|
||||
super().async_internal_will_remove_from_hass(),
|
||||
RestoreStateData.async_get_instance(self.hass),
|
||||
)
|
||||
data.async_restore_entity_removed(self.entity_id)
|
||||
|
@ -104,12 +104,12 @@ async def test_dump_data(hass):
|
||||
entity = Entity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b0'
|
||||
await entity.async_added_to_hass()
|
||||
await entity.async_internal_added_to_hass()
|
||||
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b1'
|
||||
await entity.async_added_to_hass()
|
||||
await entity.async_internal_added_to_hass()
|
||||
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
now = dt_util.utcnow()
|
||||
@ -144,7 +144,7 @@ async def test_dump_data(hass):
|
||||
assert written_states[1]['state']['state'] == 'off'
|
||||
|
||||
# Test that removed entities are not persisted
|
||||
await entity.async_will_remove_from_hass()
|
||||
await entity.async_remove()
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
||||
) as mock_write_data, patch.object(
|
||||
@ -170,12 +170,12 @@ async def test_dump_error(hass):
|
||||
entity = Entity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b0'
|
||||
await entity.async_added_to_hass()
|
||||
await entity.async_internal_added_to_hass()
|
||||
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b1'
|
||||
await entity.async_added_to_hass()
|
||||
await entity.async_internal_added_to_hass()
|
||||
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
|
||||
@ -206,7 +206,7 @@ async def test_state_saved_on_remove(hass):
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b0'
|
||||
await entity.async_added_to_hass()
|
||||
await entity.async_internal_added_to_hass()
|
||||
|
||||
hass.states.async_set('input_boolean.b0', 'on')
|
||||
|
||||
@ -215,7 +215,7 @@ async def test_state_saved_on_remove(hass):
|
||||
# No last states should currently be saved
|
||||
assert not data.last_states
|
||||
|
||||
await entity.async_will_remove_from_hass()
|
||||
await entity.async_remove()
|
||||
|
||||
# We should store the input boolean state when it is removed
|
||||
assert data.last_states['input_boolean.b0'].state.state == 'on'
|
||||
|
Loading…
x
Reference in New Issue
Block a user