mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 13:47:35 +00:00
Index entity_registry_updated listeners (#37940)
This commit is contained in:
parent
9ae08585dc
commit
910b6c9c2c
@ -27,11 +27,8 @@ from homeassistant.const import (
|
|||||||
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import NoEntitySpecifiedError
|
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||||
from homeassistant.helpers.entity_platform import EntityPlatform
|
from homeassistant.helpers.entity_platform import EntityPlatform
|
||||||
from homeassistant.helpers.entity_registry import (
|
from homeassistant.helpers.entity_registry import RegistryEntry
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED,
|
from homeassistant.helpers.event import Event, async_track_entity_registry_updated_event
|
||||||
RegistryEntry,
|
|
||||||
)
|
|
||||||
from homeassistant.helpers.event import Event
|
|
||||||
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
|
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
|
||||||
from homeassistant.util.async_ import run_callback_threadsafe
|
from homeassistant.util.async_ import run_callback_threadsafe
|
||||||
|
|
||||||
@ -518,8 +515,8 @@ class Entity(ABC):
|
|||||||
if self.registry_entry is not None:
|
if self.registry_entry is not None:
|
||||||
assert self.hass is not None
|
assert self.hass is not None
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
self.hass.bus.async_listen(
|
async_track_entity_registry_updated_event(
|
||||||
EVENT_ENTITY_REGISTRY_UPDATED, self._async_registry_updated
|
self.hass, self.entity_id, self._async_registry_updated
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -532,14 +529,11 @@ class Entity(ABC):
|
|||||||
async def _async_registry_updated(self, event: Event) -> None:
|
async def _async_registry_updated(self, event: Event) -> None:
|
||||||
"""Handle entity registry update."""
|
"""Handle entity registry update."""
|
||||||
data = event.data
|
data = event.data
|
||||||
if data["action"] == "remove" and data["entity_id"] == self.entity_id:
|
if data["action"] == "remove":
|
||||||
await self.async_removed_from_registry()
|
await self.async_removed_from_registry()
|
||||||
await self.async_remove()
|
await self.async_remove()
|
||||||
|
|
||||||
if (
|
if data["action"] != "update":
|
||||||
data["action"] != "update"
|
|
||||||
or data.get("old_entity_id", data["entity_id"]) != self.entity_id
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
assert self.hass is not None
|
assert self.hass is not None
|
||||||
|
@ -17,6 +17,7 @@ from homeassistant.const import (
|
|||||||
SUN_EVENT_SUNSET,
|
SUN_EVENT_SUNSET,
|
||||||
)
|
)
|
||||||
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback
|
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback
|
||||||
|
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||||
from homeassistant.helpers.sun import get_astral_event_next
|
from homeassistant.helpers.sun import get_astral_event_next
|
||||||
from homeassistant.helpers.template import Template
|
from homeassistant.helpers.template import Template
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
@ -26,6 +27,9 @@ from homeassistant.util.async_ import run_callback_threadsafe
|
|||||||
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
|
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
|
||||||
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
|
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
|
||||||
|
|
||||||
|
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
|
||||||
|
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# PyLint does not like the use of threaded_listener_factory
|
# PyLint does not like the use of threaded_listener_factory
|
||||||
@ -137,7 +141,7 @@ track_state_change = threaded_listener_factory(async_track_state_change)
|
|||||||
def async_track_state_change_event(
|
def async_track_state_change_event(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
entity_ids: Union[str, Iterable[str]],
|
entity_ids: Union[str, Iterable[str]],
|
||||||
action: Callable[[Event], None],
|
action: Callable[[Event], Any],
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Track specific state change events indexed by entity_id.
|
"""Track specific state change events indexed by entity_id.
|
||||||
|
|
||||||
@ -186,17 +190,28 @@ def async_track_state_change_event(
|
|||||||
@callback
|
@callback
|
||||||
def remove_listener() -> None:
|
def remove_listener() -> None:
|
||||||
"""Remove state change listener."""
|
"""Remove state change listener."""
|
||||||
_async_remove_state_change_listeners(hass, entity_ids, action)
|
_async_remove_entity_listeners(
|
||||||
|
hass,
|
||||||
|
TRACK_STATE_CHANGE_CALLBACKS,
|
||||||
|
TRACK_STATE_CHANGE_LISTENER,
|
||||||
|
entity_ids,
|
||||||
|
action,
|
||||||
|
)
|
||||||
|
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_remove_state_change_listeners(
|
def _async_remove_entity_listeners(
|
||||||
hass: HomeAssistant, entity_ids: Iterable[str], action: Callable[[Event], None]
|
hass: HomeAssistant,
|
||||||
|
storage_key: str,
|
||||||
|
listener_key: str,
|
||||||
|
entity_ids: Iterable[str],
|
||||||
|
action: Callable[[Event], Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Remove a listener."""
|
"""Remove a listener."""
|
||||||
entity_callbacks = hass.data[TRACK_STATE_CHANGE_CALLBACKS]
|
|
||||||
|
entity_callbacks = hass.data[storage_key]
|
||||||
|
|
||||||
for entity_id in entity_ids:
|
for entity_id in entity_ids:
|
||||||
entity_callbacks[entity_id].remove(action)
|
entity_callbacks[entity_id].remove(action)
|
||||||
@ -204,8 +219,66 @@ def _async_remove_state_change_listeners(
|
|||||||
del entity_callbacks[entity_id]
|
del entity_callbacks[entity_id]
|
||||||
|
|
||||||
if not entity_callbacks:
|
if not entity_callbacks:
|
||||||
hass.data[TRACK_STATE_CHANGE_LISTENER]()
|
hass.data[listener_key]()
|
||||||
del hass.data[TRACK_STATE_CHANGE_LISTENER]
|
del hass.data[listener_key]
|
||||||
|
|
||||||
|
|
||||||
|
@bind_hass
|
||||||
|
def async_track_entity_registry_updated_event(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entity_ids: Union[str, Iterable[str]],
|
||||||
|
action: Callable[[Event], Any],
|
||||||
|
) -> Callable[[], None]:
|
||||||
|
"""Track specific entity registry updated events indexed by entity_id.
|
||||||
|
|
||||||
|
Similar to async_track_state_change_event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {})
|
||||||
|
|
||||||
|
if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data:
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_entity_registry_updated_dispatcher(event: Event) -> None:
|
||||||
|
"""Dispatch entity registry updates by entity_id."""
|
||||||
|
entity_id = event.data.get("old_entity_id", event.data["entity_id"])
|
||||||
|
|
||||||
|
if entity_id not in entity_callbacks:
|
||||||
|
return
|
||||||
|
|
||||||
|
for action in entity_callbacks[entity_id][:]:
|
||||||
|
try:
|
||||||
|
hass.async_run_job(action, event)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception(
|
||||||
|
"Error while processing entity registry update for %s",
|
||||||
|
entity_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(entity_ids, str):
|
||||||
|
entity_ids = [entity_ids]
|
||||||
|
|
||||||
|
entity_ids = [entity_id.lower() for entity_id in entity_ids]
|
||||||
|
|
||||||
|
for entity_id in entity_ids:
|
||||||
|
entity_callbacks.setdefault(entity_id, []).append(action)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def remove_listener() -> None:
|
||||||
|
"""Remove state change listener."""
|
||||||
|
_async_remove_entity_listeners(
|
||||||
|
hass,
|
||||||
|
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
|
||||||
|
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
|
||||||
|
entity_ids,
|
||||||
|
action,
|
||||||
|
)
|
||||||
|
|
||||||
|
return remove_listener
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -338,6 +338,7 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock):
|
|||||||
# Verify state is removed
|
# Verify state is removed
|
||||||
state = hass.states.get("sensor.mqtt_sensor")
|
state = hass.states.get("sensor.mqtt_sensor")
|
||||||
assert state is None
|
assert state is None
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
# Verify retained discovery topic has been cleared
|
# Verify retained discovery topic has been cleared
|
||||||
mqtt_mock.async_publish.assert_called_once_with(
|
mqtt_mock.async_publish.assert_called_once_with(
|
||||||
|
@ -10,6 +10,7 @@ from homeassistant.components import sun
|
|||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import MATCH_ALL
|
||||||
import homeassistant.core as ha
|
import homeassistant.core as ha
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||||
from homeassistant.helpers.event import (
|
from homeassistant.helpers.event import (
|
||||||
async_call_later,
|
async_call_later,
|
||||||
async_track_point_in_time,
|
async_track_point_in_time,
|
||||||
@ -1180,3 +1181,104 @@ async def test_async_track_point_in_time_cancel(hass):
|
|||||||
|
|
||||||
assert len(times) == 1
|
assert len(times) == 1
|
||||||
assert times[0].tzinfo.zone == "US/Hawaii"
|
assert times[0].tzinfo.zone == "US/Hawaii"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_track_entity_registry_updated_event(hass):
|
||||||
|
"""Test tracking entity registry updates for an entity_id."""
|
||||||
|
|
||||||
|
entity_id = "switch.puppy_feeder"
|
||||||
|
new_entity_id = "switch.dog_feeder"
|
||||||
|
untracked_entity_id = "switch.kitty_feeder"
|
||||||
|
|
||||||
|
hass.states.async_set(entity_id, "on")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
event_data = []
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
|
def run_callback(event):
|
||||||
|
event_data.append(event.data)
|
||||||
|
|
||||||
|
unsub1 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||||
|
entity_id, run_callback
|
||||||
|
)
|
||||||
|
unsub2 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||||
|
new_entity_id, run_callback
|
||||||
|
)
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||||
|
)
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||||
|
{"action": "create", "entity_id": untracked_entity_id},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||||
|
{
|
||||||
|
"action": "update",
|
||||||
|
"entity_id": new_entity_id,
|
||||||
|
"old_entity_id": entity_id,
|
||||||
|
"changes": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": new_entity_id}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
unsub1()
|
||||||
|
unsub2()
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||||
|
)
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": new_entity_id}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert event_data[0] == {"action": "create", "entity_id": "switch.puppy_feeder"}
|
||||||
|
assert event_data[1] == {
|
||||||
|
"action": "update",
|
||||||
|
"changes": {},
|
||||||
|
"entity_id": "switch.dog_feeder",
|
||||||
|
"old_entity_id": "switch.puppy_feeder",
|
||||||
|
}
|
||||||
|
assert event_data[2] == {"action": "remove", "entity_id": "switch.dog_feeder"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_track_entity_registry_updated_event_with_a_callback_that_throws(
|
||||||
|
hass,
|
||||||
|
):
|
||||||
|
"""Test tracking entity registry updates for an entity_id when one callback throws."""
|
||||||
|
|
||||||
|
entity_id = "switch.puppy_feeder"
|
||||||
|
|
||||||
|
hass.states.async_set(entity_id, "on")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
event_data = []
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
|
def run_callback(event):
|
||||||
|
event_data.append(event.data)
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
|
def failing_callback(event):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
unsub1 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||||
|
entity_id, failing_callback
|
||||||
|
)
|
||||||
|
unsub2 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||||
|
entity_id, run_callback
|
||||||
|
)
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
unsub1()
|
||||||
|
unsub2()
|
||||||
|
|
||||||
|
assert event_data[0] == {"action": "create", "entity_id": "switch.puppy_feeder"}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user