From a0e6fd6ec5a64033c8f9551cb05f6f0e34c47390 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 8 Apr 2024 01:28:24 +0200 Subject: [PATCH] Add improved typing for event fire and listen methods (#114906) * Add EventType implementation * Update integrations for EventType * Change state_changed to EventType * Fix tests * Remove runtime impact * Add tests * Move to stub file * Apply pre-commit to stub files * Fix ruff PYI checks --------- Co-authored-by: J. Nick Koston --- .pre-commit-config.yaml | 4 +- homeassistant/auth/permissions/events.py | 5 +- .../components/homeassistant/logbook.py | 4 +- homeassistant/components/logbook/__init__.py | 4 +- homeassistant/components/logbook/helpers.py | 7 ++- homeassistant/components/logbook/models.py | 8 ++- homeassistant/components/logbook/processor.py | 6 +- .../components/mobile_app/logbook.py | 4 +- homeassistant/components/recorder/__init__.py | 5 +- homeassistant/components/recorder/core.py | 3 +- .../components/recorder/models/event.py | 6 +- .../recorder/table_managers/__init__.py | 10 ++-- .../recorder/table_managers/event_types.py | 27 ++++++--- homeassistant/components/recorder/tasks.py | 3 +- homeassistant/const.py | 8 ++- homeassistant/core.py | 57 +++++++++++-------- homeassistant/exceptions.py | 10 +++- homeassistant/helpers/event.py | 3 +- homeassistant/util/event_type.py | 20 +++++++ homeassistant/util/event_type.pyi | 25 ++++++++ tests/util/test_event_type.py | 25 ++++++++ 21 files changed, 182 insertions(+), 62 deletions(-) create mode 100644 homeassistant/util/event_type.py create mode 100644 homeassistant/util/event_type.pyi create mode 100644 tests/util/test_event_type.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8280ac326a7..760e7e20676 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: args: - --fix - id: ruff-format - files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.py$ + files: ^((homeassistant|pylint|script|tests)/.+)?[^/]+\.(py|pyi)$ - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: @@ -63,7 +63,7 @@ repos: language: script types: [python] require_serial: true - files: ^(homeassistant|pylint)/.+\.py$ + files: ^(homeassistant|pylint)/.+\.(py|pyi)$ - id: pylint name: pylint entry: script/run-in-env.sh pylint -j 0 --ignore-missing-annotations=y diff --git a/homeassistant/auth/permissions/events.py b/homeassistant/auth/permissions/events.py index 3146cd99787..9f2fb45f9f0 100644 --- a/homeassistant/auth/permissions/events.py +++ b/homeassistant/auth/permissions/events.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Final +from typing import Any, Final from homeassistant.const import ( EVENT_COMPONENT_LOADED, @@ -21,10 +21,11 @@ from homeassistant.helpers.area_registry import EVENT_AREA_REGISTRY_UPDATED from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.issue_registry import EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED +from homeassistant.util.event_type import EventType # These are events that do not contain any sensitive data # Except for state_changed, which is handled accordingly. -SUBSCRIBE_ALLOWLIST: Final[set[str]] = { +SUBSCRIBE_ALLOWLIST: Final[set[EventType[Any] | str]] = { EVENT_AREA_REGISTRY_UPDATED, EVENT_COMPONENT_LOADED, EVENT_CORE_CONFIG_UPDATE, diff --git a/homeassistant/components/homeassistant/logbook.py b/homeassistant/components/homeassistant/logbook.py index 12d6c66b69c..1c67075b671 100644 --- a/homeassistant/components/homeassistant/logbook.py +++ b/homeassistant/components/homeassistant/logbook.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from typing import Any from homeassistant.components.logbook import ( LOGBOOK_ENTRY_ICON, @@ -11,10 +12,11 @@ from homeassistant.components.logbook import ( ) from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.util.event_type import EventType from . import DOMAIN -EVENT_TO_NAME = { +EVENT_TO_NAME: dict[EventType[Any] | str, str] = { EVENT_HOMEASSISTANT_STOP: "stopped", EVENT_HOMEASSISTANT_START: "started", } diff --git a/homeassistant/components/logbook/__init__.py b/homeassistant/components/logbook/__init__.py index f19e64aa6f0..d520cafb80e 100644 --- a/homeassistant/components/logbook/__init__.py +++ b/homeassistant/components/logbook/__init__.py @@ -31,6 +31,7 @@ from homeassistant.helpers.integration_platform import ( ) from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass +from homeassistant.util.event_type import EventType from . import rest_api, websocket_api from .const import ( # noqa: F401 @@ -134,7 +135,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: entities_filter = None external_events: dict[ - str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] + EventType[Any] | str, + tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]], ] = {} hass.data[DOMAIN] = LogbookConfig(external_events, filters, entities_filter) websocket_api.async_setup(hass) diff --git a/homeassistant/components/logbook/helpers.py b/homeassistant/components/logbook/helpers.py index 5c25056c041..8ec953a0afd 100644 --- a/homeassistant/components/logbook/helpers.py +++ b/homeassistant/components/logbook/helpers.py @@ -26,6 +26,7 @@ from homeassistant.core import ( ) from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.event import async_track_state_change_event +from homeassistant.util.event_type import EventType from .const import ALWAYS_CONTINUOUS_DOMAINS, AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN from .models import LogbookConfig @@ -63,7 +64,7 @@ def _async_config_entries_for_ids( def async_determine_event_types( hass: HomeAssistant, entity_ids: list[str] | None, device_ids: list[str] | None -) -> tuple[str, ...]: +) -> tuple[EventType[Any] | str, ...]: """Reduce the event types based on the entity ids and device ids.""" logbook_config: LogbookConfig = hass.data[DOMAIN] external_events = logbook_config.external_events @@ -81,7 +82,7 @@ def async_determine_event_types( # to add them since we have historically included # them when matching only on entities # - intrested_event_types: set[str] = { + intrested_event_types: set[EventType[Any] | str] = { external_event for external_event, domain_call in external_events.items() if domain_call[0] in interested_domains @@ -160,7 +161,7 @@ def async_subscribe_events( hass: HomeAssistant, subscriptions: list[CALLBACK_TYPE], target: Callable[[Event[Any]], None], - event_types: tuple[str, ...], + event_types: tuple[EventType[Any] | str, ...], entities_filter: Callable[[str], bool] | None, entity_ids: list[str] | None, device_ids: list[str] | None, diff --git a/homeassistant/components/logbook/models.py b/homeassistant/components/logbook/models.py index 9409c59985c..2f9b2c8e289 100644 --- a/homeassistant/components/logbook/models.py +++ b/homeassistant/components/logbook/models.py @@ -18,6 +18,7 @@ from homeassistant.components.recorder.models import ( ) from homeassistant.const import ATTR_ICON, EVENT_STATE_CHANGED from homeassistant.core import Context, Event, State, callback +from homeassistant.util.event_type import EventType from homeassistant.util.json import json_loads from homeassistant.util.ulid import ulid_to_bytes @@ -27,7 +28,8 @@ class LogbookConfig: """Configuration for the logbook integration.""" external_events: dict[ - str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] + EventType[Any] | str, + tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]], ] sqlalchemy_filter: Filters | None = None entity_filter: Callable[[str], bool] | None = None @@ -66,7 +68,7 @@ class LazyEventPartialState: ) @cached_property - def event_type(self) -> str | None: + def event_type(self) -> EventType[Any] | str | None: """Return the event type.""" return self.row.event_type @@ -110,7 +112,7 @@ class EventAsRow: icon: str | None = None context_user_id_bin: bytes | None = None context_parent_id_bin: bytes | None = None - event_type: str | None = None + event_type: EventType[Any] | str | None = None state: str | None = None context_only: None = None diff --git a/homeassistant/components/logbook/processor.py b/homeassistant/components/logbook/processor.py index 2180a63b74f..df1eb6a15f2 100644 --- a/homeassistant/components/logbook/processor.py +++ b/homeassistant/components/logbook/processor.py @@ -38,6 +38,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, split_entity_id from homeassistant.helpers import entity_registry as er import homeassistant.util.dt as dt_util +from homeassistant.util.event_type import EventType from .const import ( ATTR_MESSAGE, @@ -75,7 +76,8 @@ class LogbookRun: context_lookup: dict[bytes | None, Row | EventAsRow | None] external_events: dict[ - str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] + EventType[Any] | str, + tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]], ] event_cache: EventCache entity_name_cache: EntityNameCache @@ -90,7 +92,7 @@ class EventProcessor: def __init__( self, hass: HomeAssistant, - event_types: tuple[str, ...], + event_types: tuple[EventType[Any] | str, ...], entity_ids: list[str] | None = None, device_ids: list[str] | None = None, context_id: str | None = None, diff --git a/homeassistant/components/mobile_app/logbook.py b/homeassistant/components/mobile_app/logbook.py index 6a863e6a75b..d9f7f4f04e1 100644 --- a/homeassistant/components/mobile_app/logbook.py +++ b/homeassistant/components/mobile_app/logbook.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from typing import Any from homeassistant.components.logbook import ( LOGBOOK_ENTRY_ENTITY_ID, @@ -12,6 +13,7 @@ from homeassistant.components.logbook import ( ) from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.util.event_type import EventType from .const import DOMAIN @@ -21,7 +23,7 @@ IOS_EVENT_ZONE_EXITED = "ios.zone_exited" ATTR_ZONE = "zone" ATTR_SOURCE_DEVICE_NAME = "sourceDeviceName" ATTR_SOURCE_DEVICE_ID = "sourceDeviceID" -EVENT_TO_DESCRIPTION = { +EVENT_TO_DESCRIPTION: dict[EventType[Any] | str, str] = { IOS_EVENT_ZONE_ENTERED: "entered zone", IOS_EVENT_ZONE_EXITED: "exited zone", } diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index de75207389f..26b9f471b9e 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -25,6 +25,7 @@ from homeassistant.helpers.integration_platform import ( ) from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass +from homeassistant.util.event_type import EventType from . import entity_registry, websocket_api from .const import ( # noqa: F401 @@ -146,7 +147,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass_config_path=hass.config.path(DEFAULT_DB_FILE) ) exclude = conf[CONF_EXCLUDE] - exclude_event_types: set[str] = set(exclude.get(CONF_EVENT_TYPES, [])) + exclude_event_types: set[EventType[Any] | str] = set( + exclude.get(CONF_EVENT_TYPES, []) + ) if EVENT_STATE_CHANGED in exclude_event_types: _LOGGER.error("State change events cannot be excluded, use a filter instead") exclude_event_types.remove(EVENT_STATE_CHANGED) diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 4ae61a0c4ba..1780436168d 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -40,6 +40,7 @@ from homeassistant.helpers.start import async_at_started from homeassistant.helpers.typing import UNDEFINED, UndefinedType import homeassistant.util.dt as dt_util from homeassistant.util.enum import try_parse_enum +from homeassistant.util.event_type import EventType from . import migration, statistics from .const import ( @@ -173,7 +174,7 @@ class Recorder(threading.Thread): db_max_retries: int, db_retry_wait: int, entity_filter: Callable[[str], bool], - exclude_event_types: set[str], + exclude_event_types: set[EventType[Any] | str], ) -> None: """Initialize the recorder.""" threading.Thread.__init__(self, name="Recorder") diff --git a/homeassistant/components/recorder/models/event.py b/homeassistant/components/recorder/models/event.py index 379a6fddb1d..4e5030bfde7 100644 --- a/homeassistant/components/recorder/models/event.py +++ b/homeassistant/components/recorder/models/event.py @@ -2,9 +2,13 @@ from __future__ import annotations +from typing import Any + +from homeassistant.util.event_type import EventType + def extract_event_type_ids( - event_type_to_event_type_id: dict[str, int | None], + event_type_to_event_type_id: dict[EventType[Any] | str, int | None], ) -> list[int]: """Extract event_type ids from event_type_to_event_type_id.""" return [ diff --git a/homeassistant/components/recorder/table_managers/__init__.py b/homeassistant/components/recorder/table_managers/__init__.py index 9a0945dc4d9..c064987ddcb 100644 --- a/homeassistant/components/recorder/table_managers/__init__.py +++ b/homeassistant/components/recorder/table_managers/__init__.py @@ -1,9 +1,11 @@ """Managers for each table.""" -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from lru import LRU +from homeassistant.util.event_type import EventType + if TYPE_CHECKING: from ..core import Recorder @@ -13,7 +15,7 @@ _DataT = TypeVar("_DataT") class BaseTableManager(Generic[_DataT]): """Base class for table managers.""" - _id_map: "LRU[str, int]" + _id_map: "LRU[EventType[Any] | str, int]" def __init__(self, recorder: "Recorder") -> None: """Initialize the table manager. @@ -24,7 +26,7 @@ class BaseTableManager(Generic[_DataT]): """ self.active = False self.recorder = recorder - self._pending: dict[str, _DataT] = {} + self._pending: dict[EventType[Any] | str, _DataT] = {} def get_from_cache(self, data: str) -> int | None: """Resolve data to the id without accessing the underlying database. @@ -34,7 +36,7 @@ class BaseTableManager(Generic[_DataT]): """ return self._id_map.get(data) - def get_pending(self, shared_data: str) -> _DataT | None: + def get_pending(self, shared_data: EventType[Any] | str) -> _DataT | None: """Get pending data that have not be assigned ids yet. This call is not thread-safe and must be called from the diff --git a/homeassistant/components/recorder/table_managers/event_types.py b/homeassistant/components/recorder/table_managers/event_types.py index 94ceab7bf68..73401e8df56 100644 --- a/homeassistant/components/recorder/table_managers/event_types.py +++ b/homeassistant/components/recorder/table_managers/event_types.py @@ -3,12 +3,13 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from lru import LRU from sqlalchemy.orm.session import Session from homeassistant.core import Event +from homeassistant.util.event_type import EventType from ..db_schema import EventTypes from ..queries import find_event_type_ids @@ -29,7 +30,9 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]): def __init__(self, recorder: Recorder) -> None: """Initialize the event type manager.""" super().__init__(recorder, CACHE_SIZE) - self._non_existent_event_types: LRU[str, None] = LRU(CACHE_SIZE) + self._non_existent_event_types: LRU[EventType[Any] | str, None] = LRU( + CACHE_SIZE + ) def load(self, events: list[Event], session: Session) -> None: """Load the event_type to event_type_ids mapping into memory. @@ -44,7 +47,10 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]): ) def get( - self, event_type: str, session: Session, from_recorder: bool = False + self, + event_type: EventType[Any] | str, + session: Session, + from_recorder: bool = False, ) -> int | None: """Resolve event_type to the event_type_id. @@ -54,16 +60,19 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]): return self.get_many((event_type,), session)[event_type] def get_many( - self, event_types: Iterable[str], session: Session, from_recorder: bool = False - ) -> dict[str, int | None]: + self, + event_types: Iterable[EventType[Any] | str], + session: Session, + from_recorder: bool = False, + ) -> dict[EventType[Any] | str, int | None]: """Resolve event_types to event_type_ids. This call is not thread-safe and must be called from the recorder thread. """ - results: dict[str, int | None] = {} - missing: list[str] = [] - non_existent: list[str] = [] + results: dict[EventType[Any] | str, int | None] = {} + missing: list[EventType[Any] | str] = [] + non_existent: list[EventType[Any] | str] = [] for event_type in event_types: if (event_type_id := self._id_map.get(event_type)) is None: @@ -123,7 +132,7 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]): self.clear_non_existent(event_type) self._pending.clear() - def clear_non_existent(self, event_type: str) -> None: + def clear_non_existent(self, event_type: EventType[Any] | str) -> None: """Clear a non-existent event type from the cache. This call is not thread-safe and must be called from the diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 1b81d7a983f..2d980c849e5 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -12,6 +12,7 @@ import threading from typing import TYPE_CHECKING, Any from homeassistant.helpers.typing import UndefinedType +from homeassistant.util.event_type import EventType from . import entity_registry, purge, statistics from .const import DOMAIN @@ -459,7 +460,7 @@ class EventIdMigrationTask(RecorderTask): class RefreshEventTypesTask(RecorderTask): """An object to insert into the recorder queue to refresh event types.""" - event_types: list[str] + event_types: list[EventType[Any] | str] def run(self, instance: Recorder) -> None: """Refresh event types.""" diff --git a/homeassistant/const.py b/homeassistant/const.py index d52deb98d5b..0eed33c48d7 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import StrEnum from functools import partial -from typing import Final +from typing import TYPE_CHECKING, Final from .helpers.deprecation import ( DeprecatedConstant, @@ -13,8 +13,12 @@ from .helpers.deprecation import ( check_if_deprecated_constant, dir_with_deprecated_constants, ) +from .util.event_type import EventType from .util.signal_type import SignalType +if TYPE_CHECKING: + from .core import EventStateChangedData + APPLICATION_NAME: Final = "HomeAssistant" MAJOR_VERSION: Final = 2024 MINOR_VERSION: Final = 5 @@ -306,7 +310,7 @@ EVENT_LOGBOOK_ENTRY: Final = "logbook_entry" EVENT_LOGGING_CHANGED: Final = "logging_changed" EVENT_SERVICE_REGISTERED: Final = "service_registered" EVENT_SERVICE_REMOVED: Final = "service_removed" -EVENT_STATE_CHANGED: Final = "state_changed" +EVENT_STATE_CHANGED: EventType[EventStateChangedData] = EventType("state_changed") EVENT_STATE_REPORTED: Final = "state_reported" EVENT_THEMES_UPDATED: Final = "themes_updated" EVENT_PANELS_UPDATED: Final = "panels_updated" diff --git a/homeassistant/core.py b/homeassistant/core.py index ccea82a7eb9..574edf34c9b 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -102,6 +102,7 @@ from .util.async_ import ( run_callback_threadsafe, shutdown_run_callback_threadsafe, ) +from .util.event_type import EventType from .util.executor import InterruptibleThreadPoolExecutor from .util.json import JsonObjectType from .util.read_only_dict import ReadOnlyDict @@ -1216,7 +1217,7 @@ class Event(Generic[_DataT]): def __init__( self, - event_type: str, + event_type: EventType[_DataT] | str, data: _DataT | None = None, origin: EventOrigin = EventOrigin.local, time_fired_timestamp: float | None = None, @@ -1290,7 +1291,7 @@ class Event(Generic[_DataT]): def _event_repr( - event_type: str, origin: EventOrigin, data: Mapping[str, Any] | None + event_type: EventType[_DataT] | str, origin: EventOrigin, data: _DataT | None ) -> str: """Return the representation.""" if data: @@ -1307,13 +1308,13 @@ _FilterableJobType = tuple[ @dataclass(slots=True) -class _OneTimeListener: +class _OneTimeListener(Generic[_DataT]): hass: HomeAssistant - listener_job: HassJob[[Event], Coroutine[Any, Any, None] | None] + listener_job: HassJob[[Event[_DataT]], Coroutine[Any, Any, None] | None] remove: CALLBACK_TYPE | None = None @callback - def __call__(self, event: Event) -> None: + def __call__(self, event: Event[_DataT]) -> None: """Remove listener from event bus and then fire listener.""" if not self.remove: # If the listener was already removed, we don't need to do anything @@ -1341,7 +1342,7 @@ class EventBus: def __init__(self, hass: HomeAssistant) -> None: """Initialize a new event bus.""" - self._listeners: dict[str, list[_FilterableJobType[Any]]] = {} + self._listeners: dict[EventType[Any] | str, list[_FilterableJobType[Any]]] = {} self._match_all_listeners: list[_FilterableJobType[Any]] = [] self._listeners[MATCH_ALL] = self._match_all_listeners self._hass = hass @@ -1356,7 +1357,7 @@ class EventBus: self._debug = _LOGGER.isEnabledFor(logging.DEBUG) @callback - def async_listeners(self) -> dict[str, int]: + def async_listeners(self) -> dict[EventType[Any] | str, int]: """Return dictionary with events and the number of listeners. This method must be run in the event loop. @@ -1364,14 +1365,14 @@ class EventBus: return {key: len(listeners) for key, listeners in self._listeners.items()} @property - def listeners(self) -> dict[str, int]: + def listeners(self) -> dict[EventType[Any] | str, int]: """Return dictionary with events and the number of listeners.""" return run_callback_threadsafe(self._hass.loop, self.async_listeners).result() def fire( self, - event_type: str, - event_data: Mapping[str, Any] | None = None, + event_type: EventType[_DataT] | str, + event_data: _DataT | None = None, origin: EventOrigin = EventOrigin.local, context: Context | None = None, ) -> None: @@ -1383,8 +1384,8 @@ class EventBus: @callback def async_fire( self, - event_type: str, - event_data: Mapping[str, Any] | None = None, + event_type: EventType[_DataT] | str, + event_data: _DataT | None = None, origin: EventOrigin = EventOrigin.local, context: Context | None = None, time_fired: float | None = None, @@ -1402,8 +1403,8 @@ class EventBus: @callback def _async_fire( self, - event_type: str, - event_data: Mapping[str, Any] | None = None, + event_type: EventType[_DataT] | str, + event_data: _DataT | None = None, origin: EventOrigin = EventOrigin.local, context: Context | None = None, time_fired: float | None = None, @@ -1431,7 +1432,7 @@ class EventBus: if not listeners: return - event: Event | None = None + event: Event[_DataT] | None = None for job, event_filter, run_immediately in listeners: if event_filter is not None: @@ -1461,8 +1462,8 @@ class EventBus: def listen( self, - event_type: str, - listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None], + event_type: EventType[_DataT] | str, + listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None], ) -> CALLBACK_TYPE: """Listen for all events or events of a specific type. @@ -1482,7 +1483,7 @@ class EventBus: @callback def async_listen( self, - event_type: str, + event_type: EventType[_DataT] | str, listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None], event_filter: Callable[[_DataT], bool] | None = None, run_immediately: bool = True, @@ -1524,7 +1525,9 @@ class EventBus: @callback def _async_listen_filterable_job( - self, event_type: str, filterable_job: _FilterableJobType[Any] + self, + event_type: EventType[_DataT] | str, + filterable_job: _FilterableJobType[_DataT], ) -> CALLBACK_TYPE: self._listeners.setdefault(event_type, []).append(filterable_job) return functools.partial( @@ -1533,8 +1536,8 @@ class EventBus: def listen_once( self, - event_type: str, - listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None], + event_type: EventType[_DataT] | str, + listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None], ) -> CALLBACK_TYPE: """Listen once for event of a specific type. @@ -1556,8 +1559,8 @@ class EventBus: @callback def async_listen_once( self, - event_type: str, - listener: Callable[[Event[Any]], Coroutine[Any, Any, None] | None], + event_type: EventType[_DataT] | str, + listener: Callable[[Event[_DataT]], Coroutine[Any, Any, None] | None], run_immediately: bool = True, ) -> CALLBACK_TYPE: """Listen once for event of a specific type. @@ -1569,7 +1572,9 @@ class EventBus: This method must be run in the event loop. """ - one_time_listener = _OneTimeListener(self._hass, HassJob(listener)) + one_time_listener: _OneTimeListener[_DataT] = _OneTimeListener( + self._hass, HassJob(listener) + ) remove = self._async_listen_filterable_job( event_type, ( @@ -1587,7 +1592,9 @@ class EventBus: @callback def _async_remove_listener( - self, event_type: str, filterable_job: _FilterableJobType + self, + event_type: EventType[_DataT] | str, + filterable_job: _FilterableJobType[_DataT], ) -> None: """Remove a listener of a specific event_type. diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index bdf4d8c060b..1eb964d82b1 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -4,7 +4,9 @@ from __future__ import annotations from collections.abc import Callable, Generator, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + +from .util.event_type import EventType if TYPE_CHECKING: from .core import Context @@ -271,8 +273,12 @@ class ServiceNotFound(HomeAssistantError): class MaxLengthExceeded(HomeAssistantError): """Raised when a property value has exceeded the max character length.""" - def __init__(self, value: str, property_name: str, max_length: int) -> None: + def __init__( + self, value: EventType[Any] | str, property_name: str, max_length: int + ) -> None: """Initialize error.""" + if TYPE_CHECKING: + value = str(value) super().__init__( translation_domain="homeassistant", translation_key="max_length_exceeded", diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 67feb6c48a4..e3f91320c7b 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -38,6 +38,7 @@ from homeassistant.exceptions import TemplateError from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe +from homeassistant.util.event_type import EventType from .device_registry import ( EVENT_DEVICE_REGISTRY_UPDATED, @@ -90,7 +91,7 @@ class _KeyedEventTracker(Generic[_TypedDictT]): listeners_key: str callbacks_key: str - event_type: str + event_type: EventType[_TypedDictT] | str dispatcher_callable: Callable[ [ HomeAssistant, diff --git a/homeassistant/util/event_type.py b/homeassistant/util/event_type.py new file mode 100644 index 00000000000..e96d45c80a3 --- /dev/null +++ b/homeassistant/util/event_type.py @@ -0,0 +1,20 @@ +"""Implementation for EventType. + +Custom for type checking. See stub file. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Generic + +from typing_extensions import TypeVar + +_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any]) + + +class EventType(str, Generic[_DataT]): + """Custom type for Event.event_type. + + At runtime this is a generic subclass of str. + """ diff --git a/homeassistant/util/event_type.pyi b/homeassistant/util/event_type.pyi new file mode 100644 index 00000000000..4285e54e8c9 --- /dev/null +++ b/homeassistant/util/event_type.pyi @@ -0,0 +1,25 @@ +"""Stub file for event_type. Provide overload for type checking.""" +# ruff: noqa: PYI021 # Allow docstrings + +from collections.abc import Mapping +from typing import Any, Generic + +from typing_extensions import TypeVar + +__all__ = [ + "EventType", +] + +_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any]) + +class EventType(Generic[_DataT]): + """Custom type for Event.event_type. At runtime delegated to str. + + For type checkers pretend to be its own separate class. + """ + + def __init__(self, value: str, /) -> None: ... + def __len__(self) -> int: ... + def __hash__(self) -> int: ... + def __eq__(self, value: object, /) -> bool: ... + def __getitem__(self, index: int) -> str: ... diff --git a/tests/util/test_event_type.py b/tests/util/test_event_type.py new file mode 100644 index 00000000000..3086c8ea075 --- /dev/null +++ b/tests/util/test_event_type.py @@ -0,0 +1,25 @@ +"""Test EventType implementation.""" + +from __future__ import annotations + +import orjson + +from homeassistant.util.event_type import EventType + + +def test_compatibility_with_str() -> None: + """Test EventType. At runtime it should be (almost) fully compatible with str.""" + + event = EventType("Hello World") + assert event == "Hello World" + assert len(event) == 11 + assert hash(event) == hash("Hello World") + d: dict[str | EventType, int] = {EventType("key"): 2} + assert d["key"] == 2 + + +def test_json_dump() -> None: + """Test EventType json dump with orjson.""" + + event = EventType("state_changed") + assert orjson.dumps({"event_type": event}) == b'{"event_type":"state_changed"}'