Add types to event tracker data (#118010)

* Add types to event tracker data

* fixes

* do not test event internals in other tests

* fixes

* Update homeassistant/helpers/event.py

* cleanup

* cleanup
This commit is contained in:
J. Nick Koston 2024-05-24 04:09:39 -10:00 committed by GitHub
parent 7183260d95
commit a8fba691ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 62 deletions

View File

@ -54,30 +54,21 @@ from .sun import get_astral_event_next
from .template import RenderInfo, Template, result_as_boolean from .template import RenderInfo, Template, result_as_boolean
from .typing import TemplateVarsType from .typing import TemplateVarsType
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" _TRACK_STATE_CHANGE_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = HassKey(
TRACK_STATE_CHANGE_LISTENER: HassKey[Callable[[], None]] = HassKey( "track_state_change_data"
"track_state_change_listener"
) )
_TRACK_STATE_ADDED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = (
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks" HassKey("track_state_added_domain_data")
TRACK_STATE_ADDED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_added_domain_listener"
) )
_TRACK_STATE_REMOVED_DOMAIN_DATA: HassKey[_KeyedEventData[EventStateChangedData]] = (
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks" HassKey("track_state_removed_domain_data")
TRACK_STATE_REMOVED_DOMAIN_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_state_removed_domain_listener"
)
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_entity_registry_updated_listener"
)
TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks"
TRACK_DEVICE_REGISTRY_UPDATED_LISTENER: HassKey[Callable[[], None]] = HassKey(
"track_device_registry_updated_listener"
) )
_TRACK_ENTITY_REGISTRY_UPDATED_DATA: HassKey[
_KeyedEventData[EventEntityRegistryUpdatedData]
] = HassKey("track_entity_registry_updated_data")
_TRACK_DEVICE_REGISTRY_UPDATED_DATA: HassKey[
_KeyedEventData[EventDeviceRegistryUpdatedData]
] = HassKey("track_device_registry_updated_data")
_ALL_LISTENER = "all" _ALL_LISTENER = "all"
_DOMAINS_LISTENER = "domains" _DOMAINS_LISTENER = "domains"
@ -99,8 +90,7 @@ _TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any])
class _KeyedEventTracker(Generic[_TypedDictT]): class _KeyedEventTracker(Generic[_TypedDictT]):
"""Class to track events by key.""" """Class to track events by key."""
listeners_key: HassKey[Callable[[], None]] key: HassKey[_KeyedEventData[_TypedDictT]]
callbacks_key: str
event_type: EventType[_TypedDictT] | str event_type: EventType[_TypedDictT] | str
dispatcher_callable: Callable[ dispatcher_callable: Callable[
[ [
@ -120,6 +110,14 @@ class _KeyedEventTracker(Generic[_TypedDictT]):
] ]
@dataclass(slots=True, frozen=True)
class _KeyedEventData(Generic[_TypedDictT]):
"""Class to track data for events by key."""
listener: CALLBACK_TYPE
callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]]
@dataclass(slots=True) @dataclass(slots=True)
class TrackStates: class TrackStates:
"""Class for keeping track of states being tracked. """Class for keeping track of states being tracked.
@ -354,8 +352,7 @@ def _async_state_change_filter(
_KEYED_TRACK_STATE_CHANGE = _KeyedEventTracker( _KEYED_TRACK_STATE_CHANGE = _KeyedEventTracker(
listeners_key=TRACK_STATE_CHANGE_LISTENER, key=_TRACK_STATE_CHANGE_DATA,
callbacks_key=TRACK_STATE_CHANGE_CALLBACKS,
event_type=EVENT_STATE_CHANGED, event_type=EVENT_STATE_CHANGED,
dispatcher_callable=_async_dispatch_entity_id_event, dispatcher_callable=_async_dispatch_entity_id_event,
filter_callable=_async_state_change_filter, filter_callable=_async_state_change_filter,
@ -380,10 +377,10 @@ def _remove_empty_listener() -> None:
"""Remove a listener that does nothing.""" """Remove a listener that does nothing."""
@callback # type: ignore[arg-type] # mypy bug? @callback
def _remove_listener( def _remove_listener(
hass: HomeAssistant, hass: HomeAssistant,
listeners_key: HassKey[Callable[[], None]], tracker: _KeyedEventTracker[_TypedDictT],
keys: Iterable[str], keys: Iterable[str],
job: HassJob[[Event[_TypedDictT]], Any], job: HassJob[[Event[_TypedDictT]], Any],
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]], callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
@ -391,12 +388,11 @@ def _remove_listener(
"""Remove listener.""" """Remove listener."""
for key in keys: for key in keys:
callbacks[key].remove(job) callbacks[key].remove(job)
if len(callbacks[key]) == 0: if not callbacks[key]:
del callbacks[key] del callbacks[key]
if not callbacks: if not callbacks:
hass.data[listeners_key]() hass.data.pop(tracker.key).listener()
del hass.data[listeners_key]
# tracker, not hass is intentionally the first argument here since its # tracker, not hass is intentionally the first argument here since its
@ -411,26 +407,24 @@ def _async_track_event(
"""Track an event by a specific key. """Track an event by a specific key.
This function is intended for internal use only. This function is intended for internal use only.
The dispatcher_callable, filter_callable, event_type, and run_immediately
must always be the same for the listener_key as the first call to this
function will set the listener_key in hass.data.
""" """
if not keys: if not keys:
return _remove_empty_listener return _remove_empty_listener
hass_data = hass.data hass_data = hass.data
callbacks: defaultdict[str, list[HassJob[[Event[_TypedDictT]], Any]]] | None tracker_key = tracker.key
if not (callbacks := hass_data.get(tracker.callbacks_key)): if tracker_key in hass_data:
callbacks = hass_data[tracker.callbacks_key] = defaultdict(list) event_data = hass_data[tracker_key]
callbacks = event_data.callbacks
listeners_key = tracker.listeners_key else:
if tracker.listeners_key not in hass_data: callbacks = defaultdict(list)
hass_data[tracker.listeners_key] = hass.bus.async_listen( listener = hass.bus.async_listen(
tracker.event_type, tracker.event_type,
partial(tracker.dispatcher_callable, hass, callbacks), partial(tracker.dispatcher_callable, hass, callbacks),
event_filter=partial(tracker.filter_callable, hass, callbacks), event_filter=partial(tracker.filter_callable, hass, callbacks),
) )
event_data = _KeyedEventData(listener, callbacks)
hass_data[tracker_key] = event_data
job = HassJob(action, f"track {tracker.event_type} event {keys}", job_type=job_type) job = HassJob(action, f"track {tracker.event_type} event {keys}", job_type=job_type)
@ -441,12 +435,12 @@ def _async_track_event(
# during startup, and we want to avoid the overhead of # during startup, and we want to avoid the overhead of
# creating empty lists and throwing them away. # creating empty lists and throwing them away.
callbacks[keys].append(job) callbacks[keys].append(job)
keys = [keys] keys = (keys,)
else: else:
for key in keys: for key in keys:
callbacks[key].append(job) callbacks[key].append(job)
return partial(_remove_listener, hass, listeners_key, keys, job, callbacks) return partial(_remove_listener, hass, tracker, keys, job, callbacks)
@callback @callback
@ -484,8 +478,7 @@ def _async_entity_registry_updated_filter(
_KEYED_TRACK_ENTITY_REGISTRY_UPDATED = _KeyedEventTracker( _KEYED_TRACK_ENTITY_REGISTRY_UPDATED = _KeyedEventTracker(
listeners_key=TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, key=_TRACK_ENTITY_REGISTRY_UPDATED_DATA,
callbacks_key=TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
event_type=EVENT_ENTITY_REGISTRY_UPDATED, event_type=EVENT_ENTITY_REGISTRY_UPDATED,
dispatcher_callable=_async_dispatch_old_entity_id_or_entity_id_event, dispatcher_callable=_async_dispatch_old_entity_id_or_entity_id_event,
filter_callable=_async_entity_registry_updated_filter, filter_callable=_async_entity_registry_updated_filter,
@ -542,8 +535,7 @@ def _async_dispatch_device_id_event(
_KEYED_TRACK_DEVICE_REGISTRY_UPDATED = _KeyedEventTracker( _KEYED_TRACK_DEVICE_REGISTRY_UPDATED = _KeyedEventTracker(
listeners_key=TRACK_DEVICE_REGISTRY_UPDATED_LISTENER, key=_TRACK_DEVICE_REGISTRY_UPDATED_DATA,
callbacks_key=TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS,
event_type=EVENT_DEVICE_REGISTRY_UPDATED, event_type=EVENT_DEVICE_REGISTRY_UPDATED,
dispatcher_callable=_async_dispatch_device_id_event, dispatcher_callable=_async_dispatch_device_id_event,
filter_callable=_async_device_registry_updated_filter, filter_callable=_async_device_registry_updated_filter,
@ -613,8 +605,7 @@ def async_track_state_added_domain(
_KEYED_TRACK_STATE_ADDED_DOMAIN = _KeyedEventTracker( _KEYED_TRACK_STATE_ADDED_DOMAIN = _KeyedEventTracker(
listeners_key=TRACK_STATE_ADDED_DOMAIN_LISTENER, key=_TRACK_STATE_ADDED_DOMAIN_DATA,
callbacks_key=TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
event_type=EVENT_STATE_CHANGED, event_type=EVENT_STATE_CHANGED,
dispatcher_callable=_async_dispatch_domain_event, dispatcher_callable=_async_dispatch_domain_event,
filter_callable=_async_domain_added_filter, filter_callable=_async_domain_added_filter,
@ -651,8 +642,7 @@ def _async_domain_removed_filter(
_KEYED_TRACK_STATE_REMOVED_DOMAIN = _KeyedEventTracker( _KEYED_TRACK_STATE_REMOVED_DOMAIN = _KeyedEventTracker(
listeners_key=TRACK_STATE_REMOVED_DOMAIN_LISTENER, key=_TRACK_STATE_REMOVED_DOMAIN_DATA,
callbacks_key=TRACK_STATE_REMOVED_DOMAIN_CALLBACKS,
event_type=EVENT_STATE_CHANGED, event_type=EVENT_STATE_CHANGED,
dispatcher_callable=_async_dispatch_domain_event, dispatcher_callable=_async_dispatch_domain_event,
filter_callable=_async_domain_removed_filter, filter_callable=_async_domain_removed_filter,

View File

@ -33,7 +33,6 @@ from homeassistant.const import (
) )
from homeassistant.core import CoreState, HomeAssistant from homeassistant.core import CoreState, HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import common from . import common
@ -901,10 +900,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None:
"group.test_group", "group.test_group",
] ]
assert hass.bus.async_listeners()["state_changed"] == 1 assert hass.bus.async_listeners()["state_changed"] == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["hello.world"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1
with patch( with patch(
"homeassistant.config.load_yaml_config_file", "homeassistant.config.load_yaml_config_file",
@ -920,9 +915,6 @@ async def test_reloading_groups(hass: HomeAssistant) -> None:
"group.hello", "group.hello",
] ]
assert hass.bus.async_listeners()["state_changed"] == 1 assert hass.bus.async_listeners()["state_changed"] == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1
async def test_modify_group(hass: HomeAssistant) -> None: async def test_modify_group(hass: HomeAssistant) -> None:

View File

@ -48,7 +48,6 @@ from homeassistant.const import (
__version__ as hass_version, __version__ as hass_version,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS
from tests.common import async_mock_service from tests.common import async_mock_service
@ -66,9 +65,7 @@ async def test_accessory_cancels_track_state_change_on_stop(
"homeassistant.components.homekit.accessories.HomeAccessory.async_update_state" "homeassistant.components.homekit.accessories.HomeAccessory.async_update_state"
): ):
acc.run() acc.run()
assert len(hass.data[TRACK_STATE_CHANGE_CALLBACKS][entity_id]) == 1
await acc.stop() await acc.stop()
assert entity_id not in hass.data[TRACK_STATE_CHANGE_CALLBACKS]
async def test_home_accessory(hass: HomeAssistant, hk_driver) -> None: async def test_home_accessory(hass: HomeAssistant, hk_driver) -> None: