diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index cc412c88612..92d9baed771 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -30,7 +30,13 @@ from homeassistant.const import ( EVENT_STATE_CHANGED, MATCH_ALL, ) -from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + EventStateChangedData, + HomeAssistant, + callback, +) from homeassistant.helpers.event import ( async_track_time_change, async_track_time_interval, @@ -862,12 +868,12 @@ class Recorder(threading.Thread): self._guarded_process_one_task_or_event_or_recover(queue_.get()) def _pre_process_startup_events( - self, startup_task_or_events: list[RecorderTask | Event] + self, startup_task_or_events: list[RecorderTask | Event[Any]] ) -> None: """Pre process startup events.""" # Prime all the state_attributes and event_data caches # before we start processing events - state_change_events: list[Event] = [] + state_change_events: list[Event[EventStateChangedData]] = [] non_state_change_events: list[Event] = [] for task_or_event in startup_task_or_events: @@ -1019,7 +1025,7 @@ class Recorder(threading.Thread): self.backlog, ) - def _process_one_event(self, event: Event) -> None: + def _process_one_event(self, event: Event[Any]) -> None: if not self.enabled: return if event.event_type == EVENT_STATE_CHANGED: @@ -1076,7 +1082,9 @@ class Recorder(threading.Thread): self._add_to_session(session, dbevent) - def _process_state_changed_event_into_session(self, event: Event) -> None: + def _process_state_changed_event_into_session( + self, event: Event[EventStateChangedData] + ) -> None: """Process a state_changed event into the session.""" state_attributes_manager = self.state_attributes_manager states_meta_manager = self.states_meta_manager diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index eac743c3d75..186b873047b 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -40,7 +40,7 @@ from homeassistant.const import ( MAX_LENGTH_STATE_ENTITY_ID, MAX_LENGTH_STATE_STATE, ) -from homeassistant.core import Context, Event, EventOrigin, State +from homeassistant.core import Context, Event, EventOrigin, EventStateChangedData, State from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null import homeassistant.util.dt as dt_util from homeassistant.util.json import ( @@ -478,10 +478,10 @@ class States(Base): return date_time.isoformat(sep=" ", timespec="seconds") @staticmethod - def from_event(event: Event) -> States: + def from_event(event: Event[EventStateChangedData]) -> States: """Create object from a state_changed event.""" entity_id = event.data["entity_id"] - state: State | None = event.data.get("new_state") + state = event.data["new_state"] dbstate = States( entity_id=entity_id, attributes=None, @@ -576,13 +576,12 @@ class StateAttributes(Base): @staticmethod def shared_attrs_bytes_from_event( - event: Event, + event: Event[EventStateChangedData], dialect: SupportedDialect | None, ) -> bytes: """Create shared_attrs from a state_changed event.""" - state: State | None = event.data.get("new_state") # None state means the state was removed from the state machine - if state is None: + if (state := event.data["new_state"]) is None: return b"{}" if state_info := state.state_info: exclude_attrs = { diff --git a/homeassistant/components/recorder/table_managers/state_attributes.py b/homeassistant/components/recorder/table_managers/state_attributes.py index e2fb9153be8..ec975d310e9 100644 --- a/homeassistant/components/recorder/table_managers/state_attributes.py +++ b/homeassistant/components/recorder/table_managers/state_attributes.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast from sqlalchemy.orm.session import Session -from homeassistant.core import Event +from homeassistant.core import Event, EventStateChangedData from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS from ..db_schema import StateAttributes @@ -38,7 +38,7 @@ class StateAttributesManager(BaseLRUTableManager[StateAttributes]): super().__init__(recorder, CACHE_SIZE) self.active = True # always active - def serialize_from_event(self, event: Event) -> bytes | None: + def serialize_from_event(self, event: Event[EventStateChangedData]) -> bytes | None: """Serialize event data.""" try: return StateAttributes.shared_attrs_bytes_from_event( @@ -47,12 +47,14 @@ class StateAttributesManager(BaseLRUTableManager[StateAttributes]): except JSON_ENCODE_EXCEPTIONS as ex: _LOGGER.warning( "State is not JSON serializable: %s: %s", - event.data.get("new_state"), + event.data["new_state"], ex, ) return None - def load(self, events: list[Event], session: Session) -> None: + def load( + self, events: list[Event[EventStateChangedData]], session: Session + ) -> None: """Load the shared_attrs to attributes_ids mapping into memory from events. This call is not thread-safe and must be called from the diff --git a/homeassistant/components/recorder/table_managers/states_meta.py b/homeassistant/components/recorder/table_managers/states_meta.py index ebc1dab45f3..2c73dcf3a54 100644 --- a/homeassistant/components/recorder/table_managers/states_meta.py +++ b/homeassistant/components/recorder/table_managers/states_meta.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, cast from sqlalchemy.orm.session import Session -from homeassistant.core import Event +from homeassistant.core import Event, EventStateChangedData from ..db_schema import StatesMeta from ..queries import find_all_states_metadata_ids, find_states_metadata_ids @@ -28,7 +28,9 @@ class StatesMetaManager(BaseLRUTableManager[StatesMeta]): self._did_first_load = False super().__init__(recorder, CACHE_SIZE) - def load(self, events: list[Event], session: Session) -> None: + def load( + self, events: list[Event[EventStateChangedData]], session: Session + ) -> None: """Load the entity_id to metadata_id mapping into memory. This call is not thread-safe and must be called from the @@ -37,9 +39,9 @@ class StatesMetaManager(BaseLRUTableManager[StatesMeta]): self._did_first_load = True self.get_many( { - event.data["new_state"].entity_id + new_state.entity_id for event in events - if event.data.get("new_state") is not None + if (new_state := event.data["new_state"]) is not None }, session, True,