Update typing 02 (#48014)

This commit is contained in:
Marc Mueller 2021-03-17 18:34:19 +01:00 committed by GitHub
parent 86d3baa34e
commit 6fb2e63e49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 717 additions and 706 deletions

View File

@ -1,6 +1,8 @@
"""Helper methods for components within Home Assistant."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any, Iterable, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Iterable, Sequence
from homeassistant.const import CONF_PLATFORM
@ -8,7 +10,7 @@ if TYPE_CHECKING:
from .typing import ConfigType
def config_per_platform(config: "ConfigType", domain: str) -> Iterable[Tuple[Any, Any]]:
def config_per_platform(config: "ConfigType", domain: str) -> Iterable[tuple[Any, Any]]:
"""Break a component config into different platforms.
For example, will find 'switch', 'switch 2', 'switch 3', .. etc

View File

@ -1,8 +1,10 @@
"""Helper for aiohttp webclient stuff."""
from __future__ import annotations
import asyncio
from ssl import SSLContext
import sys
from typing import Any, Awaitable, Optional, Union, cast
from typing import Any, Awaitable, cast
import aiohttp
from aiohttp import web
@ -87,7 +89,7 @@ async def async_aiohttp_proxy_web(
web_coro: Awaitable[aiohttp.ClientResponse],
buffer_size: int = 102400,
timeout: int = 10,
) -> Optional[web.StreamResponse]:
) -> web.StreamResponse | None:
"""Stream websession request to aiohttp web response."""
try:
with async_timeout.timeout(timeout):
@ -118,7 +120,7 @@ async def async_aiohttp_proxy_stream(
hass: HomeAssistantType,
request: web.BaseRequest,
stream: aiohttp.StreamReader,
content_type: Optional[str],
content_type: str | None,
buffer_size: int = 102400,
timeout: int = 10,
) -> web.StreamResponse:
@ -175,7 +177,7 @@ def _async_get_connector(
return cast(aiohttp.BaseConnector, hass.data[key])
if verify_ssl:
ssl_context: Union[bool, SSLContext] = ssl_util.client_context()
ssl_context: bool | SSLContext = ssl_util.client_context()
else:
ssl_context = False

View File

@ -1,6 +1,8 @@
"""Provide a way to connect devices to one physical location."""
from __future__ import annotations
from collections import OrderedDict
from typing import Container, Dict, Iterable, List, MutableMapping, Optional, cast
from typing import Container, Iterable, MutableMapping, cast
import attr
@ -26,7 +28,7 @@ class AreaEntry:
name: str = attr.ib()
normalized_name: str = attr.ib()
id: Optional[str] = attr.ib(default=None)
id: str | None = attr.ib(default=None)
def generate_id(self, existing_ids: Container[str]) -> None:
"""Initialize ID."""
@ -46,15 +48,15 @@ class AreaRegistry:
self.hass = hass
self.areas: MutableMapping[str, AreaEntry] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._normalized_name_area_idx: Dict[str, str] = {}
self._normalized_name_area_idx: dict[str, str] = {}
@callback
def async_get_area(self, area_id: str) -> Optional[AreaEntry]:
def async_get_area(self, area_id: str) -> AreaEntry | None:
"""Get area by id."""
return self.areas.get(area_id)
@callback
def async_get_area_by_name(self, name: str) -> Optional[AreaEntry]:
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name."""
normalized_name = normalize_area_name(name)
if normalized_name not in self._normalized_name_area_idx:
@ -171,7 +173,7 @@ class AreaRegistry:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> Dict[str, List[Dict[str, Optional[str]]]]:
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
"""Return data of area registry to store in a file."""
data = {}

View File

@ -5,7 +5,7 @@ from collections import OrderedDict
import logging
import os
from pathlib import Path
from typing import List, NamedTuple, Optional
from typing import NamedTuple
import voluptuous as vol
@ -35,8 +35,8 @@ class CheckConfigError(NamedTuple):
"""Configuration check error."""
message: str
domain: Optional[str]
config: Optional[ConfigType]
domain: str | None
config: ConfigType | None
class HomeAssistantConfig(OrderedDict):
@ -45,13 +45,13 @@ class HomeAssistantConfig(OrderedDict):
def __init__(self) -> None:
"""Initialize HA config."""
super().__init__()
self.errors: List[CheckConfigError] = []
self.errors: list[CheckConfigError] = []
def add_error(
self,
message: str,
domain: Optional[str] = None,
config: Optional[ConfigType] = None,
domain: str | None = None,
config: ConfigType | None = None,
) -> HomeAssistantConfig:
"""Add a single error."""
self.errors.append(CheckConfigError(str(message), domain, config))

View File

@ -1,9 +1,11 @@
"""Helper to deal with YAML + storage."""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from dataclasses import dataclass
import logging
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast
from typing import Any, Awaitable, Callable, Iterable, Optional, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -72,9 +74,9 @@ class IDManager:
def __init__(self) -> None:
"""Initiate the ID manager."""
self.collections: List[Dict[str, Any]] = []
self.collections: list[dict[str, Any]] = []
def add_collection(self, collection: Dict[str, Any]) -> None:
def add_collection(self, collection: dict[str, Any]) -> None:
"""Add a collection to check for ID usage."""
self.collections.append(collection)
@ -98,17 +100,17 @@ class IDManager:
class ObservableCollection(ABC):
"""Base collection type that can be observed."""
def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None):
def __init__(self, logger: logging.Logger, id_manager: IDManager | None = None):
"""Initialize the base collection."""
self.logger = logger
self.id_manager = id_manager or IDManager()
self.data: Dict[str, dict] = {}
self.listeners: List[ChangeListener] = []
self.data: dict[str, dict] = {}
self.listeners: list[ChangeListener] = []
self.id_manager.add_collection(self.data)
@callback
def async_items(self) -> List[dict]:
def async_items(self) -> list[dict]:
"""Return list of items in collection."""
return list(self.data.values())
@ -134,7 +136,7 @@ class ObservableCollection(ABC):
class YamlCollection(ObservableCollection):
"""Offer a collection based on static data."""
async def async_load(self, data: List[dict]) -> None:
async def async_load(self, data: list[dict]) -> None:
"""Load the YAML collection. Overrides existing data."""
old_ids = set(self.data)
@ -171,7 +173,7 @@ class StorageCollection(ObservableCollection):
self,
store: Store,
logger: logging.Logger,
id_manager: Optional[IDManager] = None,
id_manager: IDManager | None = None,
):
"""Initialize the storage collection."""
super().__init__(logger, id_manager)
@ -182,7 +184,7 @@ class StorageCollection(ObservableCollection):
"""Home Assistant object."""
return self.store.hass
async def _async_load_data(self) -> Optional[dict]:
async def _async_load_data(self) -> dict | None:
"""Load the data."""
return cast(Optional[dict], await self.store.async_load())
@ -274,7 +276,7 @@ class IDLessCollection(ObservableCollection):
counter = 0
async def async_load(self, data: List[dict]) -> None:
async def async_load(self, data: list[dict]) -> None:
"""Load the collection. Overrides existing data."""
await self.notify_changes(
[

View File

@ -1,4 +1,6 @@
"""Offer reusable conditions."""
from __future__ import annotations
import asyncio
from collections import deque
from contextlib import contextmanager
@ -7,7 +9,7 @@ import functools as ft
import logging
import re
import sys
from typing import Any, Callable, Container, Generator, List, Optional, Set, Union, cast
from typing import Any, Callable, Container, Generator, cast
from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import (
@ -124,7 +126,7 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
async def async_from_config(
hass: HomeAssistant,
config: Union[ConfigType, Template],
config: ConfigType | Template,
config_validation: bool = True,
) -> ConditionCheckerType:
"""Turn a condition configuration into a method.
@ -267,10 +269,10 @@ async def async_not_from_config(
def numeric_state(
hass: HomeAssistant,
entity: Union[None, str, State],
below: Optional[Union[float, str]] = None,
above: Optional[Union[float, str]] = None,
value_template: Optional[Template] = None,
entity: None | str | State,
below: float | str | None = None,
above: float | str | None = None,
value_template: Template | None = None,
variables: TemplateVarsType = None,
) -> bool:
"""Test a numeric state condition."""
@ -288,12 +290,12 @@ def numeric_state(
def async_numeric_state(
hass: HomeAssistant,
entity: Union[None, str, State],
below: Optional[Union[float, str]] = None,
above: Optional[Union[float, str]] = None,
value_template: Optional[Template] = None,
entity: None | str | State,
below: float | str | None = None,
above: float | str | None = None,
value_template: Template | None = None,
variables: TemplateVarsType = None,
attribute: Optional[str] = None,
attribute: str | None = None,
) -> bool:
"""Test a numeric state condition."""
if entity is None:
@ -456,10 +458,10 @@ def async_numeric_state_from_config(
def state(
hass: HomeAssistant,
entity: Union[None, str, State],
entity: None | str | State,
req_state: Any,
for_period: Optional[timedelta] = None,
attribute: Optional[str] = None,
for_period: timedelta | None = None,
attribute: str | None = None,
) -> bool:
"""Test if state matches requirements.
@ -526,7 +528,7 @@ def state_from_config(
if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config)
entity_ids = config.get(CONF_ENTITY_ID, [])
req_states: Union[str, List[str]] = config.get(CONF_STATE, [])
req_states: str | list[str] = config.get(CONF_STATE, [])
for_period = config.get("for")
attribute = config.get(CONF_ATTRIBUTE)
@ -560,10 +562,10 @@ def state_from_config(
def sun(
hass: HomeAssistant,
before: Optional[str] = None,
after: Optional[str] = None,
before_offset: Optional[timedelta] = None,
after_offset: Optional[timedelta] = None,
before: str | None = None,
after: str | None = None,
before_offset: timedelta | None = None,
after_offset: timedelta | None = None,
) -> bool:
"""Test if current time matches sun requirements."""
utcnow = dt_util.utcnow()
@ -673,9 +675,9 @@ def async_template_from_config(
def time(
hass: HomeAssistant,
before: Optional[Union[dt_util.dt.time, str]] = None,
after: Optional[Union[dt_util.dt.time, str]] = None,
weekday: Union[None, str, Container[str]] = None,
before: dt_util.dt.time | str | None = None,
after: dt_util.dt.time | str | None = None,
weekday: None | str | Container[str] = None,
) -> bool:
"""Test if local time condition matches.
@ -752,8 +754,8 @@ def time_from_config(
def zone(
hass: HomeAssistant,
zone_ent: Union[None, str, State],
entity: Union[None, str, State],
zone_ent: None | str | State,
entity: None | str | State,
) -> bool:
"""Test if zone-condition matches.
@ -858,8 +860,8 @@ async def async_device_from_config(
async def async_validate_condition_config(
hass: HomeAssistant, config: Union[ConfigType, Template]
) -> Union[ConfigType, Template]:
hass: HomeAssistant, config: ConfigType | Template
) -> ConfigType | Template:
"""Validate config."""
if isinstance(config, Template):
return config
@ -884,9 +886,9 @@ async def async_validate_condition_config(
@callback
def async_extract_entities(config: Union[ConfigType, Template]) -> Set[str]:
def async_extract_entities(config: ConfigType | Template) -> set[str]:
"""Extract entities from a condition."""
referenced: Set[str] = set()
referenced: set[str] = set()
to_process = deque([config])
while to_process:
@ -912,7 +914,7 @@ def async_extract_entities(config: Union[ConfigType, Template]) -> Set[str]:
@callback
def async_extract_devices(config: Union[ConfigType, Template]) -> Set[str]:
def async_extract_devices(config: ConfigType | Template) -> set[str]:
"""Extract devices from a condition."""
referenced = set()
to_process = deque([config])

View File

@ -1,5 +1,7 @@
"""Helpers for data entry flows for config entries."""
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from __future__ import annotations
from typing import Any, Awaitable, Callable, Union
from homeassistant import config_entries
@ -27,8 +29,8 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
self.CONNECTION_CLASS = connection_class # pylint: disable=invalid-name
async def async_step_user(
self, user_input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Handle a flow initialized by the user."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -38,8 +40,8 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
return await self.async_step_confirm()
async def async_step_confirm(
self, user_input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Confirm setup."""
if user_input is None:
self._set_confirm_only()
@ -68,8 +70,8 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
return self.async_create_entry(title=self._title, data={})
async def async_step_discovery(
self, discovery_info: Dict[str, Any]
) -> Dict[str, Any]:
self, discovery_info: dict[str, Any]
) -> dict[str, Any]:
"""Handle a flow initialized by discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -84,7 +86,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
async_step_homekit = async_step_discovery
async_step_dhcp = async_step_discovery
async def async_step_import(self, _: Optional[Dict[str, Any]]) -> Dict[str, Any]:
async def async_step_import(self, _: dict[str, Any] | None) -> dict[str, Any]:
"""Handle a flow initialized by import."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -133,8 +135,8 @@ class WebhookFlowHandler(config_entries.ConfigFlow):
self._allow_multiple = allow_multiple
async def async_step_user(
self, user_input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Handle a user initiated set up flow to create a webhook."""
if not self._allow_multiple and self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")

View File

@ -5,12 +5,14 @@ This module exists of the following parts:
- OAuth2 implementation that works with local provided client ID/secret
"""
from __future__ import annotations
from abc import ABC, ABCMeta, abstractmethod
import asyncio
import logging
import secrets
import time
from typing import Any, Awaitable, Callable, Dict, Optional, cast
from typing import Any, Awaitable, Callable, Dict, cast
from aiohttp import client, web
import async_timeout
@ -231,7 +233,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
return {}
async def async_step_pick_implementation(
self, user_input: Optional[dict] = None
self, user_input: dict | None = None
) -> dict:
"""Handle a flow start."""
implementations = await async_get_implementations(self.hass, self.DOMAIN)
@ -260,8 +262,8 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
)
async def async_step_auth(
self, user_input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Create an entry for auth."""
# Flow has been triggered by external data
if user_input:
@ -286,8 +288,8 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
return self.async_external_step(step_id="auth", url=url)
async def async_step_creation(
self, user_input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Create config entry from external data."""
token = await self.flow_impl.async_resolve_external_data(self.external_data)
# Force int for non-compliant oauth2 providers
@ -312,8 +314,8 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
return self.async_create_entry(title=self.flow_impl.name, data=data)
async def async_step_discovery(
self, discovery_info: Dict[str, Any]
) -> Dict[str, Any]:
self, discovery_info: dict[str, Any]
) -> dict[str, Any]:
"""Handle a flow initialized by discovery."""
await self.async_set_unique_id(self.DOMAIN)
@ -354,7 +356,7 @@ def async_register_implementation(
async def async_get_implementations(
hass: HomeAssistant, domain: str
) -> Dict[str, AbstractOAuth2Implementation]:
) -> dict[str, AbstractOAuth2Implementation]:
"""Return OAuth2 implementations for specified domain."""
registered = cast(
Dict[str, AbstractOAuth2Implementation],
@ -392,7 +394,7 @@ def async_add_implementation_provider(
hass: HomeAssistant,
provider_domain: str,
async_provide_implementation: Callable[
[HomeAssistant, str], Awaitable[Optional[AbstractOAuth2Implementation]]
[HomeAssistant, str], Awaitable[AbstractOAuth2Implementation | None]
],
) -> None:
"""Add an implementation provider.
@ -516,7 +518,7 @@ def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
@callback
def _decode_jwt(hass: HomeAssistant, encoded: str) -> Optional[dict]:
def _decode_jwt(hass: HomeAssistant, encoded: str) -> dict | None:
"""JWT encode data."""
secret = cast(str, hass.data.get(DATA_JWT_SECRET))

View File

@ -1,4 +1,6 @@
"""Helpers for config validation using voluptuous."""
from __future__ import annotations
from datetime import (
date as date_sys,
datetime as datetime_sys,
@ -12,19 +14,7 @@ from numbers import Number
import os
import re
from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed
from typing import (
Any,
Callable,
Dict,
Hashable,
List,
Optional,
Pattern,
Type,
TypeVar,
Union,
cast,
)
from typing import Any, Callable, Dict, Hashable, Pattern, TypeVar, cast
from urllib.parse import urlparse
from uuid import UUID
@ -131,7 +121,7 @@ def path(value: Any) -> str:
def has_at_least_one_key(*keys: str) -> Callable:
"""Validate that at least one key exists."""
def validate(obj: Dict) -> Dict:
def validate(obj: dict) -> dict:
"""Test keys exist in dict."""
if not isinstance(obj, dict):
raise vol.Invalid("expected dictionary")
@ -144,10 +134,10 @@ def has_at_least_one_key(*keys: str) -> Callable:
return validate
def has_at_most_one_key(*keys: str) -> Callable[[Dict], Dict]:
def has_at_most_one_key(*keys: str) -> Callable[[dict], dict]:
"""Validate that zero keys exist or one key exists."""
def validate(obj: Dict) -> Dict:
def validate(obj: dict) -> dict:
"""Test zero keys exist or one key exists in dict."""
if not isinstance(obj, dict):
raise vol.Invalid("expected dictionary")
@ -253,7 +243,7 @@ def isdir(value: Any) -> str:
return dir_in
def ensure_list(value: Union[T, List[T], None]) -> List[T]:
def ensure_list(value: T | list[T] | None) -> list[T]:
"""Wrap value in list if it is not one."""
if value is None:
return []
@ -269,7 +259,7 @@ def entity_id(value: Any) -> str:
raise vol.Invalid(f"Entity ID {value} is an invalid entity ID")
def entity_ids(value: Union[str, List]) -> List[str]:
def entity_ids(value: str | list) -> list[str]:
"""Validate Entity IDs."""
if value is None:
raise vol.Invalid("Entity IDs can not be None")
@ -284,7 +274,7 @@ comp_entity_ids = vol.Any(
)
def entity_domain(domain: Union[str, List[str]]) -> Callable[[Any], str]:
def entity_domain(domain: str | list[str]) -> Callable[[Any], str]:
"""Validate that entity belong to domain."""
ent_domain = entities_domain(domain)
@ -298,9 +288,7 @@ def entity_domain(domain: Union[str, List[str]]) -> Callable[[Any], str]:
return validate
def entities_domain(
domain: Union[str, List[str]]
) -> Callable[[Union[str, List]], List[str]]:
def entities_domain(domain: str | list[str]) -> Callable[[str | list], list[str]]:
"""Validate that entities belong to domain."""
if isinstance(domain, str):
@ -312,7 +300,7 @@ def entities_domain(
def check_invalid(val: str) -> bool:
return val not in domain
def validate(values: Union[str, List]) -> List[str]:
def validate(values: str | list) -> list[str]:
"""Test if entity domain is domain."""
values = entity_ids(values)
for ent_id in values:
@ -325,7 +313,7 @@ def entities_domain(
return validate
def enum(enumClass: Type[Enum]) -> vol.All:
def enum(enumClass: type[Enum]) -> vol.All:
"""Create validator for specified enum."""
return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__)
@ -423,7 +411,7 @@ def time_period_str(value: str) -> timedelta:
return offset
def time_period_seconds(value: Union[float, str]) -> timedelta:
def time_period_seconds(value: float | str) -> timedelta:
"""Validate and transform seconds to a time offset."""
try:
return timedelta(seconds=float(value))
@ -450,7 +438,7 @@ positive_time_period_dict = vol.All(time_period_dict, positive_timedelta)
positive_time_period = vol.All(time_period, positive_timedelta)
def remove_falsy(value: List[T]) -> List[T]:
def remove_falsy(value: list[T]) -> list[T]:
"""Remove falsy values from a list."""
return [v for v in value if v]
@ -477,7 +465,7 @@ def slug(value: Any) -> str:
def schema_with_slug_keys(
value_schema: Union[T, Callable], *, slug_validator: Callable[[Any], str] = slug
value_schema: T | Callable, *, slug_validator: Callable[[Any], str] = slug
) -> Callable:
"""Ensure dicts have slugs as keys.
@ -486,7 +474,7 @@ def schema_with_slug_keys(
"""
schema = vol.Schema({str: value_schema})
def verify(value: Dict) -> Dict:
def verify(value: dict) -> dict:
"""Validate all keys are slugs and then the value_schema."""
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")
@ -547,7 +535,7 @@ unit_system = vol.All(
)
def template(value: Optional[Any]) -> template_helper.Template:
def template(value: Any | None) -> template_helper.Template:
"""Validate a jinja2 template."""
if value is None:
raise vol.Invalid("template value is None")
@ -563,7 +551,7 @@ def template(value: Optional[Any]) -> template_helper.Template:
raise vol.Invalid(f"invalid template ({ex})") from ex
def dynamic_template(value: Optional[Any]) -> template_helper.Template:
def dynamic_template(value: Any | None) -> template_helper.Template:
"""Validate a dynamic (non static) jinja2 template."""
if value is None:
raise vol.Invalid("template value is None")
@ -632,7 +620,7 @@ def time_zone(value: str) -> str:
weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)])
def socket_timeout(value: Optional[Any]) -> object:
def socket_timeout(value: Any | None) -> object:
"""Validate timeout float > 0.0.
None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object.
@ -681,7 +669,7 @@ def uuid4_hex(value: Any) -> str:
return result.hex
def ensure_list_csv(value: Any) -> List:
def ensure_list_csv(value: Any) -> list:
"""Ensure that input is a list or make one from comma-separated string."""
if isinstance(value, str):
return [member.strip() for member in value.split(",")]
@ -709,9 +697,9 @@ class multi_select:
def deprecated(
key: str,
replacement_key: Optional[str] = None,
default: Optional[Any] = None,
) -> Callable[[Dict], Dict]:
replacement_key: str | None = None,
default: Any | None = None,
) -> Callable[[dict], dict]:
"""
Log key as deprecated and provide a replacement (if exists).
@ -743,7 +731,7 @@ def deprecated(
" please remove it from your configuration"
)
def validator(config: Dict) -> Dict:
def validator(config: dict) -> dict:
"""Check if key is in config and log warning."""
if key in config:
try:
@ -781,14 +769,14 @@ def deprecated(
def key_value_schemas(
key: str, value_schemas: Dict[str, vol.Schema]
) -> Callable[[Any], Dict[str, Any]]:
key: str, value_schemas: dict[str, vol.Schema]
) -> Callable[[Any], dict[str, Any]]:
"""Create a validator that validates based on a value for specific key.
This gives better error messages.
"""
def key_value_validator(value: Any) -> Dict[str, Any]:
def key_value_validator(value: Any) -> dict[str, Any]:
if not isinstance(value, dict):
raise vol.Invalid("Expected a dictionary")
@ -809,10 +797,10 @@ def key_value_schemas(
def key_dependency(
key: Hashable, dependency: Hashable
) -> Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]:
) -> Callable[[dict[Hashable, Any]], dict[Hashable, Any]]:
"""Validate that all dependencies exist for key."""
def validator(value: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
def validator(value: dict[Hashable, Any]) -> dict[Hashable, Any]:
"""Test dependencies."""
if not isinstance(value, dict):
raise vol.Invalid("key dependencies require a dict")
@ -1247,7 +1235,7 @@ def determine_script_action(action: dict) -> str:
return SCRIPT_ACTION_CALL_SERVICE
ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_CALL_SERVICE: SERVICE_SCHEMA,
SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA,
SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA,

View File

@ -1,6 +1,7 @@
"""Helpers for the data entry flow."""
from __future__ import annotations
from typing import Any, Dict
from typing import Any
from aiohttp import web
import voluptuous as vol
@ -20,7 +21,7 @@ class _BaseFlowManagerView(HomeAssistantView):
self._flow_mgr = flow_mgr
# pylint: disable=no-self-use
def _prepare_result_json(self, result: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_result_json(self, result: dict[str, Any]) -> dict[str, Any]:
"""Convert result to JSON."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
data = result.copy()
@ -58,7 +59,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
extra=vol.ALLOW_EXTRA,
)
)
async def post(self, request: web.Request, data: Dict[str, Any]) -> web.Response:
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Handle a POST request."""
if isinstance(data["handler"], list):
handler = tuple(data["handler"])
@ -99,7 +100,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(
self, request: web.Request, flow_id: str, data: Dict[str, Any]
self, request: web.Request, flow_id: str, data: dict[str, Any]
) -> web.Response:
"""Handle a POST request."""
try:

View File

@ -1,7 +1,9 @@
"""Debounce helper."""
from __future__ import annotations
import asyncio
from logging import Logger
from typing import Any, Awaitable, Callable, Optional
from typing import Any, Awaitable, Callable
from homeassistant.core import HassJob, HomeAssistant, callback
@ -16,7 +18,7 @@ class Debouncer:
*,
cooldown: float,
immediate: bool,
function: Optional[Callable[..., Awaitable[Any]]] = None,
function: Callable[..., Awaitable[Any]] | None = None,
):
"""Initialize debounce.
@ -29,13 +31,13 @@ class Debouncer:
self._function = function
self.cooldown = cooldown
self.immediate = immediate
self._timer_task: Optional[asyncio.TimerHandle] = None
self._timer_task: asyncio.TimerHandle | None = None
self._execute_at_end_of_timer: bool = False
self._execute_lock = asyncio.Lock()
self._job: Optional[HassJob] = None if function is None else HassJob(function)
self._job: HassJob | None = None if function is None else HassJob(function)
@property
def function(self) -> Optional[Callable[..., Awaitable[Any]]]:
def function(self) -> Callable[..., Awaitable[Any]] | None:
"""Return the function being wrapped by the Debouncer."""
return self._function

View File

@ -1,8 +1,10 @@
"""Deprecation helpers for Home Assistant."""
from __future__ import annotations
import functools
import inspect
import logging
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable
from ..helpers.frame import MissingIntegrationFrame, get_integration_frame
@ -49,8 +51,8 @@ def deprecated_substitute(substitute_name: str) -> Callable[..., Callable]:
def get_deprecated(
config: Dict[str, Any], new_name: str, old_name: str, default: Optional[Any] = None
) -> Optional[Any]:
config: dict[str, Any], new_name: str, old_name: str, default: Any | None = None
) -> Any | None:
"""Allow an old config name to be deprecated with a replacement.
If the new config isn't found, but the old one is, the old value is used
@ -85,7 +87,7 @@ def deprecated_function(replacement: str) -> Callable[..., Callable]:
"""Decorate function as deprecated."""
@functools.wraps(func)
def deprecated_func(*args: tuple, **kwargs: Dict[str, Any]) -> Any:
def deprecated_func(*args: tuple, **kwargs: dict[str, Any]) -> Any:
"""Wrap for the original function."""
logger = logging.getLogger(func.__module__)
try:

View File

@ -1,8 +1,10 @@
"""Provide a way to connect entities belonging to one device."""
from __future__ import annotations
from collections import OrderedDict
import logging
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, cast
import attr
@ -50,21 +52,21 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
class DeviceEntry:
"""Device Registry Entry."""
config_entries: Set[str] = attr.ib(converter=set, factory=set)
connections: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set)
identifiers: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set)
manufacturer: Optional[str] = attr.ib(default=None)
model: Optional[str] = attr.ib(default=None)
name: Optional[str] = attr.ib(default=None)
sw_version: Optional[str] = attr.ib(default=None)
via_device_id: Optional[str] = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None)
name_by_user: Optional[str] = attr.ib(default=None)
entry_type: Optional[str] = attr.ib(default=None)
config_entries: set[str] = attr.ib(converter=set, factory=set)
connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
identifiers: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
manufacturer: str | None = attr.ib(default=None)
model: str | None = attr.ib(default=None)
name: str | None = attr.ib(default=None)
sw_version: str | None = attr.ib(default=None)
via_device_id: str | None = attr.ib(default=None)
area_id: str | None = attr.ib(default=None)
name_by_user: str | None = attr.ib(default=None)
entry_type: str | None = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex)
# This value is not stored, just used to keep track of events to fire.
is_new: bool = attr.ib(default=False)
disabled_by: Optional[str] = attr.ib(
disabled_by: str | None = attr.ib(
default=None,
validator=attr.validators.in_(
(
@ -75,7 +77,7 @@ class DeviceEntry:
)
),
)
suggested_area: Optional[str] = attr.ib(default=None)
suggested_area: str | None = attr.ib(default=None)
@property
def disabled(self) -> bool:
@ -87,17 +89,17 @@ class DeviceEntry:
class DeletedDeviceEntry:
"""Deleted Device Registry Entry."""
config_entries: Set[str] = attr.ib()
connections: Set[Tuple[str, str]] = attr.ib()
identifiers: Set[Tuple[str, str]] = attr.ib()
config_entries: set[str] = attr.ib()
connections: set[tuple[str, str]] = attr.ib()
identifiers: set[tuple[str, str]] = attr.ib()
id: str = attr.ib()
orphaned_timestamp: Optional[float] = attr.ib()
orphaned_timestamp: float | None = attr.ib()
def to_device_entry(
self,
config_entry_id: str,
connections: Set[Tuple[str, str]],
identifiers: Set[Tuple[str, str]],
connections: set[tuple[str, str]],
identifiers: set[tuple[str, str]],
) -> DeviceEntry:
"""Create DeviceEntry from DeletedDeviceEntry."""
return DeviceEntry(
@ -133,9 +135,9 @@ def format_mac(mac: str) -> str:
class DeviceRegistry:
"""Class to hold a registry of devices."""
devices: Dict[str, DeviceEntry]
deleted_devices: Dict[str, DeletedDeviceEntry]
_devices_index: Dict[str, Dict[str, Dict[Tuple[str, str], str]]]
devices: dict[str, DeviceEntry]
deleted_devices: dict[str, DeletedDeviceEntry]
_devices_index: dict[str, dict[str, dict[tuple[str, str], str]]]
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the device registry."""
@ -144,16 +146,16 @@ class DeviceRegistry:
self._clear_index()
@callback
def async_get(self, device_id: str) -> Optional[DeviceEntry]:
def async_get(self, device_id: str) -> DeviceEntry | None:
"""Get device."""
return self.devices.get(device_id)
@callback
def async_get_device(
self,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]] = None,
) -> Optional[DeviceEntry]:
identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None = None,
) -> DeviceEntry | None:
"""Check if device is registered."""
device_id = self._async_get_device_id_from_index(
REGISTERED_DEVICE, identifiers, connections
@ -164,9 +166,9 @@ class DeviceRegistry:
def _async_get_deleted_device(
self,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]],
) -> Optional[DeletedDeviceEntry]:
identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None,
) -> DeletedDeviceEntry | None:
"""Check if device is deleted."""
device_id = self._async_get_device_id_from_index(
DELETED_DEVICE, identifiers, connections
@ -178,9 +180,9 @@ class DeviceRegistry:
def _async_get_device_id_from_index(
self,
index: str,
identifiers: Set[Tuple[str, str]],
connections: Optional[Set[Tuple[str, str]]],
) -> Optional[str]:
identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None,
) -> str | None:
"""Check if device has previously been registered."""
devices_index = self._devices_index[index]
for identifier in identifiers:
@ -193,7 +195,7 @@ class DeviceRegistry:
return devices_index[IDX_CONNECTIONS][connection]
return None
def _add_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None:
def _add_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Add a device and index it."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._devices_index[DELETED_DEVICE]
@ -204,7 +206,7 @@ class DeviceRegistry:
_add_device_to_index(devices_index, device)
def _remove_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None:
def _remove_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Remove a device and remove it from the index."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._devices_index[DELETED_DEVICE]
@ -243,21 +245,21 @@ class DeviceRegistry:
self,
*,
config_entry_id: str,
connections: Optional[Set[Tuple[str, str]]] = None,
identifiers: Optional[Set[Tuple[str, str]]] = None,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
default_manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
default_model: Union[str, None, UndefinedType] = UNDEFINED,
default_name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device: Optional[Tuple[str, str]] = None,
connections: set[tuple[str, str]] | None = None,
identifiers: set[tuple[str, str]] | None = None,
manufacturer: str | None | UndefinedType = UNDEFINED,
model: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
default_manufacturer: str | None | UndefinedType = UNDEFINED,
default_model: str | None | UndefinedType = UNDEFINED,
default_name: str | None | UndefinedType = UNDEFINED,
sw_version: str | None | UndefinedType = UNDEFINED,
entry_type: str | None | UndefinedType = UNDEFINED,
via_device: tuple[str, str] | None = None,
# To disable a device if it gets created
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
disabled_by: str | None | UndefinedType = UNDEFINED,
suggested_area: str | None | UndefinedType = UNDEFINED,
) -> DeviceEntry | None:
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
return None
@ -294,7 +296,7 @@ class DeviceRegistry:
if via_device is not None:
via = self.async_get_device({via_device})
via_device_id: Union[str, UndefinedType] = via.id if via else UNDEFINED
via_device_id: str | UndefinedType = via.id if via else UNDEFINED
else:
via_device_id = UNDEFINED
@ -318,18 +320,18 @@ class DeviceRegistry:
self,
device_id: str,
*,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
area_id: str | None | UndefinedType = UNDEFINED,
manufacturer: str | None | UndefinedType = UNDEFINED,
model: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
name_by_user: str | None | UndefinedType = UNDEFINED,
new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
sw_version: str | None | UndefinedType = UNDEFINED,
via_device_id: str | None | UndefinedType = UNDEFINED,
remove_config_entry_id: str | UndefinedType = UNDEFINED,
disabled_by: str | None | UndefinedType = UNDEFINED,
suggested_area: str | None | UndefinedType = UNDEFINED,
) -> DeviceEntry | None:
"""Update properties of a device."""
return self._async_update_device(
device_id,
@ -351,26 +353,26 @@ class DeviceRegistry:
self,
device_id: str,
*,
add_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
merge_connections: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
merge_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
new_identifiers: Union[Set[Tuple[str, str]], UndefinedType] = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
add_config_entry_id: str | UndefinedType = UNDEFINED,
remove_config_entry_id: str | UndefinedType = UNDEFINED,
merge_connections: set[tuple[str, str]] | UndefinedType = UNDEFINED,
merge_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
manufacturer: str | None | UndefinedType = UNDEFINED,
model: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
sw_version: str | None | UndefinedType = UNDEFINED,
entry_type: str | None | UndefinedType = UNDEFINED,
via_device_id: str | None | UndefinedType = UNDEFINED,
area_id: str | None | UndefinedType = UNDEFINED,
name_by_user: str | None | UndefinedType = UNDEFINED,
disabled_by: str | None | UndefinedType = UNDEFINED,
suggested_area: str | None | UndefinedType = UNDEFINED,
) -> DeviceEntry | None:
"""Update device attributes."""
old = self.devices[device_id]
changes: Dict[str, Any] = {}
changes: dict[str, Any] = {}
config_entries = old.config_entries
@ -529,7 +531,7 @@ class DeviceRegistry:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]:
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
"""Return data of device registry to store in a file."""
data = {}
@ -637,7 +639,7 @@ async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry:
@callback
def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> List[DeviceEntry]:
def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> list[DeviceEntry]:
"""Return entries that match an area."""
return [device for device in registry.devices.values() if device.area_id == area_id]
@ -645,7 +647,7 @@ def async_entries_for_area(registry: DeviceRegistry, area_id: str) -> List[Devic
@callback
def async_entries_for_config_entry(
registry: DeviceRegistry, config_entry_id: str
) -> List[DeviceEntry]:
) -> list[DeviceEntry]:
"""Return entries that match a config entry."""
return [
device
@ -769,7 +771,7 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean)
def _normalize_connections(connections: Set[Tuple[str, str]]) -> Set[Tuple[str, str]]:
def _normalize_connections(connections: set[tuple[str, str]]) -> set[tuple[str, str]]:
"""Normalize connections to ensure we can match mac addresses."""
return {
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
@ -778,8 +780,8 @@ def _normalize_connections(connections: Set[Tuple[str, str]]) -> Set[Tuple[str,
def _add_device_to_index(
devices_index: Dict[str, Dict[Tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry],
devices_index: dict[str, dict[tuple[str, str], str]],
device: DeviceEntry | DeletedDeviceEntry,
) -> None:
"""Add a device to the index."""
for identifier in device.identifiers:
@ -789,8 +791,8 @@ def _add_device_to_index(
def _remove_device_from_index(
devices_index: Dict[str, Dict[Tuple[str, str], str]],
device: Union[DeviceEntry, DeletedDeviceEntry],
devices_index: dict[str, dict[tuple[str, str], str]],
device: DeviceEntry | DeletedDeviceEntry,
) -> None:
"""Remove a device from the index."""
for identifier in device.identifiers:

View File

@ -5,7 +5,9 @@ There are two different types of discoveries that can be fired/listened for.
- listen_platform/discover_platform is for platforms. These are used by
components to allow discovery of their platforms.
"""
from typing import Any, Callable, Dict, Optional, TypedDict
from __future__ import annotations
from typing import Any, Callable, TypedDict
from homeassistant import core, setup
from homeassistant.core import CALLBACK_TYPE
@ -26,8 +28,8 @@ class DiscoveryDict(TypedDict):
"""Discovery data."""
service: str
platform: Optional[str]
discovered: Optional[DiscoveryInfoType]
platform: str | None
discovered: DiscoveryInfoType | None
@core.callback
@ -76,8 +78,8 @@ def discover(
async def async_discover(
hass: core.HomeAssistant,
service: str,
discovered: Optional[DiscoveryInfoType],
component: Optional[str],
discovered: DiscoveryInfoType | None,
component: str | None,
hass_config: ConfigType,
) -> None:
"""Fire discovery event. Can ensure a component is loaded."""
@ -97,7 +99,7 @@ async def async_discover(
def async_listen_platform(
hass: core.HomeAssistant,
component: str,
callback: Callable[[str, Optional[Dict[str, Any]]], Any],
callback: Callable[[str, dict[str, Any] | None], Any],
) -> None:
"""Register a platform loader listener.

View File

@ -1,11 +1,13 @@
"""An abstract class for entities."""
from __future__ import annotations
from abc import ABC
import asyncio
from datetime import datetime, timedelta
import functools as ft
import logging
from timeit import default_timer as timer
from typing import Any, Awaitable, Dict, Iterable, List, Optional
from typing import Any, Awaitable, Iterable
from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.const import (
@ -42,16 +44,16 @@ SOURCE_PLATFORM_CONFIG = "platform_config"
@callback
@bind_hass
def entity_sources(hass: HomeAssistant) -> Dict[str, Dict[str, str]]:
def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]:
"""Get the entity sources."""
return hass.data.get(DATA_ENTITY_SOURCE, {})
def generate_entity_id(
entity_id_format: str,
name: Optional[str],
current_ids: Optional[List[str]] = None,
hass: Optional[HomeAssistant] = None,
name: str | None,
current_ids: list[str] | None = None,
hass: HomeAssistant | None = None,
) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs."""
return async_generate_entity_id(entity_id_format, name, current_ids, hass)
@ -60,9 +62,9 @@ def generate_entity_id(
@callback
def async_generate_entity_id(
entity_id_format: str,
name: Optional[str],
current_ids: Optional[Iterable[str]] = None,
hass: Optional[HomeAssistant] = None,
name: str | None,
current_ids: Iterable[str] | None = None,
hass: HomeAssistant | None = None,
) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower()
@ -98,7 +100,7 @@ class Entity(ABC):
hass: HomeAssistant = None # type: ignore
# Owning platform instance. Will be set by EntityPlatform
platform: Optional[EntityPlatform] = None
platform: EntityPlatform | None = None
# If we reported if this entity was slow
_slow_reported = False
@ -110,17 +112,17 @@ class Entity(ABC):
_update_staged = False
# Process updates in parallel
parallel_updates: Optional[asyncio.Semaphore] = None
parallel_updates: asyncio.Semaphore | None = None
# Entry in the entity registry
registry_entry: Optional[RegistryEntry] = None
registry_entry: RegistryEntry | None = None
# Hold list for functions to call on remove.
_on_remove: Optional[List[CALLBACK_TYPE]] = None
_on_remove: list[CALLBACK_TYPE] | None = None
# Context
_context: Optional[Context] = None
_context_set: Optional[datetime] = None
_context: Context | None = None
_context_set: datetime | None = None
# If entity is added to an entity platform
_added = False
@ -134,12 +136,12 @@ class Entity(ABC):
return True
@property
def unique_id(self) -> Optional[str]:
def unique_id(self) -> str | None:
"""Return a unique ID."""
return None
@property
def name(self) -> Optional[str]:
def name(self) -> str | None:
"""Return the name of the entity."""
return None
@ -149,7 +151,7 @@ class Entity(ABC):
return STATE_UNKNOWN
@property
def capability_attributes(self) -> Optional[Dict[str, Any]]:
def capability_attributes(self) -> dict[str, Any] | None:
"""Return the capability attributes.
Attributes that explain the capabilities of an entity.
@ -160,7 +162,7 @@ class Entity(ABC):
return None
@property
def state_attributes(self) -> Optional[Dict[str, Any]]:
def state_attributes(self) -> dict[str, Any] | None:
"""Return the state attributes.
Implemented by component base class, should not be extended by integrations.
@ -169,7 +171,7 @@ class Entity(ABC):
return None
@property
def device_state_attributes(self) -> Optional[Dict[str, Any]]:
def device_state_attributes(self) -> dict[str, Any] | None:
"""Return entity specific state attributes.
This method is deprecated, platform classes should implement
@ -178,7 +180,7 @@ class Entity(ABC):
return None
@property
def extra_state_attributes(self) -> Optional[Dict[str, Any]]:
def extra_state_attributes(self) -> dict[str, Any] | None:
"""Return entity specific state attributes.
Implemented by platform classes. Convention for attribute names
@ -187,7 +189,7 @@ class Entity(ABC):
return None
@property
def device_info(self) -> Optional[Dict[str, Any]]:
def device_info(self) -> dict[str, Any] | None:
"""Return device specific attributes.
Implemented by platform classes.
@ -195,22 +197,22 @@ class Entity(ABC):
return None
@property
def device_class(self) -> Optional[str]:
def device_class(self) -> str | None:
"""Return the class of this device, from component DEVICE_CLASSES."""
return None
@property
def unit_of_measurement(self) -> Optional[str]:
def unit_of_measurement(self) -> str | None:
"""Return the unit of measurement of this entity, if any."""
return None
@property
def icon(self) -> Optional[str]:
def icon(self) -> str | None:
"""Return the icon to use in the frontend, if any."""
return None
@property
def entity_picture(self) -> Optional[str]:
def entity_picture(self) -> str | None:
"""Return the entity picture to use in the frontend, if any."""
return None
@ -234,7 +236,7 @@ class Entity(ABC):
return False
@property
def supported_features(self) -> Optional[int]:
def supported_features(self) -> int | None:
"""Flag supported features."""
return None
@ -516,7 +518,7 @@ class Entity(ABC):
self,
hass: HomeAssistant,
platform: EntityPlatform,
parallel_updates: Optional[asyncio.Semaphore],
parallel_updates: asyncio.Semaphore | None,
) -> None:
"""Start adding an entity to a platform."""
if self._added:

View File

@ -1,10 +1,12 @@
"""Helpers for components that manage entities."""
from __future__ import annotations
import asyncio
from datetime import timedelta
from itertools import chain
import logging
from types import ModuleType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Iterable
import voluptuous as vol
@ -76,10 +78,10 @@ class EntityComponent:
self.domain = domain
self.scan_interval = scan_interval
self.config: Optional[ConfigType] = None
self.config: ConfigType | None = None
self._platforms: Dict[
Union[str, Tuple[str, Optional[timedelta], Optional[str]]], EntityPlatform
self._platforms: dict[
str | tuple[str, timedelta | None, str | None], EntityPlatform
] = {domain: self._async_init_entity_platform(domain, None)}
self.async_add_entities = self._platforms[domain].async_add_entities
self.add_entities = self._platforms[domain].add_entities
@ -93,7 +95,7 @@ class EntityComponent:
platform.entities.values() for platform in self._platforms.values()
)
def get_entity(self, entity_id: str) -> Optional[entity.Entity]:
def get_entity(self, entity_id: str) -> entity.Entity | None:
"""Get an entity."""
for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id)
@ -125,7 +127,7 @@ class EntityComponent:
# Generic discovery listener for loading platform dynamically
# Refer to: homeassistant.helpers.discovery.async_load_platform()
async def component_platform_discovered(
platform: str, info: Optional[Dict[str, Any]]
platform: str, info: dict[str, Any] | None
) -> None:
"""Handle the loading of a platform."""
await self.async_setup_platform(platform, {}, info)
@ -176,7 +178,7 @@ class EntityComponent:
async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True
) -> List[entity.Entity]:
) -> list[entity.Entity]:
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
@ -191,9 +193,9 @@ class EntityComponent:
def async_register_entity_service(
self,
name: str,
schema: Union[Dict[str, Any], vol.Schema],
func: Union[str, Callable[..., Any]],
required_features: Optional[List[int]] = None,
schema: dict[str, Any] | vol.Schema,
func: str | Callable[..., Any],
required_features: list[int] | None = None,
) -> None:
"""Register an entity service."""
if isinstance(schema, dict):
@ -211,7 +213,7 @@ class EntityComponent:
self,
platform_type: str,
platform_config: ConfigType,
discovery_info: Optional[DiscoveryInfoType] = None,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up a platform for this component."""
if self.config is None:
@ -274,7 +276,7 @@ class EntityComponent:
async def async_prepare_reload(
self, *, skip_reset: bool = False
) -> Optional[ConfigType]:
) -> ConfigType | None:
"""Prepare reloading this entity component.
This method must be run in the event loop.
@ -303,9 +305,9 @@ class EntityComponent:
def _async_init_entity_platform(
self,
platform_type: str,
platform: Optional[ModuleType],
scan_interval: Optional[timedelta] = None,
entity_namespace: Optional[str] = None,
platform: ModuleType | None,
scan_interval: timedelta | None = None,
entity_namespace: str | None = None,
) -> EntityPlatform:
"""Initialize an entity platform."""
if scan_interval is None:

View File

@ -6,7 +6,7 @@ from contextvars import ContextVar
from datetime import datetime, timedelta
from logging import Logger
from types import ModuleType
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Callable, Coroutine, Iterable
from homeassistant import config_entries
from homeassistant.const import ATTR_RESTORED, DEVICE_DEFAULT_NAME
@ -49,9 +49,9 @@ class EntityPlatform:
logger: Logger,
domain: str,
platform_name: str,
platform: Optional[ModuleType],
platform: ModuleType | None,
scan_interval: timedelta,
entity_namespace: Optional[str],
entity_namespace: str | None,
):
"""Initialize the entity platform."""
self.hass = hass
@ -61,18 +61,18 @@ class EntityPlatform:
self.platform = platform
self.scan_interval = scan_interval
self.entity_namespace = entity_namespace
self.config_entry: Optional[config_entries.ConfigEntry] = None
self.entities: Dict[str, Entity] = {}
self._tasks: List[asyncio.Future] = []
self.config_entry: config_entries.ConfigEntry | None = None
self.entities: dict[str, Entity] = {}
self._tasks: list[asyncio.Future] = []
# Stop tracking tasks after setup is completed
self._setup_complete = False
# Method to cancel the state change listener
self._async_unsub_polling: Optional[CALLBACK_TYPE] = None
self._async_unsub_polling: CALLBACK_TYPE | None = None
# Method to cancel the retry of setup
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None
self._process_updates: Optional[asyncio.Lock] = None
self._async_cancel_retry_setup: CALLBACK_TYPE | None = None
self._process_updates: asyncio.Lock | None = None
self.parallel_updates: Optional[asyncio.Semaphore] = None
self.parallel_updates: asyncio.Semaphore | None = None
# Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities
@ -89,7 +89,7 @@ class EntityPlatform:
@callback
def _get_parallel_updates_semaphore(
self, entity_has_async_update: bool
) -> Optional[asyncio.Semaphore]:
) -> asyncio.Semaphore | None:
"""Get or create a semaphore for parallel updates.
Semaphore will be created on demand because we base it off if update method is async or not.
@ -364,7 +364,7 @@ class EntityPlatform:
return
requested_entity_id = None
suggested_object_id: Optional[str] = None
suggested_object_id: str | None = None
# Get entity_id from unique ID registration
if entity.unique_id is not None:
@ -378,7 +378,7 @@ class EntityPlatform:
suggested_object_id = f"{self.entity_namespace} {suggested_object_id}"
if self.config_entry is not None:
config_entry_id: Optional[str] = self.config_entry.entry_id
config_entry_id: str | None = self.config_entry.entry_id
else:
config_entry_id = None
@ -408,7 +408,7 @@ class EntityPlatform:
if device:
device_id = device.id
disabled_by: Optional[str] = None
disabled_by: str | None = None
if not entity.entity_registry_enabled_default:
disabled_by = DISABLED_INTEGRATION
@ -550,7 +550,7 @@ class EntityPlatform:
async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True
) -> List[Entity]:
) -> list[Entity]:
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
@ -621,7 +621,7 @@ class EntityPlatform:
await asyncio.gather(*tasks)
current_platform: ContextVar[Optional[EntityPlatform]] = ContextVar(
current_platform: ContextVar[EntityPlatform | None] = ContextVar(
"current_platform", default=None
)
@ -629,7 +629,7 @@ current_platform: ContextVar[Optional[EntityPlatform]] = ContextVar(
@callback
def async_get_platforms(
hass: HomeAssistantType, integration_name: str
) -> List[EntityPlatform]:
) -> list[EntityPlatform]:
"""Find existing platforms."""
if (
DATA_ENTITY_PLATFORM not in hass.data
@ -637,6 +637,6 @@ def async_get_platforms(
):
return []
platforms: List[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name]
platforms: list[EntityPlatform] = hass.data[DATA_ENTITY_PLATFORM][integration_name]
return platforms

View File

@ -7,20 +7,11 @@ The Entity Registry will persist itself 10 seconds after a new entity is
registered. Registering a new entity while a timer is in progress resets the
timer.
"""
from __future__ import annotations
from collections import OrderedDict
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, cast
import attr
@ -80,12 +71,12 @@ class RegistryEntry:
entity_id: str = attr.ib()
unique_id: str = attr.ib()
platform: str = attr.ib()
name: Optional[str] = attr.ib(default=None)
icon: Optional[str] = attr.ib(default=None)
device_id: Optional[str] = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None)
config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by: Optional[str] = attr.ib(
name: str | None = attr.ib(default=None)
icon: str | None = attr.ib(default=None)
device_id: str | None = attr.ib(default=None)
area_id: str | None = attr.ib(default=None)
config_entry_id: str | None = attr.ib(default=None)
disabled_by: str | None = attr.ib(
default=None,
validator=attr.validators.in_(
(
@ -98,13 +89,13 @@ class RegistryEntry:
)
),
)
capabilities: Optional[Dict[str, Any]] = attr.ib(default=None)
capabilities: dict[str, Any] | None = attr.ib(default=None)
supported_features: int = attr.ib(default=0)
device_class: Optional[str] = attr.ib(default=None)
unit_of_measurement: Optional[str] = attr.ib(default=None)
device_class: str | None = attr.ib(default=None)
unit_of_measurement: str | None = attr.ib(default=None)
# As set by integration
original_name: Optional[str] = attr.ib(default=None)
original_icon: Optional[str] = attr.ib(default=None)
original_name: str | None = attr.ib(default=None)
original_icon: str | None = attr.ib(default=None)
domain: str = attr.ib(init=False, repr=False)
@domain.default
@ -120,7 +111,7 @@ class RegistryEntry:
@callback
def write_unavailable_state(self, hass: HomeAssistantType) -> None:
"""Write the unavailable state to the state machine."""
attrs: Dict[str, Any] = {ATTR_RESTORED: True}
attrs: dict[str, Any] = {ATTR_RESTORED: True}
if self.capabilities is not None:
attrs.update(self.capabilities)
@ -151,8 +142,8 @@ class EntityRegistry:
def __init__(self, hass: HomeAssistantType):
"""Initialize the registry."""
self.hass = hass
self.entities: Dict[str, RegistryEntry]
self._index: Dict[Tuple[str, str, str], str] = {}
self.entities: dict[str, RegistryEntry]
self._index: dict[tuple[str, str, str], str] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified
@ -161,7 +152,7 @@ class EntityRegistry:
@callback
def async_get_device_class_lookup(self, domain_device_classes: set) -> dict:
"""Return a lookup for the device class by domain."""
lookup: Dict[str, Dict[Tuple[Any, Any], str]] = {}
lookup: dict[str, dict[tuple[Any, Any], str]] = {}
for entity in self.entities.values():
if not entity.device_id:
continue
@ -180,14 +171,14 @@ class EntityRegistry:
return entity_id in self.entities
@callback
def async_get(self, entity_id: str) -> Optional[RegistryEntry]:
def async_get(self, entity_id: str) -> RegistryEntry | None:
"""Get EntityEntry for an entity_id."""
return self.entities.get(entity_id)
@callback
def async_get_entity_id(
self, domain: str, platform: str, unique_id: str
) -> Optional[str]:
) -> str | None:
"""Check if an entity_id is currently registered."""
return self._index.get((domain, platform, unique_id))
@ -196,7 +187,7 @@ class EntityRegistry:
self,
domain: str,
suggested_object_id: str,
known_object_ids: Optional[Iterable[str]] = None,
known_object_ids: Iterable[str] | None = None,
) -> str:
"""Generate an entity ID that does not conflict.
@ -226,20 +217,20 @@ class EntityRegistry:
unique_id: str,
*,
# To influence entity ID generation
suggested_object_id: Optional[str] = None,
known_object_ids: Optional[Iterable[str]] = None,
suggested_object_id: str | None = None,
known_object_ids: Iterable[str] | None = None,
# To disable an entity if it gets created
disabled_by: Optional[str] = None,
disabled_by: str | None = None,
# Data that we want entry to have
config_entry: Optional["ConfigEntry"] = None,
device_id: Optional[str] = None,
area_id: Optional[str] = None,
capabilities: Optional[Dict[str, Any]] = None,
supported_features: Optional[int] = None,
device_class: Optional[str] = None,
unit_of_measurement: Optional[str] = None,
original_name: Optional[str] = None,
original_icon: Optional[str] = None,
config_entry: "ConfigEntry" | None = None,
device_id: str | None = None,
area_id: str | None = None,
capabilities: dict[str, Any] | None = None,
supported_features: int | None = None,
device_class: str | None = None,
unit_of_measurement: str | None = None,
original_name: str | None = None,
original_icon: str | None = None,
) -> RegistryEntry:
"""Get entity. Create if it doesn't exist."""
config_entry_id = None
@ -363,12 +354,12 @@ class EntityRegistry:
self,
entity_id: str,
*,
name: Union[str, None, UndefinedType] = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
icon: str | None | UndefinedType = UNDEFINED,
area_id: str | None | UndefinedType = UNDEFINED,
new_entity_id: str | UndefinedType = UNDEFINED,
new_unique_id: str | UndefinedType = UNDEFINED,
disabled_by: str | None | UndefinedType = UNDEFINED,
) -> RegistryEntry:
"""Update properties of an entity."""
return self._async_update_entity(
@ -386,20 +377,20 @@ class EntityRegistry:
self,
entity_id: str,
*,
name: Union[str, None, UndefinedType] = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED,
config_entry_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED,
device_id: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
capabilities: Union[Dict[str, Any], None, UndefinedType] = UNDEFINED,
supported_features: Union[int, UndefinedType] = UNDEFINED,
device_class: Union[str, None, UndefinedType] = UNDEFINED,
unit_of_measurement: Union[str, None, UndefinedType] = UNDEFINED,
original_name: Union[str, None, UndefinedType] = UNDEFINED,
original_icon: Union[str, None, UndefinedType] = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
icon: str | None | UndefinedType = UNDEFINED,
config_entry_id: str | None | UndefinedType = UNDEFINED,
new_entity_id: str | UndefinedType = UNDEFINED,
device_id: str | None | UndefinedType = UNDEFINED,
area_id: str | None | UndefinedType = UNDEFINED,
new_unique_id: str | UndefinedType = UNDEFINED,
disabled_by: str | None | UndefinedType = UNDEFINED,
capabilities: dict[str, Any] | None | UndefinedType = UNDEFINED,
supported_features: int | UndefinedType = UNDEFINED,
device_class: str | None | UndefinedType = UNDEFINED,
unit_of_measurement: str | None | UndefinedType = UNDEFINED,
original_name: str | None | UndefinedType = UNDEFINED,
original_icon: str | None | UndefinedType = UNDEFINED,
) -> RegistryEntry:
"""Private facing update properties method."""
old = self.entities[entity_id]
@ -479,7 +470,7 @@ class EntityRegistry:
old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate,
)
entities: Dict[str, RegistryEntry] = OrderedDict()
entities: dict[str, RegistryEntry] = OrderedDict()
if data is not None:
for entity in data["entities"]:
@ -516,7 +507,7 @@ class EntityRegistry:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> Dict[str, Any]:
def _data_to_save(self) -> dict[str, Any]:
"""Return data of entity registry to store in a file."""
data = {}
@ -605,7 +596,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
@callback
def async_entries_for_device(
registry: EntityRegistry, device_id: str, include_disabled_entities: bool = False
) -> List[RegistryEntry]:
) -> list[RegistryEntry]:
"""Return entries that match a device."""
return [
entry
@ -618,7 +609,7 @@ def async_entries_for_device(
@callback
def async_entries_for_area(
registry: EntityRegistry, area_id: str
) -> List[RegistryEntry]:
) -> list[RegistryEntry]:
"""Return entries that match an area."""
return [entry for entry in registry.entities.values() if entry.area_id == area_id]
@ -626,7 +617,7 @@ def async_entries_for_area(
@callback
def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str
) -> List[RegistryEntry]:
) -> list[RegistryEntry]:
"""Return entries that match a config entry."""
return [
entry
@ -665,7 +656,7 @@ def async_config_entry_disabled_by_changed(
)
async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
async def _async_migrate(entities: dict[str, Any]) -> dict[str, list[dict[str, Any]]]:
"""Migrate the YAML config file to storage helper format."""
return {
"entities": [
@ -721,7 +712,7 @@ def async_setup_entity_restore(
async def async_migrate_entries(
hass: HomeAssistantType,
config_entry_id: str,
entry_callback: Callable[[RegistryEntry], Optional[dict]],
entry_callback: Callable[[RegistryEntry], dict | None],
) -> None:
"""Migrator of unique IDs."""
ent_reg = await async_get_registry(hass)

View File

@ -1,8 +1,10 @@
"""A class to hold entity values."""
from __future__ import annotations
from collections import OrderedDict
import fnmatch
import re
from typing import Any, Dict, Optional, Pattern
from typing import Any, Pattern
from homeassistant.core import split_entity_id
@ -14,17 +16,17 @@ class EntityValues:
def __init__(
self,
exact: Optional[Dict[str, Dict[str, str]]] = None,
domain: Optional[Dict[str, Dict[str, str]]] = None,
glob: Optional[Dict[str, Dict[str, str]]] = None,
exact: dict[str, dict[str, str]] | None = None,
domain: dict[str, dict[str, str]] | None = None,
glob: dict[str, dict[str, str]] | None = None,
) -> None:
"""Initialize an EntityConfigDict."""
self._cache: Dict[str, Dict[str, str]] = {}
self._cache: dict[str, dict[str, str]] = {}
self._exact = exact
self._domain = domain
if glob is None:
compiled: Optional[Dict[Pattern[str], Any]] = None
compiled: dict[Pattern[str], Any] | None = None
else:
compiled = OrderedDict()
for key, value in glob.items():
@ -32,7 +34,7 @@ class EntityValues:
self._glob = compiled
def get(self, entity_id: str) -> Dict[str, str]:
def get(self, entity_id: str) -> dict[str, str]:
"""Get config for an entity id."""
if entity_id in self._cache:
return self._cache[entity_id]

View File

@ -1,7 +1,9 @@
"""Helper class to implement include/exclude of entities and domains."""
from __future__ import annotations
import fnmatch
import re
from typing import Callable, Dict, List, Pattern
from typing import Callable, Pattern
import voluptuous as vol
@ -19,7 +21,7 @@ CONF_EXCLUDE_ENTITIES = "exclude_entities"
CONF_ENTITY_GLOBS = "entity_globs"
def convert_filter(config: Dict[str, List[str]]) -> Callable[[str], bool]:
def convert_filter(config: dict[str, list[str]]) -> Callable[[str], bool]:
"""Convert the filter schema into a filter."""
filt = generate_filter(
config[CONF_INCLUDE_DOMAINS],
@ -57,7 +59,7 @@ FILTER_SCHEMA = vol.All(BASE_FILTER_SCHEMA, convert_filter)
def convert_include_exclude_filter(
config: Dict[str, Dict[str, List[str]]]
config: dict[str, dict[str, list[str]]]
) -> Callable[[str], bool]:
"""Convert the include exclude filter schema into a filter."""
include = config[CONF_INCLUDE]
@ -107,7 +109,7 @@ def _glob_to_re(glob: str) -> Pattern[str]:
return re.compile(fnmatch.translate(glob))
def _test_against_patterns(patterns: List[Pattern[str]], entity_id: str) -> bool:
def _test_against_patterns(patterns: list[Pattern[str]], entity_id: str) -> bool:
"""Test entity against list of patterns, true if any match."""
for pattern in patterns:
if pattern.match(entity_id):
@ -119,12 +121,12 @@ def _test_against_patterns(patterns: List[Pattern[str]], entity_id: str) -> bool
# It's safe since we don't modify it. And None causes typing warnings
# pylint: disable=dangerous-default-value
def generate_filter(
include_domains: List[str],
include_entities: List[str],
exclude_domains: List[str],
exclude_entities: List[str],
include_entity_globs: List[str] = [],
exclude_entity_globs: List[str] = [],
include_domains: list[str],
include_entities: list[str],
exclude_domains: list[str],
exclude_entities: list[str],
include_entity_globs: list[str] = [],
exclude_entity_globs: list[str] = [],
) -> Callable[[str], bool]:
"""Return a function that will filter entities based on the args."""
include_d = set(include_domains)

View File

@ -1,4 +1,6 @@
"""Helpers for listening to events."""
from __future__ import annotations
import asyncio
import copy
from dataclasses import dataclass
@ -6,18 +8,7 @@ from datetime import datetime, timedelta
import functools as ft
import logging
import time
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from typing import Any, Awaitable, Callable, Iterable, List
import attr
@ -79,8 +70,8 @@ class TrackStates:
"""
all_states: bool
entities: Set
domains: Set
entities: set
domains: set
@dataclass
@ -94,7 +85,7 @@ class TrackTemplate:
template: Template
variables: TemplateVarsType
rate_limit: Optional[timedelta] = None
rate_limit: timedelta | None = None
@dataclass
@ -146,10 +137,10 @@ def threaded_listener_factory(
@bind_hass
def async_track_state_change(
hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]],
entity_ids: str | Iterable[str],
action: Callable[[str, State, State], None],
from_state: Union[None, str, Iterable[str]] = None,
to_state: Union[None, str, Iterable[str]] = None,
from_state: None | str | Iterable[str] = None,
to_state: None | str | Iterable[str] = None,
) -> CALLBACK_TYPE:
"""Track specific state changes.
@ -240,7 +231,7 @@ track_state_change = threaded_listener_factory(async_track_state_change)
@bind_hass
def async_track_state_change_event(
hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]],
entity_ids: str | Iterable[str],
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track specific state change events indexed by entity_id.
@ -337,7 +328,7 @@ def _async_remove_indexed_listeners(
@bind_hass
def async_track_entity_registry_updated_event(
hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]],
entity_ids: str | Iterable[str],
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track specific entity registry updated events indexed by entity_id.
@ -402,7 +393,7 @@ def async_track_entity_registry_updated_event(
@callback
def _async_dispatch_domain_event(
hass: HomeAssistant, event: Event, callbacks: Dict[str, List]
hass: HomeAssistant, event: Event, callbacks: dict[str, list]
) -> None:
domain = split_entity_id(event.data["entity_id"])[0]
@ -423,7 +414,7 @@ def _async_dispatch_domain_event(
@bind_hass
def async_track_state_added_domain(
hass: HomeAssistant,
domains: Union[str, Iterable[str]],
domains: str | Iterable[str],
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track state change events when an entity is added to domains."""
@ -476,7 +467,7 @@ def async_track_state_added_domain(
@bind_hass
def async_track_state_removed_domain(
hass: HomeAssistant,
domains: Union[str, Iterable[str]],
domains: str | Iterable[str],
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track state change events when an entity is removed from domains."""
@ -527,7 +518,7 @@ def async_track_state_removed_domain(
@callback
def _async_string_to_lower_list(instr: Union[str, Iterable[str]]) -> List[str]:
def _async_string_to_lower_list(instr: str | Iterable[str]) -> list[str]:
if isinstance(instr, str):
return [instr.lower()]
@ -546,7 +537,7 @@ class _TrackStateChangeFiltered:
"""Handle removal / refresh of tracker init."""
self.hass = hass
self._action = action
self._listeners: Dict[str, Callable] = {}
self._listeners: dict[str, Callable] = {}
self._last_track_states: TrackStates = track_states
@callback
@ -569,7 +560,7 @@ class _TrackStateChangeFiltered:
self._setup_entities_listener(track_states.domains, track_states.entities)
@property
def listeners(self) -> Dict:
def listeners(self) -> dict:
"""State changes that will cause a re-render."""
track_states = self._last_track_states
return {
@ -628,7 +619,7 @@ class _TrackStateChangeFiltered:
self._listeners.pop(listener_name)()
@callback
def _setup_entities_listener(self, domains: Set, entities: Set) -> None:
def _setup_entities_listener(self, domains: set, entities: set) -> None:
if domains:
entities = entities.copy()
entities.update(self.hass.states.async_entity_ids(domains))
@ -642,7 +633,7 @@ class _TrackStateChangeFiltered:
)
@callback
def _setup_domains_listener(self, domains: Set) -> None:
def _setup_domains_listener(self, domains: set) -> None:
if not domains:
return
@ -691,8 +682,8 @@ def async_track_state_change_filtered(
def async_track_template(
hass: HomeAssistant,
template: Template,
action: Callable[[str, Optional[State], Optional[State]], None],
variables: Optional[TemplateVarsType] = None,
action: Callable[[str, State | None, State | None], None],
variables: TemplateVarsType | None = None,
) -> Callable[[], None]:
"""Add a listener that fires when a a template evaluates to 'true'.
@ -734,7 +725,7 @@ def async_track_template(
@callback
def _template_changed_listener(
event: Event, updates: List[TrackTemplateResult]
event: Event, updates: list[TrackTemplateResult]
) -> None:
"""Check if condition is correct and run action."""
track_result = updates.pop()
@ -792,12 +783,12 @@ class _TrackTemplateResultInfo:
track_template_.template.hass = hass
self._track_templates = track_templates
self._last_result: Dict[Template, Union[str, TemplateError]] = {}
self._last_result: dict[Template, str | TemplateError] = {}
self._rate_limit = KeyedRateLimit(hass)
self._info: Dict[Template, RenderInfo] = {}
self._track_state_changes: Optional[_TrackStateChangeFiltered] = None
self._time_listeners: Dict[Template, Callable] = {}
self._info: dict[Template, RenderInfo] = {}
self._track_state_changes: _TrackStateChangeFiltered | None = None
self._time_listeners: dict[Template, Callable] = {}
def async_setup(self, raise_on_template_error: bool) -> None:
"""Activation of template tracking."""
@ -826,7 +817,7 @@ class _TrackTemplateResultInfo:
)
@property
def listeners(self) -> Dict:
def listeners(self) -> dict:
"""State changes that will cause a re-render."""
assert self._track_state_changes
return {
@ -882,8 +873,8 @@ class _TrackTemplateResultInfo:
self,
track_template_: TrackTemplate,
now: datetime,
event: Optional[Event],
) -> Union[bool, TrackTemplateResult]:
event: Event | None,
) -> bool | TrackTemplateResult:
"""Re-render the template if conditions match.
Returns False if the template was not be re-rendered
@ -927,7 +918,7 @@ class _TrackTemplateResultInfo:
)
try:
result: Union[str, TemplateError] = info.result()
result: str | TemplateError = info.result()
except TemplateError as ex:
result = ex
@ -945,9 +936,9 @@ class _TrackTemplateResultInfo:
@callback
def _refresh(
self,
event: Optional[Event],
track_templates: Optional[Iterable[TrackTemplate]] = None,
replayed: Optional[bool] = False,
event: Event | None,
track_templates: Iterable[TrackTemplate] | None = None,
replayed: bool | None = False,
) -> None:
"""Refresh the template.
@ -1076,16 +1067,16 @@ def async_track_same_state(
hass: HomeAssistant,
period: timedelta,
action: Callable[..., None],
async_check_same_func: Callable[[str, Optional[State], Optional[State]], bool],
entity_ids: Union[str, Iterable[str]] = MATCH_ALL,
async_check_same_func: Callable[[str, State | None, State | None], bool],
entity_ids: str | Iterable[str] = MATCH_ALL,
) -> CALLBACK_TYPE:
"""Track the state of entities for a period and run an action.
If async_check_func is None it use the state of orig_value.
Without entity_ids we track all state changes.
"""
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
async_remove_state_for_cancel: CALLBACK_TYPE | None = None
async_remove_state_for_listener: CALLBACK_TYPE | None = None
job = HassJob(action)
@ -1113,8 +1104,8 @@ def async_track_same_state(
def state_for_cancel_listener(event: Event) -> None:
"""Fire on changes and cancel for listener if changed."""
entity: str = event.data["entity_id"]
from_state: Optional[State] = event.data.get("old_state")
to_state: Optional[State] = event.data.get("new_state")
from_state: State | None = event.data.get("old_state")
to_state: State | None = event.data.get("new_state")
if not async_check_same_func(entity, from_state, to_state):
clear_listener()
@ -1144,7 +1135,7 @@ track_same_state = threaded_listener_factory(async_track_same_state)
@bind_hass
def async_track_point_in_time(
hass: HomeAssistant,
action: Union[HassJob, Callable[..., None]],
action: HassJob | Callable[..., None],
point_in_time: datetime,
) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in time."""
@ -1165,7 +1156,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@bind_hass
def async_track_point_in_utc_time(
hass: HomeAssistant,
action: Union[HassJob, Callable[..., None]],
action: HassJob | Callable[..., None],
point_in_time: datetime,
) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in UTC time."""
@ -1176,7 +1167,7 @@ def async_track_point_in_utc_time(
# having to figure out how to call the action every time its called.
job = action if isinstance(action, HassJob) else HassJob(action)
cancel_callback: Optional[asyncio.TimerHandle] = None
cancel_callback: asyncio.TimerHandle | None = None
@callback
def run_action() -> None:
@ -1217,7 +1208,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim
@callback
@bind_hass
def async_call_later(
hass: HomeAssistant, delay: float, action: Union[HassJob, Callable[..., None]]
hass: HomeAssistant, delay: float, action: HassJob | Callable[..., None]
) -> CALLBACK_TYPE:
"""Add a listener that is called in <delay>."""
return async_track_point_in_utc_time(
@ -1232,7 +1223,7 @@ call_later = threaded_listener_factory(async_call_later)
@bind_hass
def async_track_time_interval(
hass: HomeAssistant,
action: Callable[..., Union[None, Awaitable]],
action: Callable[..., None | Awaitable],
interval: timedelta,
) -> CALLBACK_TYPE:
"""Add a listener that fires repetitively at every timedelta interval."""
@ -1276,9 +1267,9 @@ class SunListener:
hass: HomeAssistant = attr.ib()
job: HassJob = attr.ib()
event: str = attr.ib()
offset: Optional[timedelta] = attr.ib()
_unsub_sun: Optional[CALLBACK_TYPE] = attr.ib(default=None)
_unsub_config: Optional[CALLBACK_TYPE] = attr.ib(default=None)
offset: timedelta | None = attr.ib()
_unsub_sun: CALLBACK_TYPE | None = attr.ib(default=None)
_unsub_config: CALLBACK_TYPE | None = attr.ib(default=None)
@callback
def async_attach(self) -> None:
@ -1332,7 +1323,7 @@ class SunListener:
@callback
@bind_hass
def async_track_sunrise(
hass: HomeAssistant, action: Callable[..., None], offset: Optional[timedelta] = None
hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None
) -> CALLBACK_TYPE:
"""Add a listener that will fire a specified offset from sunrise daily."""
listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNRISE, offset)
@ -1346,7 +1337,7 @@ track_sunrise = threaded_listener_factory(async_track_sunrise)
@callback
@bind_hass
def async_track_sunset(
hass: HomeAssistant, action: Callable[..., None], offset: Optional[timedelta] = None
hass: HomeAssistant, action: Callable[..., None], offset: timedelta | None = None
) -> CALLBACK_TYPE:
"""Add a listener that will fire a specified offset from sunset daily."""
listener = SunListener(hass, HassJob(action), SUN_EVENT_SUNSET, offset)
@ -1365,9 +1356,9 @@ time_tracker_utcnow = dt_util.utcnow
def async_track_utc_time_change(
hass: HomeAssistant,
action: Callable[..., None],
hour: Optional[Any] = None,
minute: Optional[Any] = None,
second: Optional[Any] = None,
hour: Any | None = None,
minute: Any | None = None,
second: Any | None = None,
local: bool = False,
) -> CALLBACK_TYPE:
"""Add a listener that will fire if time matches a pattern."""
@ -1394,7 +1385,7 @@ def async_track_utc_time_change(
localized_now, matching_seconds, matching_minutes, matching_hours
)
time_listener: Optional[CALLBACK_TYPE] = None
time_listener: CALLBACK_TYPE | None = None
@callback
def pattern_time_change_listener(_: datetime) -> None:
@ -1431,9 +1422,9 @@ track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
def async_track_time_change(
hass: HomeAssistant,
action: Callable[..., None],
hour: Optional[Any] = None,
minute: Optional[Any] = None,
second: Optional[Any] = None,
hour: Any | None = None,
minute: Any | None = None,
second: Any | None = None,
) -> CALLBACK_TYPE:
"""Add a listener that will fire if UTC time matches a pattern."""
return async_track_utc_time_change(hass, action, hour, minute, second, local=True)
@ -1442,9 +1433,7 @@ def async_track_time_change(
track_time_change = threaded_listener_factory(async_track_time_change)
def process_state_match(
parameter: Union[None, str, Iterable[str]]
) -> Callable[[str], bool]:
def process_state_match(parameter: None | str | Iterable[str]) -> Callable[[str], bool]:
"""Convert parameter to function that matches input against parameter."""
if parameter is None or parameter == MATCH_ALL:
return lambda _: True
@ -1459,7 +1448,7 @@ def process_state_match(
@callback
def _entities_domains_from_render_infos(
render_infos: Iterable[RenderInfo],
) -> Tuple[Set, Set]:
) -> tuple[set, set]:
"""Combine from multiple RenderInfo."""
entities = set()
domains = set()
@ -1520,7 +1509,7 @@ def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
@callback
def _rate_limit_for_event(
event: Event, info: RenderInfo, track_template_: TrackTemplate
) -> Optional[timedelta]:
) -> timedelta | None:
"""Determine the rate limit for an event."""
entity_id = event.data.get(ATTR_ENTITY_ID)
@ -1532,7 +1521,7 @@ def _rate_limit_for_event(
if track_template_.rate_limit is not None:
return track_template_.rate_limit
rate_limit: Optional[timedelta] = info.rate_limit
rate_limit: timedelta | None = info.rate_limit
return rate_limit

View File

@ -1,9 +1,11 @@
"""Provide frame helper for finding the current frame context."""
from __future__ import annotations
import asyncio
import functools
import logging
from traceback import FrameSummary, extract_stack
from typing import Any, Callable, Optional, Tuple, TypeVar, cast
from typing import Any, Callable, TypeVar, cast
from homeassistant.exceptions import HomeAssistantError
@ -13,8 +15,8 @@ CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-na
def get_integration_frame(
exclude_integrations: Optional[set] = None,
) -> Tuple[FrameSummary, str, str]:
exclude_integrations: set | None = None,
) -> tuple[FrameSummary, str, str]:
"""Return the frame, integration and integration path of the current stack frame."""
found_frame = None
if not exclude_integrations:
@ -64,7 +66,7 @@ def report(what: str) -> None:
def report_integration(
what: str, integration_frame: Tuple[FrameSummary, str, str]
what: str, integration_frame: tuple[FrameSummary, str, str]
) -> None:
"""Report incorrect usage in an integration.

View File

@ -1,6 +1,8 @@
"""Helper for httpx."""
from __future__ import annotations
import sys
from typing import Any, Callable, Optional
from typing import Any, Callable
import httpx
@ -29,7 +31,7 @@ def get_async_client(
"""
key = DATA_ASYNC_CLIENT if verify_ssl else DATA_ASYNC_CLIENT_NOVERIFY
client: Optional[httpx.AsyncClient] = hass.data.get(key)
client: httpx.AsyncClient | None = hass.data.get(key)
if client is None:
client = hass.data[key] = create_async_httpx_client(hass, verify_ssl)

View File

@ -1,9 +1,9 @@
"""Icon helper methods."""
from typing import Optional
from __future__ import annotations
def icon_for_battery_level(
battery_level: Optional[int] = None, charging: bool = False
battery_level: int | None = None, charging: bool = False
) -> str:
"""Return a battery icon valid identifier."""
icon = "mdi:battery"
@ -20,7 +20,7 @@ def icon_for_battery_level(
return icon
def icon_for_signal_level(signal_level: Optional[int] = None) -> str:
def icon_for_signal_level(signal_level: int | None = None) -> str:
"""Return a signal icon valid identifier."""
if signal_level is None or signal_level == 0:
return "mdi:signal-cellular-outline"

View File

@ -1,5 +1,6 @@
"""Helper to create a unique instance ID."""
from typing import Dict, Optional
from __future__ import annotations
import uuid
from homeassistant.core import HomeAssistant
@ -17,7 +18,7 @@ async def async_get(hass: HomeAssistant) -> str:
"""Get unique ID for the hass instance."""
store = storage.Store(hass, DATA_VERSION, DATA_KEY, True)
data: Optional[Dict[str, str]] = await storage.async_migrator( # type: ignore
data: dict[str, str] | None = await storage.async_migrator( # type: ignore
hass,
hass.config.path(LEGACY_UUID_FILE),
store,

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import re
from typing import Any, Callable, Dict, Iterable, Optional
from typing import Any, Callable, Dict, Iterable
import voluptuous as vol
@ -52,9 +52,9 @@ async def async_handle(
hass: HomeAssistantType,
platform: str,
intent_type: str,
slots: Optional[_SlotsType] = None,
text_input: Optional[str] = None,
context: Optional[Context] = None,
slots: _SlotsType | None = None,
text_input: str | None = None,
context: Context | None = None,
) -> IntentResponse:
"""Handle an intent."""
handler: IntentHandler = hass.data.get(DATA_KEY, {}).get(intent_type)
@ -103,7 +103,7 @@ class IntentUnexpectedError(IntentError):
@callback
@bind_hass
def async_match_state(
hass: HomeAssistantType, name: str, states: Optional[Iterable[State]] = None
hass: HomeAssistantType, name: str, states: Iterable[State] | None = None
) -> State:
"""Find a state that matches the name."""
if states is None:
@ -127,10 +127,10 @@ def async_test_feature(state: State, feature: int, feature_name: str) -> None:
class IntentHandler:
"""Intent handler registration."""
intent_type: Optional[str] = None
slot_schema: Optional[vol.Schema] = None
_slot_schema: Optional[vol.Schema] = None
platforms: Optional[Iterable[str]] = []
intent_type: str | None = None
slot_schema: vol.Schema | None = None
_slot_schema: vol.Schema | None = None
platforms: Iterable[str] | None = []
@callback
def async_can_handle(self, intent_obj: Intent) -> bool:
@ -163,7 +163,7 @@ class IntentHandler:
return f"<{self.__class__.__name__} - {self.intent_type}>"
def _fuzzymatch(name: str, items: Iterable[T], key: Callable[[T], str]) -> Optional[T]:
def _fuzzymatch(name: str, items: Iterable[T], key: Callable[[T], str]) -> T | None:
"""Fuzzy matching function."""
matches = []
pattern = ".*?".join(name)
@ -226,7 +226,7 @@ class Intent:
platform: str,
intent_type: str,
slots: _SlotsType,
text_input: Optional[str],
text_input: str | None,
context: Context,
) -> None:
"""Initialize an intent."""
@ -246,15 +246,15 @@ class Intent:
class IntentResponse:
"""Response to an intent."""
def __init__(self, intent: Optional[Intent] = None) -> None:
def __init__(self, intent: Intent | None = None) -> None:
"""Initialize an IntentResponse."""
self.intent = intent
self.speech: Dict[str, Dict[str, Any]] = {}
self.card: Dict[str, Dict[str, str]] = {}
self.speech: dict[str, dict[str, Any]] = {}
self.card: dict[str, dict[str, str]] = {}
@callback
def async_set_speech(
self, speech: str, speech_type: str = "plain", extra_data: Optional[Any] = None
self, speech: str, speech_type: str = "plain", extra_data: Any | None = None
) -> None:
"""Set speech response."""
self.speech[speech_type] = {"speech": speech, "extra_data": extra_data}
@ -267,6 +267,6 @@ class IntentResponse:
self.card[card_type] = {"title": title, "content": content}
@callback
def as_dict(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
def as_dict(self) -> dict[str, dict[str, dict[str, Any]]]:
"""Return a dictionary representation of an intent response."""
return {"speech": self.speech, "card": self.card}

View File

@ -1,7 +1,8 @@
"""Location helpers for Home Assistant."""
from __future__ import annotations
import logging
from typing import Optional, Sequence
from typing import Sequence
import voluptuous as vol
@ -25,9 +26,7 @@ def has_location(state: State) -> bool:
)
def closest(
latitude: float, longitude: float, states: Sequence[State]
) -> Optional[State]:
def closest(latitude: float, longitude: float, states: Sequence[State]) -> State | None:
"""Return closest state to point.
Async friendly.
@ -50,8 +49,8 @@ def closest(
def find_coordinates(
hass: HomeAssistantType, entity_id: str, recursion_history: Optional[list] = None
) -> Optional[str]:
hass: HomeAssistantType, entity_id: str, recursion_history: list | None = None
) -> str | None:
"""Find the gps coordinates of the entity in the form of '90.000,180.000'."""
entity_state = hass.states.get(entity_id)

View File

@ -1,7 +1,9 @@
"""Helpers for logging allowing more advanced logging styles to be used."""
from __future__ import annotations
import inspect
import logging
from typing import Any, Mapping, MutableMapping, Optional, Tuple
from typing import Any, Mapping, MutableMapping
class KeywordMessage:
@ -26,7 +28,7 @@ class KeywordStyleAdapter(logging.LoggerAdapter):
"""Represents an adapter wrapping the logger allowing KeywordMessages."""
def __init__(
self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] = None
self, logger: logging.Logger, extra: Mapping[str, Any] | None = None
) -> None:
"""Initialize a new StyleAdapter for the provided logger."""
super().__init__(logger, extra or {})
@ -41,7 +43,7 @@ class KeywordStyleAdapter(logging.LoggerAdapter):
def process(
self, msg: Any, kwargs: MutableMapping[str, Any]
) -> Tuple[Any, MutableMapping[str, Any]]:
) -> tuple[Any, MutableMapping[str, Any]]:
"""Process the keyword args in preparation for logging."""
return (
msg,

View File

@ -1,6 +1,8 @@
"""Network helpers."""
from __future__ import annotations
from ipaddress import ip_address
from typing import Optional, cast
from typing import cast
import yarl
@ -117,7 +119,7 @@ def get_url(
raise NoURLAvailableError
def _get_request_host() -> Optional[str]:
def _get_request_host() -> str | None:
"""Get the host address of the current request."""
request = http.current_request.get()
if request is None:

View File

@ -1,8 +1,10 @@
"""Ratelimit helper."""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import logging
from typing import Any, Callable, Dict, Hashable, Optional
from typing import Any, Callable, Hashable
from homeassistant.core import HomeAssistant, callback
import homeassistant.util.dt as dt_util
@ -19,8 +21,8 @@ class KeyedRateLimit:
):
"""Initialize ratelimit tracker."""
self.hass = hass
self._last_triggered: Dict[Hashable, datetime] = {}
self._rate_limit_timers: Dict[Hashable, asyncio.TimerHandle] = {}
self._last_triggered: dict[Hashable, datetime] = {}
self._rate_limit_timers: dict[Hashable, asyncio.TimerHandle] = {}
@callback
def async_has_timer(self, key: Hashable) -> bool:
@ -30,7 +32,7 @@ class KeyedRateLimit:
return key in self._rate_limit_timers
@callback
def async_triggered(self, key: Hashable, now: Optional[datetime] = None) -> None:
def async_triggered(self, key: Hashable, now: datetime | None = None) -> None:
"""Call when the action we are tracking was triggered."""
self.async_cancel_timer(key)
self._last_triggered[key] = now or dt_util.utcnow()
@ -54,11 +56,11 @@ class KeyedRateLimit:
def async_schedule_action(
self,
key: Hashable,
rate_limit: Optional[timedelta],
rate_limit: timedelta | None,
now: datetime,
action: Callable,
*args: Any,
) -> Optional[datetime]:
) -> datetime | None:
"""Check rate limits and schedule an action if we hit the limit.
If the rate limit is hit:

View File

@ -1,8 +1,9 @@
"""Class to reload platforms."""
from __future__ import annotations
import asyncio
import logging
from typing import Dict, Iterable, List, Optional
from typing import Iterable
from homeassistant import config as conf_util
from homeassistant.const import SERVICE_RELOAD
@ -61,7 +62,7 @@ async def _resetup_platform(
if not conf:
return
root_config: Dict = {integration_platform: []}
root_config: dict = {integration_platform: []}
# Extract only the config for template, ignore the rest.
for p_type, p_config in config_per_platform(conf, integration_platform):
if p_type != integration_name:
@ -101,7 +102,7 @@ async def _async_setup_platform(
hass: HomeAssistantType,
integration_name: str,
integration_platform: str,
platform_configs: List[Dict],
platform_configs: list[dict],
) -> None:
"""Platform for the first time when new configuration is added."""
if integration_platform not in hass.data:
@ -119,7 +120,7 @@ async def _async_setup_platform(
async def _async_reconfig_platform(
platform: EntityPlatform, platform_configs: List[Dict]
platform: EntityPlatform, platform_configs: list[dict]
) -> None:
"""Reconfigure an already loaded platform."""
await platform.async_reset()
@ -129,7 +130,7 @@ async def _async_reconfig_platform(
async def async_integration_yaml_config(
hass: HomeAssistantType, integration_name: str
) -> Optional[ConfigType]:
) -> ConfigType | None:
"""Fetch the latest yaml configuration for an integration."""
integration = await async_get_integration(hass, integration_name)
@ -141,7 +142,7 @@ async def async_integration_yaml_config(
@callback
def async_get_platform_without_config_entry(
hass: HomeAssistantType, integration_name: str, integration_platform_name: str
) -> Optional[EntityPlatform]:
) -> EntityPlatform | None:
"""Find an existing platform that is not a config entry."""
for integration_platform in async_get_platforms(hass, integration_name):
if integration_platform.config_entry is not None:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import logging
from typing import Any, Dict, List, Optional, Set, cast
from typing import Any, cast
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
@ -45,12 +45,12 @@ class StoredState:
self.state = state
self.last_seen = last_seen
def as_dict(self) -> Dict[str, Any]:
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the stored state."""
return {"state": self.state.as_dict(), "last_seen": self.last_seen}
@classmethod
def from_dict(cls, json_dict: Dict) -> StoredState:
def from_dict(cls, json_dict: dict) -> StoredState:
"""Initialize a stored state from a dict."""
last_seen = json_dict["last_seen"]
@ -106,11 +106,11 @@ class RestoreStateData:
self.store: Store = Store(
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
)
self.last_states: Dict[str, StoredState] = {}
self.entity_ids: Set[str] = set()
self.last_states: dict[str, StoredState] = {}
self.entity_ids: set[str] = set()
@callback
def async_get_stored_states(self) -> List[StoredState]:
def async_get_stored_states(self) -> list[StoredState]:
"""Get the set of states which should be stored.
This includes the states of all registered entities, as well as the
@ -249,7 +249,7 @@ class RestoreEntity(Entity):
)
data.async_restore_entity_removed(self.entity_id)
async def async_get_last_state(self) -> Optional[State]:
async def async_get_last_state(self) -> State | None:
"""Get the entity state from the previous run."""
if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet

View File

@ -1,4 +1,6 @@
"""Helpers to execute scripts."""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
@ -6,18 +8,7 @@ from functools import partial
import itertools
import logging
from types import MappingProxyType
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from typing import Any, Callable, Dict, Sequence, Union, cast
import async_timeout
import voluptuous as vol
@ -232,8 +223,8 @@ STATIC_VALIDATION_ACTION_TYPES = (
async def async_validate_actions_config(
hass: HomeAssistant, actions: List[ConfigType]
) -> List[ConfigType]:
hass: HomeAssistant, actions: list[ConfigType]
) -> list[ConfigType]:
"""Validate a list of actions."""
return await asyncio.gather(
*[async_validate_action_config(hass, action) for action in actions]
@ -300,8 +291,8 @@ class _ScriptRun:
self,
hass: HomeAssistant,
script: "Script",
variables: Dict[str, Any],
context: Optional[Context],
variables: dict[str, Any],
context: Context | None,
log_exceptions: bool,
) -> None:
self._hass = hass
@ -310,7 +301,7 @@ class _ScriptRun:
self._context = context
self._log_exceptions = log_exceptions
self._step = -1
self._action: Optional[Dict[str, Any]] = None
self._action: dict[str, Any] | None = None
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
@ -890,7 +881,7 @@ async def _async_stop_scripts_at_shutdown(hass, event):
_VarsType = Union[Dict[str, Any], MappingProxyType]
def _referenced_extract_ids(data: Dict[str, Any], key: str, found: Set[str]) -> None:
def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None:
"""Extract referenced IDs."""
if not data:
return
@ -913,20 +904,20 @@ class Script:
def __init__(
self,
hass: HomeAssistant,
sequence: Sequence[Dict[str, Any]],
sequence: Sequence[dict[str, Any]],
name: str,
domain: str,
*,
# Used in "Running <running_description>" log message
running_description: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None,
running_description: str | None = None,
change_listener: Callable[..., Any] | None = None,
script_mode: str = DEFAULT_SCRIPT_MODE,
max_runs: int = DEFAULT_MAX,
max_exceeded: str = DEFAULT_MAX_EXCEEDED,
logger: Optional[logging.Logger] = None,
logger: logging.Logger | None = None,
log_exceptions: bool = True,
top_level: bool = True,
variables: Optional[ScriptVariables] = None,
variables: ScriptVariables | None = None,
) -> None:
"""Initialize the script."""
all_scripts = hass.data.get(DATA_SCRIPTS)
@ -959,25 +950,25 @@ class Script:
self._log_exceptions = log_exceptions
self.last_action = None
self.last_triggered: Optional[datetime] = None
self.last_triggered: datetime | None = None
self._runs: List[_ScriptRun] = []
self._runs: list[_ScriptRun] = []
self.max_runs = max_runs
self._max_exceeded = max_exceeded
if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._repeat_script: Dict[int, Script] = {}
self._choose_data: Dict[int, Dict[str, Any]] = {}
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, dict[str, Any]] = {}
self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None
self.variables = variables
self._variables_dynamic = template.is_complex(variables)
if self._variables_dynamic:
template.attach(hass, variables)
@property
def change_listener(self) -> Optional[Callable[..., Any]]:
def change_listener(self) -> Callable[..., Any] | None:
"""Return the change_listener."""
return self._change_listener
@ -991,13 +982,13 @@ class Script:
):
self._change_listener_job = HassJob(change_listener)
def _set_logger(self, logger: Optional[logging.Logger] = None) -> None:
def _set_logger(self, logger: logging.Logger | None = None) -> None:
if logger:
self._logger = logger
else:
self._logger = logging.getLogger(f"{__name__}.{slugify(self.name)}")
def update_logger(self, logger: Optional[logging.Logger] = None) -> None:
def update_logger(self, logger: logging.Logger | None = None) -> None:
"""Update logger."""
self._set_logger(logger)
for script in self._repeat_script.values():
@ -1038,7 +1029,7 @@ class Script:
if self._referenced_devices is not None:
return self._referenced_devices
referenced: Set[str] = set()
referenced: set[str] = set()
for step in self.sequence:
action = cv.determine_script_action(step)
@ -1067,7 +1058,7 @@ class Script:
if self._referenced_entities is not None:
return self._referenced_entities
referenced: Set[str] = set()
referenced: set[str] = set()
for step in self.sequence:
action = cv.determine_script_action(step)
@ -1091,7 +1082,7 @@ class Script:
return referenced
def run(
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
self, variables: _VarsType | None = None, context: Context | None = None
) -> None:
"""Run script."""
asyncio.run_coroutine_threadsafe(
@ -1100,9 +1091,9 @@ class Script:
async def async_run(
self,
run_variables: Optional[_VarsType] = None,
context: Optional[Context] = None,
started_action: Optional[Callable[..., Any]] = None,
run_variables: _VarsType | None = None,
context: Context | None = None,
started_action: Callable[..., Any] | None = None,
) -> None:
"""Run script."""
if context is None:

View File

@ -1,5 +1,7 @@
"""Script variables."""
from typing import Any, Dict, Mapping, Optional
from __future__ import annotations
from typing import Any, Mapping
from homeassistant.core import HomeAssistant, callback
@ -9,20 +11,20 @@ from . import template
class ScriptVariables:
"""Class to hold and render script variables."""
def __init__(self, variables: Dict[str, Any]):
def __init__(self, variables: dict[str, Any]):
"""Initialize script variables."""
self.variables = variables
self._has_template: Optional[bool] = None
self._has_template: bool | None = None
@callback
def async_render(
self,
hass: HomeAssistant,
run_variables: Optional[Mapping[str, Any]],
run_variables: Mapping[str, Any] | None,
*,
render_as_defaults: bool = True,
limited: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Render script variables.
The run variables are used to compute the static variables.

View File

@ -10,14 +10,9 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypedDict,
Union,
cast,
)
@ -79,8 +74,8 @@ class ServiceParams(TypedDict):
domain: str
service: str
service_data: Dict[str, Any]
target: Optional[Dict]
service_data: dict[str, Any]
target: dict | None
@dataclasses.dataclass
@ -88,17 +83,17 @@ class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: Set[str] = dataclasses.field(default_factory=set)
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: Set[str] = dataclasses.field(default_factory=set)
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: Set[str] = dataclasses.field(default_factory=set)
missing_areas: Set[str] = dataclasses.field(default_factory=set)
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: Set[str]) -> None:
def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items."""
parts = []
for label, items in (
@ -137,7 +132,7 @@ async def async_call_from_config(
blocking: bool = False,
variables: TemplateVarsType = None,
validate_config: bool = True,
context: Optional[ha.Context] = None,
context: ha.Context | None = None,
) -> None:
"""Call a service based on a config hash."""
try:
@ -235,7 +230,7 @@ def async_prepare_call_from_config(
@bind_hass
def extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]:
) -> set[str]:
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
@ -251,7 +246,7 @@ async def async_extract_entities(
entities: Iterable[Entity],
service_call: ha.ServiceCall,
expand_group: bool = True,
) -> List[Entity]:
) -> list[Entity]:
"""Extract a list of entity objects from a service call.
Will convert group entity ids to the entity ids it represents.
@ -287,7 +282,7 @@ async def async_extract_entities(
@bind_hass
async def async_extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]:
) -> set[str]:
"""Extract a set of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
@ -408,7 +403,7 @@ def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JS
def _load_services_files(
hass: HomeAssistantType, integrations: Iterable[Integration]
) -> List[JSON_TYPE]:
) -> list[JSON_TYPE]:
"""Load service files for multiple intergrations."""
return [_load_services_file(hass, integration) for integration in integrations]
@ -416,7 +411,7 @@ def _load_services_files(
@bind_hass
async def async_get_all_descriptions(
hass: HomeAssistantType,
) -> Dict[str, Dict[str, Any]]:
) -> dict[str, dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
format_cache_key = "{}.{}".format
@ -448,7 +443,7 @@ async def async_get_all_descriptions(
loaded[domain] = content
# Build response
descriptions: Dict[str, Dict[str, Any]] = {}
descriptions: dict[str, dict[str, Any]] = {}
for domain in services:
descriptions[domain] = {}
@ -483,7 +478,7 @@ async def async_get_all_descriptions(
@ha.callback
@bind_hass
def async_set_service_schema(
hass: HomeAssistantType, domain: str, service: str, schema: Dict[str, Any]
hass: HomeAssistantType, domain: str, service: str, schema: dict[str, Any]
) -> None:
"""Register a description for a service."""
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
@ -504,9 +499,9 @@ def async_set_service_schema(
async def entity_service_call(
hass: HomeAssistantType,
platforms: Iterable["EntityPlatform"],
func: Union[str, Callable[..., Any]],
func: str | Callable[..., Any],
call: ha.ServiceCall,
required_features: Optional[Iterable[int]] = None,
required_features: Iterable[int] | None = None,
) -> None:
"""Handle an entity service call.
@ -516,17 +511,17 @@ async def entity_service_call(
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
entity_perms: Optional[
entity_perms: None | (
Callable[[str, str], bool]
] = user.permissions.check_entity
) = user.permissions.check_entity
else:
entity_perms = None
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
if target_all_entities:
referenced: Optional[SelectedEntities] = None
all_referenced: Optional[Set[str]] = None
referenced: SelectedEntities | None = None
all_referenced: set[str] | None = None
else:
# A set of entities we're trying to target.
referenced = await async_extract_referenced_entity_ids(hass, call, True)
@ -534,7 +529,7 @@ async def entity_service_call(
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):
data: Union[Dict, ha.ServiceCall] = {
data: dict | ha.ServiceCall = {
key: val
for key, val in call.data.items()
if key not in cv.ENTITY_SERVICE_FIELDS
@ -546,7 +541,7 @@ async def entity_service_call(
# Check the permissions
# A list with entities to call the service on.
entity_candidates: List["Entity"] = []
entity_candidates: list["Entity"] = []
if entity_perms is None:
for platform in platforms:
@ -662,8 +657,8 @@ async def entity_service_call(
async def _handle_entity_call(
hass: HomeAssistantType,
entity: Entity,
func: Union[str, Callable[..., Any]],
data: Union[Dict, ha.ServiceCall],
func: str | Callable[..., Any],
data: dict | ha.ServiceCall,
context: ha.Context,
) -> None:
"""Handle calling service method."""
@ -693,7 +688,7 @@ def async_register_admin_service(
hass: HomeAssistantType,
domain: str,
service: str,
service_func: Callable[[ha.ServiceCall], Optional[Awaitable]],
service_func: Callable[[ha.ServiceCall], Awaitable | None],
schema: vol.Schema = vol.Schema({}, extra=vol.PREVENT_EXTRA),
) -> None:
"""Register a service that requires admin access."""

View File

@ -29,7 +29,7 @@ The following cases will never be passed to your function:
from __future__ import annotations
from types import MappingProxyType
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State, callback
@ -66,7 +66,7 @@ ExtraCheckTypeFunc = Callable[
async def create_checker(
hass: HomeAssistant,
_domain: str,
extra_significant_check: Optional[ExtraCheckTypeFunc] = None,
extra_significant_check: ExtraCheckTypeFunc | None = None,
) -> SignificantlyChangedChecker:
"""Create a significantly changed checker for a domain."""
await _initialize(hass)
@ -90,15 +90,15 @@ async def _initialize(hass: HomeAssistant) -> None:
await async_process_integration_platforms(hass, PLATFORM, process_platform)
def either_one_none(val1: Optional[Any], val2: Optional[Any]) -> bool:
def either_one_none(val1: Any | None, val2: Any | None) -> bool:
"""Test if exactly one value is None."""
return (val1 is None and val2 is not None) or (val1 is not None and val2 is None)
def check_numeric_changed(
val1: Optional[Union[int, float]],
val2: Optional[Union[int, float]],
change: Union[int, float],
val1: int | float | None,
val2: int | float | None,
change: int | float,
) -> bool:
"""Check if two numeric values have changed."""
if val1 is None and val2 is None:
@ -125,22 +125,22 @@ class SignificantlyChangedChecker:
def __init__(
self,
hass: HomeAssistant,
extra_significant_check: Optional[ExtraCheckTypeFunc] = None,
extra_significant_check: ExtraCheckTypeFunc | None = None,
) -> None:
"""Test if an entity has significantly changed."""
self.hass = hass
self.last_approved_entities: Dict[str, Tuple[State, Any]] = {}
self.last_approved_entities: dict[str, tuple[State, Any]] = {}
self.extra_significant_check = extra_significant_check
@callback
def async_is_significant_change(
self, new_state: State, *, extra_arg: Optional[Any] = None
self, new_state: State, *, extra_arg: Any | None = None
) -> bool:
"""Return if this was a significant change.
Extra kwargs are passed to the extra significant checker.
"""
old_data: Optional[Tuple[State, Any]] = self.last_approved_entities.get(
old_data: tuple[State, Any] | None = self.last_approved_entities.get(
new_state.entity_id
)
@ -164,9 +164,7 @@ class SignificantlyChangedChecker:
self.last_approved_entities[new_state.entity_id] = (new_state, extra_arg)
return True
functions: Optional[Dict[str, CheckTypeFunc]] = self.hass.data.get(
DATA_FUNCTIONS
)
functions: dict[str, CheckTypeFunc] | None = self.hass.data.get(DATA_FUNCTIONS)
if functions is None:
raise RuntimeError("Significant Change not initialized")

View File

@ -1,7 +1,9 @@
"""Helper to help coordinating calls."""
from __future__ import annotations
import asyncio
import functools
from typing import Callable, Optional, TypeVar, cast
from typing import Callable, TypeVar, cast
from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
@ -24,7 +26,7 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
@bind_hass
@functools.wraps(func)
def wrapped(hass: HomeAssistant) -> T:
obj: Optional[T] = hass.data.get(data_key)
obj: T | None = hass.data.get(data_key)
if obj is None:
obj = hass.data[data_key] = func(hass)
return obj

View File

@ -1,10 +1,12 @@
"""Helpers that help with state related things."""
from __future__ import annotations
import asyncio
from collections import defaultdict
import datetime as dt
import logging
from types import ModuleType, TracebackType
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Iterable
from homeassistant.components.sun import STATE_ABOVE_HORIZON, STATE_BELOW_HORIZON
from homeassistant.const import (
@ -44,19 +46,19 @@ class AsyncTrackStates:
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize a TrackStates block."""
self.hass = hass
self.states: List[State] = []
self.states: list[State] = []
# pylint: disable=attribute-defined-outside-init
def __enter__(self) -> List[State]:
def __enter__(self) -> list[State]:
"""Record time from which to track changes."""
self.now = dt_util.utcnow()
return self.states
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Add changes states to changes list."""
self.states.extend(get_changed_since(self.hass.states.async_all(), self.now))
@ -64,7 +66,7 @@ class AsyncTrackStates:
def get_changed_since(
states: Iterable[State], utc_point_in_time: dt.datetime
) -> List[State]:
) -> list[State]:
"""Return list of states that have been changed since utc_point_in_time.
Deprecated. Remove after June 2021.
@ -76,21 +78,21 @@ def get_changed_since(
@bind_hass
async def async_reproduce_state(
hass: HomeAssistantType,
states: Union[State, Iterable[State]],
states: State | Iterable[State],
*,
context: Optional[Context] = None,
reproduce_options: Optional[Dict[str, Any]] = None,
context: Context | None = None,
reproduce_options: dict[str, Any] | None = None,
) -> None:
"""Reproduce a list of states on multiple domains."""
if isinstance(states, State):
states = [states]
to_call: Dict[str, List[State]] = defaultdict(list)
to_call: dict[str, list[State]] = defaultdict(list)
for state in states:
to_call[state.domain].append(state)
async def worker(domain: str, states_by_domain: List[State]) -> None:
async def worker(domain: str, states_by_domain: list[State]) -> None:
try:
integration = await async_get_integration(hass, domain)
except IntegrationNotFound:
@ -100,7 +102,7 @@ async def async_reproduce_state(
return
try:
platform: Optional[ModuleType] = integration.get_platform("reproduce_state")
platform: ModuleType | None = integration.get_platform("reproduce_state")
except ImportError:
_LOGGER.warning("Integration %s does not support reproduce state", domain)
return

View File

@ -1,9 +1,11 @@
"""Helper to help store data."""
from __future__ import annotations
import asyncio
from json import JSONEncoder
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback
@ -71,18 +73,18 @@ class Store:
key: str,
private: bool = False,
*,
encoder: Optional[Type[JSONEncoder]] = None,
encoder: type[JSONEncoder] | None = None,
):
"""Initialize storage class."""
self.version = version
self.key = key
self.hass = hass
self._private = private
self._data: Optional[Dict[str, Any]] = None
self._unsub_delay_listener: Optional[CALLBACK_TYPE] = None
self._unsub_final_write_listener: Optional[CALLBACK_TYPE] = None
self._data: dict[str, Any] | None = None
self._unsub_delay_listener: CALLBACK_TYPE | None = None
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
self._write_lock = asyncio.Lock()
self._load_task: Optional[asyncio.Future] = None
self._load_task: asyncio.Future | None = None
self._encoder = encoder
@property
@ -90,7 +92,7 @@ class Store:
"""Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key)
async def async_load(self) -> Union[Dict, List, None]:
async def async_load(self) -> dict | list | None:
"""Load data.
If the expected version does not match the given version, the migrate
@ -140,7 +142,7 @@ class Store:
return stored
async def async_save(self, data: Union[Dict, List]) -> None:
async def async_save(self, data: dict | list) -> None:
"""Save data."""
self._data = {"version": self.version, "key": self.key, "data": data}
@ -151,7 +153,7 @@ class Store:
await self._async_handle_write_data()
@callback
def async_delay_save(self, data_func: Callable[[], Dict], delay: float = 0) -> None:
def async_delay_save(self, data_func: Callable[[], dict], delay: float = 0) -> None:
"""Save data with an optional delay."""
self._data = {"version": self.version, "key": self.key, "data_func": data_func}
@ -224,7 +226,7 @@ class Store:
except (json_util.SerializationError, json_util.WriteError) as err:
_LOGGER.error("Error writing config for %s: %s", self.key, err)
def _write_data(self, path: str, data: Dict) -> None:
def _write_data(self, path: str, data: dict) -> None:
"""Write the data."""
if not os.path.isdir(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING
from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
from homeassistant.core import callback
@ -44,8 +44,8 @@ def get_astral_location(hass: HomeAssistantType) -> astral.Location:
def get_astral_event_next(
hass: HomeAssistantType,
event: str,
utc_point_in_time: Optional[datetime.datetime] = None,
offset: Optional[datetime.timedelta] = None,
utc_point_in_time: datetime.datetime | None = None,
offset: datetime.timedelta | None = None,
) -> datetime.datetime:
"""Calculate the next specified solar event."""
location = get_astral_location(hass)
@ -56,8 +56,8 @@ def get_astral_event_next(
def get_location_astral_event_next(
location: "astral.Location",
event: str,
utc_point_in_time: Optional[datetime.datetime] = None,
offset: Optional[datetime.timedelta] = None,
utc_point_in_time: datetime.datetime | None = None,
offset: datetime.timedelta | None = None,
) -> datetime.datetime:
"""Calculate the next specified solar event."""
from astral import AstralError # pylint: disable=import-outside-toplevel
@ -91,8 +91,8 @@ def get_location_astral_event_next(
def get_astral_event_date(
hass: HomeAssistantType,
event: str,
date: Union[datetime.date, datetime.datetime, None] = None,
) -> Optional[datetime.datetime]:
date: datetime.date | datetime.datetime | None = None,
) -> datetime.datetime | None:
"""Calculate the astral event time for the specified date."""
from astral import AstralError # pylint: disable=import-outside-toplevel
@ -114,7 +114,7 @@ def get_astral_event_date(
@callback
@bind_hass
def is_up(
hass: HomeAssistantType, utc_point_in_time: Optional[datetime.datetime] = None
hass: HomeAssistantType, utc_point_in_time: datetime.datetime | None = None
) -> bool:
"""Calculate if the sun is currently up."""
if utc_point_in_time is None:

View File

@ -1,7 +1,9 @@
"""Helper to gather system info."""
from __future__ import annotations
import os
import platform
from typing import Any, Dict
from typing import Any
from homeassistant.const import __version__ as current_version
from homeassistant.loader import bind_hass
@ -11,7 +13,7 @@ from .typing import HomeAssistantType
@bind_hass
async def async_get_system_info(hass: HomeAssistantType) -> Dict[str, Any]:
async def async_get_system_info(hass: HomeAssistantType) -> dict[str, Any]:
"""Return info about the system."""
info_object = {
"installation_type": "Unknown",

View File

@ -1,6 +1,7 @@
"""Temperature helpers for Home Assistant."""
from __future__ import annotations
from numbers import Number
from typing import Optional
from homeassistant.const import PRECISION_HALVES, PRECISION_TENTHS
from homeassistant.core import HomeAssistant
@ -8,8 +9,8 @@ from homeassistant.util.temperature import convert as convert_temperature
def display_temp(
hass: HomeAssistant, temperature: Optional[float], unit: str, precision: float
) -> Optional[float]:
hass: HomeAssistant, temperature: float | None, unit: str, precision: float
) -> float | None:
"""Convert temperature into preferred units/precision for display."""
temperature_unit = unit
ha_unit = hass.config.units.temperature_unit

View File

@ -13,7 +13,7 @@ import math
from operator import attrgetter
import random
import re
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union, cast
from typing import Any, Generator, Iterable, cast
from urllib.parse import urlencode as urllib_urlencode
import weakref
@ -125,7 +125,7 @@ def is_template_string(maybe_template: str) -> bool:
class ResultWrapper:
"""Result wrapper class to store render result."""
render_result: Optional[str]
render_result: str | None
def gen_result_wrapper(kls):
@ -134,7 +134,7 @@ def gen_result_wrapper(kls):
class Wrapper(kls, ResultWrapper):
"""Wrapper of a kls that can store render_result."""
def __init__(self, *args: tuple, render_result: Optional[str] = None) -> None:
def __init__(self, *args: tuple, render_result: str | None = None) -> None:
super().__init__(*args)
self.render_result = render_result
@ -156,15 +156,13 @@ class TupleWrapper(tuple, ResultWrapper):
# This is all magic to be allowed to subclass a tuple.
def __new__(
cls, value: tuple, *, render_result: Optional[str] = None
) -> TupleWrapper:
def __new__(cls, value: tuple, *, render_result: str | None = None) -> TupleWrapper:
"""Create a new tuple class."""
return super().__new__(cls, tuple(value))
# pylint: disable=super-init-not-called
def __init__(self, value: tuple, *, render_result: Optional[str] = None):
def __init__(self, value: tuple, *, render_result: str | None = None):
"""Initialize a new tuple class."""
self.render_result = render_result
@ -176,7 +174,7 @@ class TupleWrapper(tuple, ResultWrapper):
return self.render_result
RESULT_WRAPPERS: Dict[Type, Type] = {
RESULT_WRAPPERS: dict[type, type] = {
kls: gen_result_wrapper(kls) # type: ignore[no-untyped-call]
for kls in (list, dict, set)
}
@ -200,15 +198,15 @@ class RenderInfo:
# Will be set sensibly once frozen.
self.filter_lifecycle = _true
self.filter = _true
self._result: Optional[str] = None
self._result: str | None = None
self.is_static = False
self.exception: Optional[TemplateError] = None
self.exception: TemplateError | None = None
self.all_states = False
self.all_states_lifecycle = False
self.domains = set()
self.domains_lifecycle = set()
self.entities = set()
self.rate_limit: Optional[timedelta] = None
self.rate_limit: timedelta | None = None
self.has_time = False
def __repr__(self) -> str:
@ -294,7 +292,7 @@ class Template:
self.template: str = template.strip()
self._compiled_code = None
self._compiled: Optional[Template] = None
self._compiled: Template | None = None
self.hass = hass
self.is_static = not is_template_string(template)
self._limited = None
@ -304,7 +302,7 @@ class Template:
if self.hass is None:
return _NO_HASS_ENV
wanted_env = _ENVIRONMENT_LIMITED if self._limited else _ENVIRONMENT
ret: Optional[TemplateEnvironment] = self.hass.data.get(wanted_env)
ret: TemplateEnvironment | None = self.hass.data.get(wanted_env)
if ret is None:
ret = self.hass.data[wanted_env] = TemplateEnvironment(self.hass, self._limited) # type: ignore[no-untyped-call]
return ret
@ -776,7 +774,7 @@ def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
entity_collect.entities.add(entity_id)
def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generator:
def _state_generator(hass: HomeAssistantType, domain: str | None) -> Generator:
"""State generator for a domain or all states."""
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")):
yield TemplateState(hass, state, collect=False)
@ -784,20 +782,20 @@ def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generato
def _get_state_if_valid(
hass: HomeAssistantType, entity_id: str
) -> Optional[TemplateState]:
) -> TemplateState | None:
state = hass.states.get(entity_id)
if state is None and not valid_entity_id(entity_id):
raise TemplateError(f"Invalid entity ID '{entity_id}'") # type: ignore
return _get_template_state_from_state(hass, entity_id, state)
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:
def _get_state(hass: HomeAssistantType, entity_id: str) -> TemplateState | None:
return _get_template_state_from_state(hass, entity_id, hass.states.get(entity_id))
def _get_template_state_from_state(
hass: HomeAssistantType, entity_id: str, state: Optional[State]
) -> Optional[TemplateState]:
hass: HomeAssistantType, entity_id: str, state: State | None
) -> TemplateState | None:
if state is None:
# Only need to collect if none, if not none collect first actual
# access to the state properties in the state wrapper.
@ -808,7 +806,7 @@ def _get_template_state_from_state(
def _resolve_state(
hass: HomeAssistantType, entity_id_or_state: Any
) -> Union[State, TemplateState, None]:
) -> State | TemplateState | None:
"""Return state or entity_id if given."""
if isinstance(entity_id_or_state, State):
return entity_id_or_state
@ -817,7 +815,7 @@ def _resolve_state(
return None
def result_as_boolean(template_result: Optional[str]) -> bool:
def result_as_boolean(template_result: str | None) -> bool:
"""Convert the template result to a boolean.
True/not 0/'1'/'true'/'yes'/'on'/'enable' are considered truthy

View File

@ -1,8 +1,10 @@
"""Helpers for script and condition tracing."""
from __future__ import annotations
from collections import deque
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Deque, Dict, Generator, List, Optional, Tuple, Union, cast
from typing import Any, Deque, Generator, cast
from homeassistant.helpers.typing import TemplateVarsType
import homeassistant.util.dt as dt_util
@ -13,9 +15,9 @@ class TraceElement:
def __init__(self, variables: TemplateVarsType, path: str):
"""Container for trace data."""
self._error: Optional[Exception] = None
self._error: Exception | None = None
self.path: str = path
self._result: Optional[dict] = None
self._result: dict | None = None
self._timestamp = dt_util.utcnow()
if variables is None:
@ -41,9 +43,9 @@ class TraceElement:
"""Set result."""
self._result = {**kwargs}
def as_dict(self) -> Dict[str, Any]:
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this TraceElement."""
result: Dict[str, Any] = {"path": self.path, "timestamp": self._timestamp}
result: dict[str, Any] = {"path": self.path, "timestamp": self._timestamp}
if self._variables:
result["changed_variables"] = self._variables
if self._error is not None:
@ -55,31 +57,31 @@ class TraceElement:
# Context variables for tracing
# Current trace
trace_cv: ContextVar[Optional[Dict[str, Deque[TraceElement]]]] = ContextVar(
trace_cv: ContextVar[dict[str, Deque[TraceElement]] | None] = ContextVar(
"trace_cv", default=None
)
# Stack of TraceElements
trace_stack_cv: ContextVar[Optional[List[TraceElement]]] = ContextVar(
trace_stack_cv: ContextVar[list[TraceElement] | None] = ContextVar(
"trace_stack_cv", default=None
)
# Current location in config tree
trace_path_stack_cv: ContextVar[Optional[List[str]]] = ContextVar(
trace_path_stack_cv: ContextVar[list[str] | None] = ContextVar(
"trace_path_stack_cv", default=None
)
# Copy of last variables
variables_cv: ContextVar[Optional[Any]] = ContextVar("variables_cv", default=None)
variables_cv: ContextVar[Any | None] = ContextVar("variables_cv", default=None)
# Automation ID + Run ID
trace_id_cv: ContextVar[Optional[Tuple[str, str]]] = ContextVar(
trace_id_cv: ContextVar[tuple[str, str] | None] = ContextVar(
"trace_id_cv", default=None
)
def trace_id_set(trace_id: Tuple[str, str]) -> None:
def trace_id_set(trace_id: tuple[str, str]) -> None:
"""Set id of the current trace."""
trace_id_cv.set(trace_id)
def trace_id_get() -> Optional[Tuple[str, str]]:
def trace_id_get() -> tuple[str, str] | None:
"""Get id if the current trace."""
return trace_id_cv.get()
@ -99,13 +101,13 @@ def trace_stack_pop(trace_stack_var: ContextVar) -> None:
trace_stack.pop()
def trace_stack_top(trace_stack_var: ContextVar) -> Optional[Any]:
def trace_stack_top(trace_stack_var: ContextVar) -> Any | None:
"""Return the element at the top of a trace stack."""
trace_stack = trace_stack_var.get()
return trace_stack[-1] if trace_stack else None
def trace_path_push(suffix: Union[str, List[str]]) -> int:
def trace_path_push(suffix: str | list[str]) -> int:
"""Go deeper in the config tree."""
if isinstance(suffix, str):
suffix = [suffix]
@ -130,7 +132,7 @@ def trace_path_get() -> str:
def trace_append_element(
trace_element: TraceElement,
maxlen: Optional[int] = None,
maxlen: int | None = None,
) -> None:
"""Append a TraceElement to trace[path]."""
path = trace_element.path
@ -143,7 +145,7 @@ def trace_append_element(
trace[path].append(trace_element)
def trace_get(clear: bool = True) -> Optional[Dict[str, Deque[TraceElement]]]:
def trace_get(clear: bool = True) -> dict[str, Deque[TraceElement]] | None:
"""Return the current trace."""
if clear:
trace_clear()
@ -165,7 +167,7 @@ def trace_set_result(**kwargs: Any) -> None:
@contextmanager
def trace_path(suffix: Union[str, List[str]]) -> Generator:
def trace_path(suffix: str | list[str]) -> Generator:
"""Go deeper in the config tree."""
count = trace_path_push(suffix)
try:

View File

@ -1,8 +1,10 @@
"""Translation string lookup helpers."""
from __future__ import annotations
import asyncio
from collections import ChainMap
import logging
from typing import Any, Dict, List, Optional, Set
from typing import Any
from homeassistant.core import callback
from homeassistant.loader import (
@ -24,7 +26,7 @@ TRANSLATION_FLATTEN_CACHE = "translation_flatten_cache"
LOCALE_EN = "en"
def recursive_flatten(prefix: Any, data: Dict) -> Dict[str, Any]:
def recursive_flatten(prefix: Any, data: dict) -> dict[str, Any]:
"""Return a flattened representation of dict data."""
output = {}
for key, value in data.items():
@ -38,7 +40,7 @@ def recursive_flatten(prefix: Any, data: Dict) -> Dict[str, Any]:
@callback
def component_translation_path(
component: str, language: str, integration: Integration
) -> Optional[str]:
) -> str | None:
"""Return the translation json file location for a component.
For component:
@ -69,8 +71,8 @@ def component_translation_path(
def load_translations_files(
translation_files: Dict[str, str]
) -> Dict[str, Dict[str, Any]]:
translation_files: dict[str, str]
) -> dict[str, dict[str, Any]]:
"""Load and parse translation.json files."""
loaded = {}
for component, translation_file in translation_files.items():
@ -90,13 +92,13 @@ def load_translations_files(
def _merge_resources(
translation_strings: Dict[str, Dict[str, Any]],
components: Set[str],
translation_strings: dict[str, dict[str, Any]],
components: set[str],
category: str,
) -> Dict[str, Dict[str, Any]]:
) -> dict[str, dict[str, Any]]:
"""Build and merge the resources response for the given components and platforms."""
# Build response
resources: Dict[str, Dict[str, Any]] = {}
resources: dict[str, dict[str, Any]] = {}
for component in components:
if "." not in component:
domain = component
@ -131,10 +133,10 @@ def _merge_resources(
def _build_resources(
translation_strings: Dict[str, Dict[str, Any]],
components: Set[str],
translation_strings: dict[str, dict[str, Any]],
components: set[str],
category: str,
) -> Dict[str, Dict[str, Any]]:
) -> dict[str, dict[str, Any]]:
"""Build the resources response for the given components."""
# Build response
return {
@ -146,8 +148,8 @@ def _build_resources(
async def async_get_component_strings(
hass: HomeAssistantType, language: str, components: Set[str]
) -> Dict[str, Any]:
hass: HomeAssistantType, language: str, components: set[str]
) -> dict[str, Any]:
"""Load translations."""
domains = list({loaded.split(".")[-1] for loaded in components})
integrations = dict(
@ -160,7 +162,7 @@ async def async_get_component_strings(
)
)
translations: Dict[str, Any] = {}
translations: dict[str, Any] = {}
# Determine paths of missing components/platforms
files_to_load = {}
@ -205,15 +207,15 @@ class _TranslationCache:
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the cache."""
self.hass = hass
self.loaded: Dict[str, Set[str]] = {}
self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
self.loaded: dict[str, set[str]] = {}
self.cache: dict[str, dict[str, dict[str, Any]]] = {}
async def async_fetch(
self,
language: str,
category: str,
components: Set,
) -> List[Dict[str, Dict[str, Any]]]:
components: set,
) -> list[dict[str, dict[str, Any]]]:
"""Load resources into the cache."""
components_to_load = components - self.loaded.setdefault(language, set())
@ -224,7 +226,7 @@ class _TranslationCache:
return [cached.get(component, {}).get(category, {}) for component in components]
async def _async_load(self, language: str, components: Set) -> None:
async def _async_load(self, language: str, components: set) -> None:
"""Populate the cache for a given set of components."""
_LOGGER.debug(
"Cache miss for %s: %s",
@ -247,12 +249,12 @@ class _TranslationCache:
def _build_category_cache(
self,
language: str,
components: Set,
translation_strings: Dict[str, Dict[str, Any]],
components: set,
translation_strings: dict[str, dict[str, Any]],
) -> None:
"""Extract resources into the cache."""
cached = self.cache.setdefault(language, {})
categories: Set[str] = set()
categories: set[str] = set()
for resource in translation_strings.values():
categories.update(resource)
@ -263,7 +265,7 @@ class _TranslationCache:
new_resources = resource_func(translation_strings, components, category)
for component, resource in new_resources.items():
category_cache: Dict[str, Any] = cached.setdefault(
category_cache: dict[str, Any] = cached.setdefault(
component, {}
).setdefault(category, {})
@ -283,9 +285,9 @@ async def async_get_translations(
hass: HomeAssistantType,
language: str,
category: str,
integration: Optional[str] = None,
config_flow: Optional[bool] = None,
) -> Dict[str, Any]:
integration: str | None = None,
config_flow: bool | None = None,
) -> dict[str, Any]:
"""Return all backend translations.
If integration specified, load it for that one.

View File

@ -1,8 +1,10 @@
"""Triggers."""
from __future__ import annotations
import asyncio
import logging
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable
import voluptuous as vol
@ -39,8 +41,8 @@ async def _async_get_trigger_platform(
async def async_validate_trigger_config(
hass: HomeAssistantType, trigger_config: List[ConfigType]
) -> List[ConfigType]:
hass: HomeAssistantType, trigger_config: list[ConfigType]
) -> list[ConfigType]:
"""Validate triggers."""
config = []
for conf in trigger_config:
@ -55,14 +57,14 @@ async def async_validate_trigger_config(
async def async_initialize_triggers(
hass: HomeAssistantType,
trigger_config: List[ConfigType],
trigger_config: list[ConfigType],
action: Callable,
domain: str,
name: str,
log_cb: Callable,
home_assistant_start: bool = False,
variables: Optional[Union[Dict[str, Any], MappingProxyType]] = None,
) -> Optional[CALLBACK_TYPE]:
variables: dict[str, Any] | MappingProxyType | None = None,
) -> CALLBACK_TYPE | None:
"""Initialize triggers."""
info = {
"domain": domain,

View File

@ -1,9 +1,11 @@
"""Helpers to help coordinate updates."""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import logging
from time import monotonic
from typing import Any, Awaitable, Callable, Generic, List, Optional, TypeVar
from typing import Any, Awaitable, Callable, Generic, TypeVar
import urllib.error
import aiohttp
@ -37,9 +39,9 @@ class DataUpdateCoordinator(Generic[T]):
logger: logging.Logger,
*,
name: str,
update_interval: Optional[timedelta] = None,
update_method: Optional[Callable[[], Awaitable[T]]] = None,
request_refresh_debouncer: Optional[Debouncer] = None,
update_interval: timedelta | None = None,
update_method: Callable[[], Awaitable[T]] | None = None,
request_refresh_debouncer: Debouncer | None = None,
):
"""Initialize global data updater."""
self.hass = hass
@ -48,12 +50,12 @@ class DataUpdateCoordinator(Generic[T]):
self.update_method = update_method
self.update_interval = update_interval
self.data: Optional[T] = None
self.data: T | None = None
self._listeners: List[CALLBACK_TYPE] = []
self._listeners: list[CALLBACK_TYPE] = []
self._job = HassJob(self._handle_refresh_interval)
self._unsub_refresh: Optional[CALLBACK_TYPE] = None
self._request_refresh_task: Optional[asyncio.TimerHandle] = None
self._unsub_refresh: CALLBACK_TYPE | None = None
self._request_refresh_task: asyncio.TimerHandle | None = None
self.last_update_success = True
if request_refresh_debouncer is None:
@ -132,7 +134,7 @@ class DataUpdateCoordinator(Generic[T]):
"""
await self._debounced_refresh.async_call()
async def _async_update_data(self) -> Optional[T]:
async def _async_update_data(self) -> T | None:
"""Fetch the latest data from the source."""
if self.update_method is None:
raise NotImplementedError("Update method not implemented")