diff --git a/homeassistant/core.py b/homeassistant/core.py index 94f785cd4ef..3e42b5168c2 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -880,6 +880,9 @@ class EventBus: ) +_StateT = TypeVar("_StateT", bound="State") + + class State: """Object to represent a state within the state machine. @@ -946,7 +949,7 @@ class State: "_", " " ) - def as_dict(self) -> dict: + def as_dict(self) -> dict[str, Collection[Any]]: """Return a dict representation of the State. Async friendly. @@ -971,7 +974,7 @@ class State: return self._as_dict @classmethod - def from_dict(cls, json_dict: dict) -> Any: + def from_dict(cls: type[_StateT], json_dict: dict[str, Any]) -> _StateT | None: """Initialize a state from a dict. Async friendly. @@ -1042,7 +1045,7 @@ class StateMachine: @callback def async_entity_ids( - self, domain_filter: str | Iterable | None = None + self, domain_filter: str | Iterable[str] | None = None ) -> list[str]: """List of entity ids that are being tracked. @@ -1062,7 +1065,7 @@ class StateMachine: @callback def async_entity_ids_count( - self, domain_filter: str | Iterable | None = None + self, domain_filter: str | Iterable[str] | None = None ) -> int: """Count the entity ids that are being tracked. @@ -1078,14 +1081,16 @@ class StateMachine: [None for state in self._states.values() if state.domain in domain_filter] ) - def all(self, domain_filter: str | Iterable | None = None) -> list[State]: + def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]: """Create a list of all states.""" return run_callback_threadsafe( self._loop, self.async_all, domain_filter ).result() @callback - def async_all(self, domain_filter: str | Iterable | None = None) -> list[State]: + def async_all( + self, domain_filter: str | Iterable[str] | None = None + ) -> list[State]: """Create a list of all states matching the filter. This method must be run in the event loop. diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index 56b5b106278..19b1509c162 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from datetime import datetime, timedelta import logging -from typing import Any, cast +from typing import Any, TypeVar, cast from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant, State, callback, valid_entity_id @@ -31,6 +31,8 @@ STATE_DUMP_INTERVAL = timedelta(minutes=15) # How long should a saved state be preserved if the entity no longer exists STATE_EXPIRATION = timedelta(days=7) +_StoredStateT = TypeVar("_StoredStateT", bound="StoredState") + class StoredState: """Object to represent a stored state.""" @@ -45,14 +47,14 @@ class StoredState: return {"state": self.state.as_dict(), "last_seen": self.last_seen} @classmethod - def from_dict(cls, json_dict: dict) -> StoredState: + def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT: """Initialize a stored state from a dict.""" last_seen = json_dict["last_seen"] if isinstance(last_seen, str): last_seen = dt_util.parse_datetime(last_seen) - return cls(State.from_dict(json_dict["state"]), last_seen) + return cls(cast(State, State.from_dict(json_dict["state"])), last_seen) class RestoreStateData: