Add strict typing to core.py (2) - State (#63240)

This commit is contained in:
Marc Mueller 2022-01-04 18:33:56 +01:00 committed by GitHub
parent 5f5adffd5b
commit 3a32fe9a34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 9 deletions

View File

@ -880,6 +880,9 @@ class EventBus:
) )
_StateT = TypeVar("_StateT", bound="State")
class State: class State:
"""Object to represent a state within the state machine. """Object to represent a state within the state machine.
@ -946,7 +949,7 @@ class State:
"_", " " "_", " "
) )
def as_dict(self) -> dict: def as_dict(self) -> dict[str, Collection[Any]]:
"""Return a dict representation of the State. """Return a dict representation of the State.
Async friendly. Async friendly.
@ -971,7 +974,7 @@ class State:
return self._as_dict return self._as_dict
@classmethod @classmethod
def from_dict(cls, json_dict: dict) -> Any: def from_dict(cls: type[_StateT], json_dict: dict[str, Any]) -> _StateT | None:
"""Initialize a state from a dict. """Initialize a state from a dict.
Async friendly. Async friendly.
@ -1042,7 +1045,7 @@ class StateMachine:
@callback @callback
def async_entity_ids( def async_entity_ids(
self, domain_filter: str | Iterable | None = None self, domain_filter: str | Iterable[str] | None = None
) -> list[str]: ) -> list[str]:
"""List of entity ids that are being tracked. """List of entity ids that are being tracked.
@ -1062,7 +1065,7 @@ class StateMachine:
@callback @callback
def async_entity_ids_count( def async_entity_ids_count(
self, domain_filter: str | Iterable | None = None self, domain_filter: str | Iterable[str] | None = None
) -> int: ) -> int:
"""Count the entity ids that are being tracked. """Count the entity ids that are being tracked.
@ -1078,14 +1081,16 @@ class StateMachine:
[None for state in self._states.values() if state.domain in domain_filter] [None for state in self._states.values() if state.domain in domain_filter]
) )
def all(self, domain_filter: str | Iterable | None = None) -> list[State]: def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
"""Create a list of all states.""" """Create a list of all states."""
return run_callback_threadsafe( return run_callback_threadsafe(
self._loop, self.async_all, domain_filter self._loop, self.async_all, domain_filter
).result() ).result()
@callback @callback
def async_all(self, domain_filter: str | Iterable | None = None) -> list[State]: def async_all(
self, domain_filter: str | Iterable[str] | None = None
) -> list[State]:
"""Create a list of all states matching the filter. """Create a list of all states matching the filter.
This method must be run in the event loop. This method must be run in the event loop.

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from typing import Any, cast from typing import Any, TypeVar, cast
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant, State, callback, valid_entity_id from homeassistant.core import HomeAssistant, State, callback, valid_entity_id
@ -31,6 +31,8 @@ STATE_DUMP_INTERVAL = timedelta(minutes=15)
# How long should a saved state be preserved if the entity no longer exists # How long should a saved state be preserved if the entity no longer exists
STATE_EXPIRATION = timedelta(days=7) STATE_EXPIRATION = timedelta(days=7)
_StoredStateT = TypeVar("_StoredStateT", bound="StoredState")
class StoredState: class StoredState:
"""Object to represent a stored state.""" """Object to represent a stored state."""
@ -45,14 +47,14 @@ class StoredState:
return {"state": self.state.as_dict(), "last_seen": self.last_seen} return {"state": self.state.as_dict(), "last_seen": self.last_seen}
@classmethod @classmethod
def from_dict(cls, json_dict: dict) -> StoredState: def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT:
"""Initialize a stored state from a dict.""" """Initialize a stored state from a dict."""
last_seen = json_dict["last_seen"] last_seen = json_dict["last_seen"]
if isinstance(last_seen, str): if isinstance(last_seen, str):
last_seen = dt_util.parse_datetime(last_seen) last_seen = dt_util.parse_datetime(last_seen)
return cls(State.from_dict(json_dict["state"]), last_seen) return cls(cast(State, State.from_dict(json_dict["state"])), last_seen)
class RestoreStateData: class RestoreStateData: