Update typing 01 (#48013)

This commit is contained in:
Marc Mueller
2021-03-17 17:34:55 +01:00
committed by GitHub
parent 9011a54e7f
commit e55702d635
11 changed files with 303 additions and 313 deletions

View File

@@ -4,6 +4,8 @@ Core components of Home Assistant.
Home Assistant is a Home Automation framework for observing the state
of entities and react to changes.
"""
from __future__ import annotations
import asyncio
import datetime
import enum
@@ -22,15 +24,10 @@ from typing import (
Callable,
Collection,
Coroutine,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
)
@@ -119,7 +116,7 @@ TIMEOUT_EVENT_START = 15
_LOGGER = logging.getLogger(__name__)
def split_entity_id(entity_id: str) -> List[str]:
def split_entity_id(entity_id: str) -> list[str]:
"""Split a state entity ID into domain and object ID."""
return entity_id.split(".", 1)
@@ -237,7 +234,7 @@ class HomeAssistant:
self.state: CoreState = CoreState.not_running
self.exit_code: int = 0
# If not None, use to signal end-of-loop
self._stopped: Optional[asyncio.Event] = None
self._stopped: asyncio.Event | None = None
# Timeout handler for Core/Helper namespace
self.timeout: TimeoutManager = TimeoutManager()
@@ -342,7 +339,7 @@ class HomeAssistant:
@callback
def async_add_job(
self, target: Callable[..., Any], *args: Any
) -> Optional[asyncio.Future]:
) -> asyncio.Future | None:
"""Add a job from within the event loop.
This method must be run in the event loop.
@@ -359,9 +356,7 @@ class HomeAssistant:
return self.async_add_hass_job(HassJob(target), *args)
@callback
def async_add_hass_job(
self, hassjob: HassJob, *args: Any
) -> Optional[asyncio.Future]:
def async_add_hass_job(self, hassjob: HassJob, *args: Any) -> asyncio.Future | None:
"""Add a HassJob from within the event loop.
This method must be run in the event loop.
@@ -423,9 +418,7 @@ class HomeAssistant:
self._track_task = False
@callback
def async_run_hass_job(
self, hassjob: HassJob, *args: Any
) -> Optional[asyncio.Future]:
def async_run_hass_job(self, hassjob: HassJob, *args: Any) -> asyncio.Future | None:
"""Run a HassJob from within the event loop.
This method must be run in the event loop.
@@ -441,8 +434,8 @@ class HomeAssistant:
@callback
def async_run_job(
self, target: Callable[..., Union[None, Awaitable]], *args: Any
) -> Optional[asyncio.Future]:
self, target: Callable[..., None | Awaitable], *args: Any
) -> asyncio.Future | None:
"""Run a job from within the event loop.
This method must be run in the event loop.
@@ -465,7 +458,7 @@ class HomeAssistant:
"""Block until all pending work is done."""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: Optional[float] = None
start_time: float | None = None
while self._pending_tasks:
pending = [task for task in self._pending_tasks if not task.done()]
@@ -582,10 +575,10 @@ class Context:
"""The context that triggered something."""
user_id: str = attr.ib(default=None)
parent_id: Optional[str] = attr.ib(default=None)
parent_id: str | None = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex)
def as_dict(self) -> Dict[str, Optional[str]]:
def as_dict(self) -> dict[str, str | None]:
"""Return a dictionary representation of the context."""
return {"id": self.id, "parent_id": self.parent_id, "user_id": self.user_id}
@@ -610,10 +603,10 @@ class Event:
def __init__(
self,
event_type: str,
data: Optional[Dict[str, Any]] = None,
data: dict[str, Any] | None = None,
origin: EventOrigin = EventOrigin.local,
time_fired: Optional[datetime.datetime] = None,
context: Optional[Context] = None,
time_fired: datetime.datetime | None = None,
context: Context | None = None,
) -> None:
"""Initialize a new event."""
self.event_type = event_type
@@ -627,7 +620,7 @@ class Event:
# The only event type that shares context are the TIME_CHANGED
return hash((self.event_type, self.context.id, self.time_fired))
def as_dict(self) -> Dict[str, Any]:
def as_dict(self) -> dict[str, Any]:
"""Create a dict representation of this Event.
Async friendly.
@@ -664,11 +657,11 @@ class EventBus:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus."""
self._listeners: Dict[str, List[Tuple[HassJob, Optional[Callable]]]] = {}
self._listeners: dict[str, list[tuple[HassJob, Callable | None]]] = {}
self._hass = hass
@callback
def async_listeners(self) -> Dict[str, int]:
def async_listeners(self) -> dict[str, int]:
"""Return dictionary with events and the number of listeners.
This method must be run in the event loop.
@@ -676,16 +669,16 @@ class EventBus:
return {key: len(self._listeners[key]) for key in self._listeners}
@property
def listeners(self) -> Dict[str, int]:
def listeners(self) -> dict[str, int]:
"""Return dictionary with events and the number of listeners."""
return run_callback_threadsafe(self._hass.loop, self.async_listeners).result()
def fire(
self,
event_type: str,
event_data: Optional[Dict] = None,
event_data: dict | None = None,
origin: EventOrigin = EventOrigin.local,
context: Optional[Context] = None,
context: Context | None = None,
) -> None:
"""Fire an event."""
self._hass.loop.call_soon_threadsafe(
@@ -696,10 +689,10 @@ class EventBus:
def async_fire(
self,
event_type: str,
event_data: Optional[Dict[str, Any]] = None,
event_data: dict[str, Any] | None = None,
origin: EventOrigin = EventOrigin.local,
context: Optional[Context] = None,
time_fired: Optional[datetime.datetime] = None,
context: Context | None = None,
time_fired: datetime.datetime | None = None,
) -> None:
"""Fire an event.
@@ -751,7 +744,7 @@ class EventBus:
self,
event_type: str,
listener: Callable,
event_filter: Optional[Callable] = None,
event_filter: Callable | None = None,
) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type.
@@ -772,7 +765,7 @@ class EventBus:
@callback
def _async_listen_filterable_job(
self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]]
self, event_type: str, filterable_job: tuple[HassJob, Callable | None]
) -> CALLBACK_TYPE:
self._listeners.setdefault(event_type, []).append(filterable_job)
@@ -811,7 +804,7 @@ class EventBus:
This method must be run in the event loop.
"""
filterable_job: Optional[Tuple[HassJob, Optional[Callable]]] = None
filterable_job: tuple[HassJob, Callable | None] | None = None
@callback
def _onetime_listener(event: Event) -> None:
@@ -835,7 +828,7 @@ class EventBus:
@callback
def _async_remove_listener(
self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]]
self, event_type: str, filterable_job: tuple[HassJob, Callable | None]
) -> None:
"""Remove a listener of a specific event_type.
@@ -884,11 +877,11 @@ class State:
self,
entity_id: str,
state: str,
attributes: Optional[Mapping[str, Any]] = None,
last_changed: Optional[datetime.datetime] = None,
last_updated: Optional[datetime.datetime] = None,
context: Optional[Context] = None,
validate_entity_id: Optional[bool] = True,
attributes: Mapping[str, Any] | None = None,
last_changed: datetime.datetime | None = None,
last_updated: datetime.datetime | None = None,
context: Context | None = None,
validate_entity_id: bool | None = True,
) -> None:
"""Initialize a new state."""
state = str(state)
@@ -912,7 +905,7 @@ class State:
self.last_changed = last_changed or self.last_updated
self.context = context or Context()
self.domain, self.object_id = split_entity_id(self.entity_id)
self._as_dict: Optional[Dict[str, Collection[Any]]] = None
self._as_dict: dict[str, Collection[Any]] | None = None
@property
def name(self) -> str:
@@ -921,7 +914,7 @@ class State:
"_", " "
)
def as_dict(self) -> Dict:
def as_dict(self) -> dict:
"""Return a dict representation of the State.
Async friendly.
@@ -946,7 +939,7 @@ class State:
return self._as_dict
@classmethod
def from_dict(cls, json_dict: Dict) -> Any:
def from_dict(cls, json_dict: dict) -> Any:
"""Initialize a state from a dict.
Async friendly.
@@ -1004,12 +997,12 @@ class StateMachine:
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
"""Initialize state machine."""
self._states: Dict[str, State] = {}
self._reservations: Set[str] = set()
self._states: dict[str, State] = {}
self._reservations: set[str] = set()
self._bus = bus
self._loop = loop
def entity_ids(self, domain_filter: Optional[str] = None) -> List[str]:
def entity_ids(self, domain_filter: str | None = None) -> list[str]:
"""List of entity ids that are being tracked."""
future = run_callback_threadsafe(
self._loop, self.async_entity_ids, domain_filter
@@ -1018,8 +1011,8 @@ class StateMachine:
@callback
def async_entity_ids(
self, domain_filter: Optional[Union[str, Iterable]] = None
) -> List[str]:
self, domain_filter: str | Iterable | None = None
) -> list[str]:
"""List of entity ids that are being tracked.
This method must be run in the event loop.
@@ -1038,7 +1031,7 @@ class StateMachine:
@callback
def async_entity_ids_count(
self, domain_filter: Optional[Union[str, Iterable]] = None
self, domain_filter: str | Iterable | None = None
) -> int:
"""Count the entity ids that are being tracked.
@@ -1054,16 +1047,14 @@ class StateMachine:
[None for state in self._states.values() if state.domain in domain_filter]
)
def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
def all(self, domain_filter: str | Iterable | None = None) -> list[State]:
"""Create a list of all states."""
return run_callback_threadsafe(
self._loop, self.async_all, domain_filter
).result()
@callback
def async_all(
self, domain_filter: Optional[Union[str, Iterable]] = None
) -> List[State]:
def async_all(self, domain_filter: str | Iterable | None = None) -> list[State]:
"""Create a list of all states matching the filter.
This method must be run in the event loop.
@@ -1078,7 +1069,7 @@ class StateMachine:
state for state in self._states.values() if state.domain in domain_filter
]
def get(self, entity_id: str) -> Optional[State]:
def get(self, entity_id: str) -> State | None:
"""Retrieve state of entity_id or None if not found.
Async friendly.
@@ -1103,7 +1094,7 @@ class StateMachine:
).result()
@callback
def async_remove(self, entity_id: str, context: Optional[Context] = None) -> bool:
def async_remove(self, entity_id: str, context: Context | None = None) -> bool:
"""Remove the state of an entity.
Returns boolean to indicate if an entity was removed.
@@ -1131,9 +1122,9 @@ class StateMachine:
self,
entity_id: str,
new_state: str,
attributes: Optional[Mapping[str, Any]] = None,
attributes: Mapping[str, Any] | None = None,
force_update: bool = False,
context: Optional[Context] = None,
context: Context | None = None,
) -> None:
"""Set the state of an entity, add entity if it does not exist.
@@ -1180,9 +1171,9 @@ class StateMachine:
self,
entity_id: str,
new_state: str,
attributes: Optional[Mapping[str, Any]] = None,
attributes: Mapping[str, Any] | None = None,
force_update: bool = False,
context: Optional[Context] = None,
context: Context | None = None,
) -> None:
"""Set the state of an entity, add entity if it does not exist.
@@ -1241,8 +1232,8 @@ class Service:
def __init__(
self,
func: Callable,
schema: Optional[vol.Schema],
context: Optional[Context] = None,
schema: vol.Schema | None,
context: Context | None = None,
) -> None:
"""Initialize a service."""
self.job = HassJob(func)
@@ -1258,8 +1249,8 @@ class ServiceCall:
self,
domain: str,
service: str,
data: Optional[Dict] = None,
context: Optional[Context] = None,
data: dict | None = None,
context: Context | None = None,
) -> None:
"""Initialize a service call."""
self.domain = domain.lower()
@@ -1283,16 +1274,16 @@ class ServiceRegistry:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a service registry."""
self._services: Dict[str, Dict[str, Service]] = {}
self._services: dict[str, dict[str, Service]] = {}
self._hass = hass
@property
def services(self) -> Dict[str, Dict[str, Service]]:
def services(self) -> dict[str, dict[str, Service]]:
"""Return dictionary with per domain a list of available services."""
return run_callback_threadsafe(self._hass.loop, self.async_services).result()
@callback
def async_services(self) -> Dict[str, Dict[str, Service]]:
def async_services(self) -> dict[str, dict[str, Service]]:
"""Return dictionary with per domain a list of available services.
This method must be run in the event loop.
@@ -1311,7 +1302,7 @@ class ServiceRegistry:
domain: str,
service: str,
service_func: Callable,
schema: Optional[vol.Schema] = None,
schema: vol.Schema | None = None,
) -> None:
"""
Register a service.
@@ -1328,7 +1319,7 @@ class ServiceRegistry:
domain: str,
service: str,
service_func: Callable,
schema: Optional[vol.Schema] = None,
schema: vol.Schema | None = None,
) -> None:
"""
Register a service.
@@ -1382,12 +1373,12 @@ class ServiceRegistry:
self,
domain: str,
service: str,
service_data: Optional[Dict] = None,
service_data: dict | None = None,
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
context: Context | None = None,
limit: float | None = SERVICE_CALL_LIMIT,
target: dict | None = None,
) -> bool | None:
"""
Call a service.
@@ -1404,12 +1395,12 @@ class ServiceRegistry:
self,
domain: str,
service: str,
service_data: Optional[Dict] = None,
service_data: dict | None = None,
blocking: bool = False,
context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
target: Optional[Dict] = None,
) -> Optional[bool]:
context: Context | None = None,
limit: float | None = SERVICE_CALL_LIMIT,
target: dict | None = None,
) -> bool | None:
"""
Call a service.
@@ -1497,7 +1488,7 @@ class ServiceRegistry:
return False
def _run_service_in_background(
self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
self, coro_or_task: Coroutine | asyncio.Task, service_call: ServiceCall
) -> None:
"""Run service call in background, catching and logging any exceptions."""
@@ -1542,8 +1533,8 @@ class Config:
self.location_name: str = "Home"
self.time_zone: datetime.tzinfo = dt_util.UTC
self.units: UnitSystem = METRIC_SYSTEM
self.internal_url: Optional[str] = None
self.external_url: Optional[str] = None
self.internal_url: str | None = None
self.external_url: str | None = None
self.config_source: str = "default"
@@ -1551,22 +1542,22 @@ class Config:
self.skip_pip: bool = False
# List of loaded components
self.components: Set[str] = set()
self.components: set[str] = set()
# API (HTTP) server configuration, see components.http.ApiConfig
self.api: Optional[Any] = None
self.api: Any | None = None
# Directory that holds the configuration
self.config_dir: Optional[str] = None
self.config_dir: str | None = None
# List of allowed external dirs to access
self.allowlist_external_dirs: Set[str] = set()
self.allowlist_external_dirs: set[str] = set()
# List of allowed external URLs that integrations may use
self.allowlist_external_urls: Set[str] = set()
self.allowlist_external_urls: set[str] = set()
# Dictionary of Media folders that integrations may use
self.media_dirs: Dict[str, str] = {}
self.media_dirs: dict[str, str] = {}
# If Home Assistant is running in safe mode
self.safe_mode: bool = False
@@ -1574,7 +1565,7 @@ class Config:
# Use legacy template behavior
self.legacy_templates: bool = False
def distance(self, lat: float, lon: float) -> Optional[float]:
def distance(self, lat: float, lon: float) -> float | None:
"""Calculate distance from Home Assistant.
Async friendly.
@@ -1625,7 +1616,7 @@ class Config:
return False
def as_dict(self) -> Dict:
def as_dict(self) -> dict:
"""Create a dictionary representation of the configuration.
Async friendly.
@@ -1670,15 +1661,15 @@ class Config:
self,
*,
source: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
elevation: Optional[int] = None,
unit_system: Optional[str] = None,
location_name: Optional[str] = None,
time_zone: Optional[str] = None,
latitude: float | None = None,
longitude: float | None = None,
elevation: int | None = None,
unit_system: str | None = None,
location_name: str | None = None,
time_zone: str | None = None,
# pylint: disable=dangerous-default-value # _UNDEFs not modified
external_url: Optional[Union[str, dict]] = _UNDEF,
internal_url: Optional[Union[str, dict]] = _UNDEF,
external_url: str | dict | None = _UNDEF,
internal_url: str | dict | None = _UNDEF,
) -> None:
"""Update the configuration from a dictionary."""
self.config_source = source