From 4c2b20db6892f6cf755bcdafd72d073b2819a273 Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Fri, 13 Jan 2023 15:12:11 +0100 Subject: [PATCH] Collection of typing improvements in common test helpers (#85509) Co-authored-by: Martin Hjelmare --- tests/common.py | 124 +++++++++++++++++++++++++++--------------------- 1 file changed, 70 insertions(+), 54 deletions(-) diff --git a/tests/common.py b/tests/common.py index eaa31851f0c..3e947b60d3a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from collections import OrderedDict -from collections.abc import Awaitable, Callable, Collection +from collections.abc import Awaitable, Callable, Collection, Mapping, Sequence from contextlib import contextmanager from datetime import datetime, timedelta, timezone import functools as ft @@ -16,13 +16,13 @@ import threading import time from time import monotonic import types -from typing import Any +from typing import Any, NoReturn from unittest.mock import AsyncMock, Mock, patch from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401 import voluptuous as vol -from homeassistant import auth, bootstrap, config_entries, core as ha, loader +from homeassistant import auth, bootstrap, config_entries, loader from homeassistant.auth import ( auth_store, models as auth_models, @@ -42,7 +42,15 @@ from homeassistant.const import ( STATE_OFF, STATE_ON, ) -from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, ServiceCall, State +from homeassistant.core import ( + BLOCK_LOG_TIMEOUT, + CoreState, + Event, + HomeAssistant, + ServiceCall, + State, + callback, +) from homeassistant.helpers import ( area_registry, device_registry, @@ -57,7 +65,7 @@ from homeassistant.helpers import ( ) from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.json import JSONEncoder -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, StateType from homeassistant.setup import setup_component from homeassistant.util.async_ import run_callback_threadsafe import homeassistant.util.dt as date_util @@ -161,7 +169,7 @@ def get_test_home_assistant(): # pylint: disable=protected-access async def async_test_home_assistant(event_loop, load_registries=True): """Return a Home Assistant object pointing at test config dir.""" - hass = ha.HomeAssistant() + hass = HomeAssistant() store = auth_store.AuthStore(hass) hass.auth = auth.AuthManager(hass, store, {}, {}) ensure_auth_manager_loaded(hass.auth) @@ -308,7 +316,7 @@ async def async_test_home_assistant(event_loop, load_registries=True): await hass.async_block_till_done() hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None - hass.state = ha.CoreState.running + hass.state = CoreState.running # Mock async_start orig_start = hass.async_start @@ -321,7 +329,7 @@ async def async_test_home_assistant(event_loop, load_registries=True): hass.async_start = mock_async_start - @ha.callback + @callback def clear_instance(event): """Clear global instance.""" INSTANCES.remove(hass) @@ -337,7 +345,7 @@ def async_mock_service( """Set up a fake service & return a calls log list to this service.""" calls = [] - @ha.callback + @callback def mock_service_log(call): # pylint: disable=unnecessary-lambda """Mock service call.""" calls.append(call) @@ -350,7 +358,7 @@ def async_mock_service( mock_service = threadsafe_callback_factory(async_mock_service) -@ha.callback +@callback def async_mock_intent(hass, intent_typ): """Set up a fake intent handler.""" intents = [] @@ -368,7 +376,7 @@ def async_mock_intent(hass, intent_typ): return intents -@ha.callback +@callback def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False): """Fire the MQTT message.""" # Local import to avoid processing MQTT modules when running a testcase @@ -384,7 +392,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False): fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) -@ha.callback +@callback def async_fire_time_changed_exact( hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False ) -> None: @@ -403,7 +411,7 @@ def async_fire_time_changed_exact( _async_fire_time_changed(hass, utc_datetime, fire_all) -@ha.callback +@callback def async_fire_time_changed( hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False ) -> None: @@ -432,7 +440,7 @@ def async_fire_time_changed( _async_fire_time_changed(hass, utc_datetime, fire_all) -@ha.callback +@callback def _async_fire_time_changed( hass: HomeAssistant, utc_datetime: datetime | None, fire_all: bool ) -> None: @@ -491,7 +499,7 @@ def mock_state_change_event( hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context) -@ha.callback +@callback def mock_component(hass: HomeAssistant, component: str) -> None: """Mock a component is setup.""" if component in hass.config.components: @@ -624,7 +632,7 @@ async def register_auth_provider( return provider -@ha.callback +@callback def ensure_auth_manager_loaded(auth_mgr): """Ensure an auth manager is considered loaded.""" store = auth_mgr._store @@ -995,7 +1003,7 @@ def init_recorder_component(hass, add_config=None, db_url="sqlite://"): ) -def mock_restore_cache(hass, states): +def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None: """Mock the DATA_RESTORE_CACHE.""" key = restore_state.DATA_RESTORE_STATE_TASK data = restore_state.RestoreStateData(hass) @@ -1020,7 +1028,9 @@ def mock_restore_cache(hass, states): hass.data[key] = data -def mock_restore_cache_with_extra_data(hass, states): +def mock_restore_cache_with_extra_data( + hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]] +) -> None: """Mock the DATA_RESTORE_CACHE.""" key = restore_state.DATA_RESTORE_STATE_TASK data = restore_state.RestoreStateData(hass) @@ -1048,7 +1058,7 @@ def mock_restore_cache_with_extra_data(hass, states): class MockEntity(entity.Entity): """Mock Entity class.""" - def __init__(self, **values): + def __init__(self, **values: Any) -> None: """Initialize an entity.""" self._values = values @@ -1056,86 +1066,86 @@ class MockEntity(entity.Entity): self.entity_id = values["entity_id"] @property - def available(self): + def available(self) -> bool: """Return True if entity is available.""" return self._handle("available") @property - def capability_attributes(self): + def capability_attributes(self) -> Mapping[str, Any] | None: """Info about capabilities.""" return self._handle("capability_attributes") @property - def device_class(self): + def device_class(self) -> str | None: """Info how device should be classified.""" return self._handle("device_class") @property - def device_info(self): + def device_info(self) -> entity.DeviceInfo | None: """Info how it links to a device.""" return self._handle("device_info") @property - def entity_category(self): + def entity_category(self) -> entity.EntityCategory | None: """Return the entity category.""" return self._handle("entity_category") @property - def has_entity_name(self): + def has_entity_name(self) -> bool: """Return the has_entity_name name flag.""" return self._handle("has_entity_name") @property - def entity_registry_enabled_default(self): + def entity_registry_enabled_default(self) -> bool: """Return if the entity should be enabled when first added to the entity registry.""" return self._handle("entity_registry_enabled_default") @property - def entity_registry_visible_default(self): + def entity_registry_visible_default(self) -> bool: """Return if the entity should be visible when first added to the entity registry.""" return self._handle("entity_registry_visible_default") @property - def icon(self): + def icon(self) -> str | None: """Return the suggested icon.""" return self._handle("icon") @property - def name(self): + def name(self) -> str | None: """Return the name of the entity.""" return self._handle("name") @property - def should_poll(self): + def should_poll(self) -> bool: """Return the ste of the polling.""" return self._handle("should_poll") @property - def state(self): + def state(self) -> StateType: """Return the state of the entity.""" return self._handle("state") @property - def supported_features(self): + def supported_features(self) -> int | None: """Info about supported features.""" return self._handle("supported_features") @property - def translation_key(self): + def translation_key(self) -> str | None: """Return the translation key.""" return self._handle("translation_key") @property - def unique_id(self): + def unique_id(self) -> str | None: """Return the unique ID of the entity.""" return self._handle("unique_id") @property - def unit_of_measurement(self): + def unit_of_measurement(self) -> str | None: """Info on the units the entity state is in.""" return self._handle("unit_of_measurement") - def _handle(self, attr): + def _handle(self, attr: str) -> Any: """Return attribute value.""" if attr in self._values: return self._values[attr] @@ -1202,7 +1212,7 @@ def mock_storage(data=None): yield data -async def flush_store(store): +async def flush_store(store: storage.Store) -> None: """Make sure all delayed writes of a store are written.""" if store._data is None: return @@ -1212,12 +1222,14 @@ async def flush_store(store): await store._async_handle_write_data() -async def get_system_health_info(hass, domain): +async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str, Any]: """Get system health info.""" return await hass.data["system_health"][domain].info_callback(hass) -def mock_integration(hass, module, built_in=True): +def mock_integration( + hass: HomeAssistant, module: MockModule, built_in: bool = True +) -> loader.Integration: """Mock an integration.""" integration = loader.Integration( hass, @@ -1228,7 +1240,7 @@ def mock_integration(hass, module, built_in=True): module.mock_manifest(), ) - def mock_import_platform(platform_name): + def mock_import_platform(platform_name: str) -> NoReturn: raise ImportError( f"Mocked unable to import platform '{platform_name}'", name=f"{integration.pkg_path}.{platform_name}", @@ -1243,7 +1255,9 @@ def mock_integration(hass, module, built_in=True): return integration -def mock_entity_platform(hass, platform_path, module): +def mock_entity_platform( + hass: HomeAssistant, platform_path: str, module: MockPlatform | None +) -> None: """Mock a entity platform. platform_path is in form light.hue. Will create platform @@ -1253,7 +1267,9 @@ def mock_entity_platform(hass, platform_path, module): mock_platform(hass, f"{platform_name}.{domain}", module) -def mock_platform(hass, platform_path, module=None): +def mock_platform( + hass: HomeAssistant, platform_path: str, module: Mock | MockPlatform | None = None +) -> None: """Mock a platform. platform_path is in form hue.config_flow. @@ -1269,12 +1285,12 @@ def mock_platform(hass, platform_path, module=None): module_cache[platform_path] = module or Mock() -def async_capture_events(hass, event_name): +def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]: """Create a helper that captures events.""" events = [] - @ha.callback - def capture_events(event): + @callback + def capture_events(event: Event) -> None: events.append(event) hass.bus.async_listen(event_name, capture_events) @@ -1282,13 +1298,13 @@ def async_capture_events(hass, event_name): return events -@ha.callback -def async_mock_signal(hass, signal): +@callback +def async_mock_signal(hass: HomeAssistant, signal: str) -> list[tuple[Any]]: """Catch all dispatches to a signal.""" calls = [] - @ha.callback - def mock_signal_handler(*args): + @callback + def mock_signal_handler(*args: Any) -> None: """Mock service call.""" calls.append(args) @@ -1297,7 +1313,7 @@ def async_mock_signal(hass, signal): return calls -def assert_lists_same(a, b): +def assert_lists_same(a: list[Any], b: list[Any]) -> None: """Compare two lists, ignoring order. Check both that all items in a are in b and that all items in b are in a, @@ -1322,17 +1338,17 @@ class _HA_ANY: _other = _SENTINEL - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Test equal.""" self._other = other return True - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Test not equal.""" self._other = other return False - def __repr__(self): + def __repr__(self) -> str: """Return repr() other to not show up in pytest quality diffs.""" if self._other is _SENTINEL: return "" @@ -1342,7 +1358,7 @@ class _HA_ANY: ANY = _HA_ANY() -def raise_contains_mocks(val): +def raise_contains_mocks(val: Any) -> None: """Raise for mocks.""" if isinstance(val, Mock): raise ValueError