diff --git a/tests/common.py b/tests/common.py index 40745a1df9e..55f7cadfd4b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -8,6 +8,8 @@ from collections.abc import ( Callable, Coroutine, Generator, + Iterable, + Iterator, Mapping, Sequence, ) @@ -30,6 +32,7 @@ from unittest.mock import AsyncMock, Mock, patch from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401 import pytest from syrupy import SnapshotAssertion +from typing_extensions import TypeVar import voluptuous as vol from homeassistant import auth, bootstrap, config_entries, loader @@ -90,6 +93,7 @@ from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder, jso from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.util.async_ import run_callback_threadsafe import homeassistant.util.dt as dt_util +from homeassistant.util.event_type import EventType from homeassistant.util.json import ( JsonArrayType, JsonObjectType, @@ -107,6 +111,8 @@ from .testing_config.custom_components.test_constant_deprecation import ( import_deprecated_constant, ) +_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=dict[str, Any]) + _LOGGER = logging.getLogger(__name__) INSTANCES = [] CLIENT_ID = "https://example.com/app" @@ -1434,7 +1440,7 @@ async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str, @contextmanager -def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> None: +def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> Iterator[None]: """Mock a config flow handler.""" original_handler = config_entries.HANDLERS.get(domain) config_entries.HANDLERS[domain] = config_flow @@ -1502,12 +1508,14 @@ def mock_platform( module_cache[platform_path] = module or Mock() -def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]: +def async_capture_events( + hass: HomeAssistant, event_name: EventType[_DataT] | str +) -> list[Event[_DataT]]: """Create a helper that captures events.""" - events = [] + events: list[Event[_DataT]] = [] @callback - def capture_events(event: Event) -> None: + def capture_events(event: Event[_DataT]) -> None: events.append(event) hass.bus.async_listen(event_name, capture_events) @@ -1516,14 +1524,14 @@ def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]: @callback -def async_mock_signal( - hass: HomeAssistant, signal: SignalType[Any] | str -) -> list[tuple[Any]]: +def async_mock_signal[*_Ts]( + hass: HomeAssistant, signal: SignalType[*_Ts] | str +) -> list[tuple[*_Ts]]: """Catch all dispatches to a signal.""" - calls = [] + calls: list[tuple[*_Ts]] = [] @callback - def mock_signal_handler(*args: Any) -> None: + def mock_signal_handler(*args: *_Ts) -> None: """Mock service call.""" calls.append(args) @@ -1723,7 +1731,7 @@ def extract_stack_to_frame(extract_stack: list[Mock]) -> FrameType: def setup_test_component_platform( hass: HomeAssistant, domain: str, - entities: Sequence[Entity], + entities: Iterable[Entity], from_config_entry: bool = False, built_in: bool = True, ) -> MockPlatform: