diff --git a/.strict-typing b/.strict-typing index 74a255a7f96..621c6c315fc 100644 --- a/.strict-typing +++ b/.strict-typing @@ -22,6 +22,7 @@ homeassistant.helpers.script_variables homeassistant.helpers.translation homeassistant.util.async_ homeassistant.util.color +homeassistant.util.decorator homeassistant.util.process homeassistant.util.unit_system diff --git a/homeassistant/auth/mfa_modules/__init__.py b/homeassistant/auth/mfa_modules/__init__.py index bb81d1fb04f..61c36da6e90 100644 --- a/homeassistant/auth/mfa_modules/__init__.py +++ b/homeassistant/auth/mfa_modules/__init__.py @@ -16,7 +16,7 @@ from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError from homeassistant.util.decorator import Registry -MULTI_FACTOR_AUTH_MODULES = Registry() +MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry() MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema( { @@ -129,7 +129,7 @@ async def auth_mfa_module_from_config( hass: HomeAssistant, config: dict[str, Any] ) -> MultiFactorAuthModule: """Initialize an auth module from a config.""" - module_name = config[CONF_TYPE] + module_name: str = config[CONF_TYPE] module = await _load_mfa_module(hass, module_name) try: @@ -142,7 +142,7 @@ async def auth_mfa_module_from_config( ) raise - return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config) # type: ignore[no-any-return] + return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config) async def _load_mfa_module(hass: HomeAssistant, module_name: str) -> types.ModuleType: diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index d80d7a5273b..63389059051 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -25,7 +25,7 @@ from ..models import Credentials, RefreshToken, User, UserMeta _LOGGER = logging.getLogger(__name__) DATA_REQS = "auth_prov_reqs_processed" -AUTH_PROVIDERS = Registry() +AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry() AUTH_PROVIDER_SCHEMA = vol.Schema( { @@ -136,7 +136,7 @@ async def auth_provider_from_config( hass: HomeAssistant, store: AuthStore, config: dict[str, Any] ) -> AuthProvider: """Initialize an auth provider from a config.""" - provider_name = config[CONF_TYPE] + provider_name: str = config[CONF_TYPE] module = await load_auth_provider_module(hass, provider_name) try: @@ -149,7 +149,7 @@ async def auth_provider_from_config( ) raise - return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore[no-any-return] + return AUTH_PROVIDERS[provider_name](hass, store, config) async def load_auth_provider_module( diff --git a/homeassistant/components/alexa/entities.py b/homeassistant/components/alexa/entities.py index 1ab24927bcb..5ecd326afb6 100644 --- a/homeassistant/components/alexa/entities.py +++ b/homeassistant/components/alexa/entities.py @@ -83,7 +83,7 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) -ENTITY_ADAPTERS = Registry() +ENTITY_ADAPTERS: Registry[str, type[AlexaEntity]] = Registry() TRANSLATION_TABLE = dict.fromkeys(map(ord, r"}{\/|\"()[]+~!><*%"), None) diff --git a/homeassistant/components/alexa/handlers.py b/homeassistant/components/alexa/handlers.py index f3f669de3b3..a27bc432b4f 100644 --- a/homeassistant/components/alexa/handlers.py +++ b/homeassistant/components/alexa/handlers.py @@ -73,7 +73,7 @@ from .errors import ( from .state_report import async_enable_proactive_mode _LOGGER = logging.getLogger(__name__) -HANDLERS = Registry() +HANDLERS = Registry() # type: ignore[var-annotated] @HANDLERS.register(("Alexa.Discovery", "Discover")) diff --git a/homeassistant/components/alexa/intent.py b/homeassistant/components/alexa/intent.py index 0b8bf55fcda..7352bbd995a 100644 --- a/homeassistant/components/alexa/intent.py +++ b/homeassistant/components/alexa/intent.py @@ -12,7 +12,7 @@ from .const import DOMAIN, SYN_RESOLUTION_MATCH _LOGGER = logging.getLogger(__name__) -HANDLERS = Registry() +HANDLERS = Registry() # type: ignore[var-annotated] INTENTS_API_ENDPOINT = "/api/alexa" diff --git a/homeassistant/components/filter/sensor.py b/homeassistant/components/filter/sensor.py index d2ad3ec313c..a5b54a621a7 100644 --- a/homeassistant/components/filter/sensor.py +++ b/homeassistant/components/filter/sensor.py @@ -52,7 +52,7 @@ FILTER_NAME_OUTLIER = "outlier" FILTER_NAME_THROTTLE = "throttle" FILTER_NAME_TIME_THROTTLE = "time_throttle" FILTER_NAME_TIME_SMA = "time_simple_moving_average" -FILTERS = Registry() +FILTERS: Registry[str, type[Filter]] = Registry() CONF_FILTERS = "filters" CONF_FILTER_NAME = "filter" diff --git a/homeassistant/components/google_assistant/smart_home.py b/homeassistant/components/google_assistant/smart_home.py index 80bc61cc61d..5f38194e3e3 100644 --- a/homeassistant/components/google_assistant/smart_home.py +++ b/homeassistant/components/google_assistant/smart_home.py @@ -19,7 +19,7 @@ from .helpers import GoogleEntity, RequestData, async_get_entities EXECUTE_LIMIT = 2 # Wait 2 seconds for execute to finish -HANDLERS = Registry() +HANDLERS = Registry() # type: ignore[var-annotated] _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/homekit/accessories.py b/homeassistant/components/homekit/accessories.py index 922c4c52568..d348b4c1f42 100644 --- a/homeassistant/components/homekit/accessories.py +++ b/homeassistant/components/homekit/accessories.py @@ -1,4 +1,6 @@ """Extend the basic Accessory and Bridge functions.""" +from __future__ import annotations + import logging from pyhap.accessory import Accessory, Bridge @@ -90,7 +92,7 @@ SWITCH_TYPES = { TYPE_SWITCH: "Switch", TYPE_VALVE: "Valve", } -TYPES = Registry() +TYPES: Registry[str, type[HomeAccessory]] = Registry() def get_accessory(hass, driver, state, aid, config): # noqa: C901 diff --git a/homeassistant/components/konnected/handlers.py b/homeassistant/components/konnected/handlers.py index ef878fc6f2b..af784750627 100644 --- a/homeassistant/components/konnected/handlers.py +++ b/homeassistant/components/konnected/handlers.py @@ -9,7 +9,7 @@ from homeassistant.util import decorator from .const import CONF_INVERSE, SIGNAL_DS18B20_NEW _LOGGER = logging.getLogger(__name__) -HANDLERS = decorator.Registry() +HANDLERS = decorator.Registry() # type: ignore[var-annotated] @HANDLERS.register("state") diff --git a/homeassistant/components/mobile_app/webhook.py b/homeassistant/components/mobile_app/webhook.py index d659d7625c1..221c4eef733 100644 --- a/homeassistant/components/mobile_app/webhook.py +++ b/homeassistant/components/mobile_app/webhook.py @@ -109,7 +109,7 @@ _LOGGER = logging.getLogger(__name__) DELAY_SAVE = 10 -WEBHOOK_COMMANDS = Registry() +WEBHOOK_COMMANDS = Registry() # type: ignore[var-annotated] COMBINED_CLASSES = set(BINARY_SENSOR_CLASSES + SENSOR_CLASSES) SENSOR_TYPES = [ATTR_SENSOR_TYPE_BINARY_SENSOR, ATTR_SENSOR_TYPE_SENSOR] diff --git a/homeassistant/components/mysensors/gateway.py b/homeassistant/components/mysensors/gateway.py index b167c8c58de..be0381ab74e 100644 --- a/homeassistant/components/mysensors/gateway.py +++ b/homeassistant/components/mysensors/gateway.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from collections import defaultdict -from collections.abc import Callable, Coroutine +from collections.abc import Callable import logging import socket import sys @@ -337,9 +337,7 @@ def _gw_callback_factory( _LOGGER.debug("Node update: node %s child %s", msg.node_id, msg.child_id) msg_type = msg.gateway.const.MessageType(msg.type) - msg_handler: Callable[ - [HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None] - ] | None = HANDLERS.get(msg_type.name) + msg_handler = HANDLERS.get(msg_type.name) if msg_handler is None: return diff --git a/homeassistant/components/mysensors/handler.py b/homeassistant/components/mysensors/handler.py index 4d61a2812ae..57ff12fc6f0 100644 --- a/homeassistant/components/mysensors/handler.py +++ b/homeassistant/components/mysensors/handler.py @@ -1,6 +1,9 @@ """Handle MySensors messages.""" from __future__ import annotations +from collections.abc import Callable, Coroutine +from typing import Any + from mysensors import Message from homeassistant.const import Platform @@ -12,7 +15,9 @@ from .const import CHILD_CALLBACK, NODE_CALLBACK, DevId, GatewayId from .device import get_mysensors_devices from .helpers import discover_mysensors_platform, validate_set_msg -HANDLERS = decorator.Registry() +HANDLERS: decorator.Registry[ + str, Callable[[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]] +] = decorator.Registry() @HANDLERS.register("set") diff --git a/homeassistant/components/mysensors/helpers.py b/homeassistant/components/mysensors/helpers.py index 5b6682393b1..a5f67111738 100644 --- a/homeassistant/components/mysensors/helpers.py +++ b/homeassistant/components/mysensors/helpers.py @@ -31,7 +31,9 @@ from .const import ( ) _LOGGER = logging.getLogger(__name__) -SCHEMAS = Registry() +SCHEMAS: Registry[ + tuple[str, str], Callable[[BaseAsyncGateway, ChildSensor, ValueType], vol.Schema] +] = Registry() @callback diff --git a/homeassistant/components/onvif/parsers.py b/homeassistant/components/onvif/parsers.py index 9574d44edea..b518dbbb451 100644 --- a/homeassistant/components/onvif/parsers.py +++ b/homeassistant/components/onvif/parsers.py @@ -1,10 +1,13 @@ """ONVIF event parsers.""" +from collections.abc import Callable, Coroutine +from typing import Any + from homeassistant.util import dt as dt_util from homeassistant.util.decorator import Registry from .models import Event -PARSERS = Registry() +PARSERS: Registry[str, Callable[[str, Any], Coroutine[Any, Any, Event]]] = Registry() @PARSERS.register("tns1:VideoSource/MotionAlarm") diff --git a/homeassistant/components/overkiz/coordinator.py b/homeassistant/components/overkiz/coordinator.py index 98031926cfd..d90a52ae409 100644 --- a/homeassistant/components/overkiz/coordinator.py +++ b/homeassistant/components/overkiz/coordinator.py @@ -1,8 +1,10 @@ """Helpers to help coordinate updates.""" from __future__ import annotations +from collections.abc import Callable, Coroutine from datetime import timedelta import logging +from typing import Any from aiohttp import ServerDisconnectedError from pyoverkiz.client import OverkizClient @@ -25,7 +27,9 @@ from homeassistant.util.decorator import Registry from .const import DOMAIN, LOGGER, UPDATE_INTERVAL -EVENT_HANDLERS = Registry() +EVENT_HANDLERS: Registry[ + str, Callable[[OverkizDataUpdateCoordinator, Event], Coroutine[Any, Any, None]] +] = Registry() class OverkizDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Device]]): diff --git a/homeassistant/components/owntracks/messages.py b/homeassistant/components/owntracks/messages.py index bd01284329b..b85a37dadf9 100644 --- a/homeassistant/components/owntracks/messages.py +++ b/homeassistant/components/owntracks/messages.py @@ -17,7 +17,7 @@ from .helper import supports_encryption _LOGGER = logging.getLogger(__name__) -HANDLERS = decorator.Registry() +HANDLERS = decorator.Registry() # type: ignore[var-annotated] def get_cipher(): diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index e22e06df7e2..157f20b5b37 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -245,7 +245,7 @@ class Stream: self, fmt: str, timeout: int = OUTPUT_IDLE_TIMEOUT ) -> StreamOutput: """Add provider output stream.""" - if not self._outputs.get(fmt): + if not (provider := self._outputs.get(fmt)): @callback def idle_callback() -> None: @@ -259,7 +259,7 @@ class Stream: self.hass, IdleTimer(self.hass, timeout, idle_callback) ) self._outputs[fmt] = provider - return self._outputs[fmt] + return provider def remove_provider(self, provider: StreamOutput) -> None: """Remove provider output stream.""" diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index 91414dd96d9..8db6a239818 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from . import Stream -PROVIDERS = Registry() +PROVIDERS: Registry[str, type[StreamOutput]] = Registry() @attr.s(slots=True) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index af04ee032dc..7bc37bcd305 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -62,7 +62,7 @@ SOURCE_UNIGNORE = "unignore" # This is used to signal that re-authentication is required by the user. SOURCE_REAUTH = "reauth" -HANDLERS = Registry() +HANDLERS: Registry[str, type[ConfigFlow]] = Registry() STORAGE_KEY = "core.config_entries" STORAGE_VERSION = 1 @@ -530,8 +530,10 @@ class ConfigEntry: ) return False # Handler may be a partial + # Keep for backwards compatibility + # https://github.com/home-assistant/core/pull/67087#discussion_r812559950 while isinstance(handler, functools.partial): - handler = handler.func + handler = handler.func # type: ignore[unreachable] if self.version == handler.VERSION: return True @@ -753,7 +755,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): if not context or "source" not in context: raise KeyError("Context not set or doesn't have a source set") - flow = cast(ConfigFlow, handler()) + flow = handler() flow.init_step = context["source"] return flow @@ -1496,7 +1498,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager): if entry.domain not in HANDLERS: raise data_entry_flow.UnknownHandler - return cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry)) + return HANDLERS[entry.domain].async_get_options_flow(entry) async def async_finish_flow( self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 38fe621f96c..f280feb83b2 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -13,7 +13,7 @@ from homeassistant.util import decorator from . import config_validation as cv -SELECTORS = decorator.Registry() +SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry() def _get_selector_class(config: Any) -> type[Selector]: @@ -24,12 +24,12 @@ def _get_selector_class(config: Any) -> type[Selector]: if len(config) != 1: raise vol.Invalid(f"Only one type can be specified. Found {', '.join(config)}") - selector_type = list(config)[0] + selector_type: str = list(config)[0] if (selector_class := SELECTORS.get(selector_type)) is None: raise vol.Invalid(f"Unknown selector type {selector_type} found") - return cast(type[Selector], selector_class) + return selector_class def selector(config: Any) -> Selector: diff --git a/homeassistant/util/decorator.py b/homeassistant/util/decorator.py index 602cdba5598..c648f6f1cab 100644 --- a/homeassistant/util/decorator.py +++ b/homeassistant/util/decorator.py @@ -2,18 +2,19 @@ from __future__ import annotations from collections.abc import Callable, Hashable -from typing import TypeVar +from typing import Any, TypeVar -CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name +_KT = TypeVar("_KT", bound=Hashable) +_VT = TypeVar("_VT", bound=Callable[..., Any]) -class Registry(dict): +class Registry(dict[_KT, _VT]): """Registry of items.""" - def register(self, name: Hashable) -> Callable[[CALLABLE_T], CALLABLE_T]: + def register(self, name: _KT) -> Callable[[_VT], _VT]: """Return decorator to register item with a specific name.""" - def decorator(func: CALLABLE_T) -> CALLABLE_T: + def decorator(func: _VT) -> _VT: """Register decorated function.""" self[name] = func return func diff --git a/mypy.ini b/mypy.ini index 3f8386a2a27..781bc5c199e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -76,6 +76,9 @@ disallow_any_generics = true [mypy-homeassistant.util.color] disallow_any_generics = true +[mypy-homeassistant.util.decorator] +disallow_any_generics = true + [mypy-homeassistant.util.process] disallow_any_generics = true