mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Collection of typing improvements in common test helpers (#85509)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
6baa905448
commit
4c2b20db68
124
tests/common.py
124
tests/common.py
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Awaitable, Callable, Collection
|
||||
from collections.abc import Awaitable, Callable, Collection, Mapping, Sequence
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import functools as ft
|
||||
@ -16,13 +16,13 @@ import threading
|
||||
import time
|
||||
from time import monotonic
|
||||
import types
|
||||
from typing import Any
|
||||
from typing import Any, NoReturn
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import auth, bootstrap, config_entries, core as ha, loader
|
||||
from homeassistant import auth, bootstrap, config_entries, loader
|
||||
from homeassistant.auth import (
|
||||
auth_store,
|
||||
models as auth_models,
|
||||
@ -42,7 +42,15 @@ from homeassistant.const import (
|
||||
STATE_OFF,
|
||||
STATE_ON,
|
||||
)
|
||||
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, ServiceCall, State
|
||||
from homeassistant.core import (
|
||||
BLOCK_LOG_TIMEOUT,
|
||||
CoreState,
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
State,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.helpers import (
|
||||
area_registry,
|
||||
device_registry,
|
||||
@ -57,7 +65,7 @@ from homeassistant.helpers import (
|
||||
)
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.typing import ConfigType, StateType
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
import homeassistant.util.dt as date_util
|
||||
@ -161,7 +169,7 @@ def get_test_home_assistant():
|
||||
# pylint: disable=protected-access
|
||||
async def async_test_home_assistant(event_loop, load_registries=True):
|
||||
"""Return a Home Assistant object pointing at test config dir."""
|
||||
hass = ha.HomeAssistant()
|
||||
hass = HomeAssistant()
|
||||
store = auth_store.AuthStore(hass)
|
||||
hass.auth = auth.AuthManager(hass, store, {}, {})
|
||||
ensure_auth_manager_loaded(hass.auth)
|
||||
@ -308,7 +316,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
|
||||
await hass.async_block_till_done()
|
||||
hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None
|
||||
|
||||
hass.state = ha.CoreState.running
|
||||
hass.state = CoreState.running
|
||||
|
||||
# Mock async_start
|
||||
orig_start = hass.async_start
|
||||
@ -321,7 +329,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
|
||||
|
||||
hass.async_start = mock_async_start
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def clear_instance(event):
|
||||
"""Clear global instance."""
|
||||
INSTANCES.remove(hass)
|
||||
@ -337,7 +345,7 @@ def async_mock_service(
|
||||
"""Set up a fake service & return a calls log list to this service."""
|
||||
calls = []
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def mock_service_log(call): # pylint: disable=unnecessary-lambda
|
||||
"""Mock service call."""
|
||||
calls.append(call)
|
||||
@ -350,7 +358,7 @@ def async_mock_service(
|
||||
mock_service = threadsafe_callback_factory(async_mock_service)
|
||||
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def async_mock_intent(hass, intent_typ):
|
||||
"""Set up a fake intent handler."""
|
||||
intents = []
|
||||
@ -368,7 +376,7 @@ def async_mock_intent(hass, intent_typ):
|
||||
return intents
|
||||
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
|
||||
"""Fire the MQTT message."""
|
||||
# Local import to avoid processing MQTT modules when running a testcase
|
||||
@ -384,7 +392,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
|
||||
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
|
||||
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def async_fire_time_changed_exact(
|
||||
hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False
|
||||
) -> None:
|
||||
@ -403,7 +411,7 @@ def async_fire_time_changed_exact(
|
||||
_async_fire_time_changed(hass, utc_datetime, fire_all)
|
||||
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def async_fire_time_changed(
|
||||
hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False
|
||||
) -> None:
|
||||
@ -432,7 +440,7 @@ def async_fire_time_changed(
|
||||
_async_fire_time_changed(hass, utc_datetime, fire_all)
|
||||
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def _async_fire_time_changed(
|
||||
hass: HomeAssistant, utc_datetime: datetime | None, fire_all: bool
|
||||
) -> None:
|
||||
@ -491,7 +499,7 @@ def mock_state_change_event(
|
||||
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
|
||||
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def mock_component(hass: HomeAssistant, component: str) -> None:
|
||||
"""Mock a component is setup."""
|
||||
if component in hass.config.components:
|
||||
@ -624,7 +632,7 @@ async def register_auth_provider(
|
||||
return provider
|
||||
|
||||
|
||||
@ha.callback
|
||||
@callback
|
||||
def ensure_auth_manager_loaded(auth_mgr):
|
||||
"""Ensure an auth manager is considered loaded."""
|
||||
store = auth_mgr._store
|
||||
@ -995,7 +1003,7 @@ def init_recorder_component(hass, add_config=None, db_url="sqlite://"):
|
||||
)
|
||||
|
||||
|
||||
def mock_restore_cache(hass, states):
|
||||
def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None:
|
||||
"""Mock the DATA_RESTORE_CACHE."""
|
||||
key = restore_state.DATA_RESTORE_STATE_TASK
|
||||
data = restore_state.RestoreStateData(hass)
|
||||
@ -1020,7 +1028,9 @@ def mock_restore_cache(hass, states):
|
||||
hass.data[key] = data
|
||||
|
||||
|
||||
def mock_restore_cache_with_extra_data(hass, states):
|
||||
def mock_restore_cache_with_extra_data(
|
||||
hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]]
|
||||
) -> None:
|
||||
"""Mock the DATA_RESTORE_CACHE."""
|
||||
key = restore_state.DATA_RESTORE_STATE_TASK
|
||||
data = restore_state.RestoreStateData(hass)
|
||||
@ -1048,7 +1058,7 @@ def mock_restore_cache_with_extra_data(hass, states):
|
||||
class MockEntity(entity.Entity):
|
||||
"""Mock Entity class."""
|
||||
|
||||
def __init__(self, **values):
|
||||
def __init__(self, **values: Any) -> None:
|
||||
"""Initialize an entity."""
|
||||
self._values = values
|
||||
|
||||
@ -1056,86 +1066,86 @@ class MockEntity(entity.Entity):
|
||||
self.entity_id = values["entity_id"]
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
def available(self) -> bool:
|
||||
"""Return True if entity is available."""
|
||||
return self._handle("available")
|
||||
|
||||
@property
|
||||
def capability_attributes(self):
|
||||
def capability_attributes(self) -> Mapping[str, Any] | None:
|
||||
"""Info about capabilities."""
|
||||
return self._handle("capability_attributes")
|
||||
|
||||
@property
|
||||
def device_class(self):
|
||||
def device_class(self) -> str | None:
|
||||
"""Info how device should be classified."""
|
||||
return self._handle("device_class")
|
||||
|
||||
@property
|
||||
def device_info(self):
|
||||
def device_info(self) -> entity.DeviceInfo | None:
|
||||
"""Info how it links to a device."""
|
||||
return self._handle("device_info")
|
||||
|
||||
@property
|
||||
def entity_category(self):
|
||||
def entity_category(self) -> entity.EntityCategory | None:
|
||||
"""Return the entity category."""
|
||||
return self._handle("entity_category")
|
||||
|
||||
@property
|
||||
def has_entity_name(self):
|
||||
def has_entity_name(self) -> bool:
|
||||
"""Return the has_entity_name name flag."""
|
||||
return self._handle("has_entity_name")
|
||||
|
||||
@property
|
||||
def entity_registry_enabled_default(self):
|
||||
def entity_registry_enabled_default(self) -> bool:
|
||||
"""Return if the entity should be enabled when first added to the entity registry."""
|
||||
return self._handle("entity_registry_enabled_default")
|
||||
|
||||
@property
|
||||
def entity_registry_visible_default(self):
|
||||
def entity_registry_visible_default(self) -> bool:
|
||||
"""Return if the entity should be visible when first added to the entity registry."""
|
||||
return self._handle("entity_registry_visible_default")
|
||||
|
||||
@property
|
||||
def icon(self):
|
||||
def icon(self) -> str | None:
|
||||
"""Return the suggested icon."""
|
||||
return self._handle("icon")
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str | None:
|
||||
"""Return the name of the entity."""
|
||||
return self._handle("name")
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
def should_poll(self) -> bool:
|
||||
"""Return the ste of the polling."""
|
||||
return self._handle("should_poll")
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
def state(self) -> StateType:
|
||||
"""Return the state of the entity."""
|
||||
return self._handle("state")
|
||||
|
||||
@property
|
||||
def supported_features(self):
|
||||
def supported_features(self) -> int | None:
|
||||
"""Info about supported features."""
|
||||
return self._handle("supported_features")
|
||||
|
||||
@property
|
||||
def translation_key(self):
|
||||
def translation_key(self) -> str | None:
|
||||
"""Return the translation key."""
|
||||
return self._handle("translation_key")
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
def unique_id(self) -> str | None:
|
||||
"""Return the unique ID of the entity."""
|
||||
return self._handle("unique_id")
|
||||
|
||||
@property
|
||||
def unit_of_measurement(self):
|
||||
def unit_of_measurement(self) -> str | None:
|
||||
"""Info on the units the entity state is in."""
|
||||
return self._handle("unit_of_measurement")
|
||||
|
||||
def _handle(self, attr):
|
||||
def _handle(self, attr: str) -> Any:
|
||||
"""Return attribute value."""
|
||||
if attr in self._values:
|
||||
return self._values[attr]
|
||||
@ -1202,7 +1212,7 @@ def mock_storage(data=None):
|
||||
yield data
|
||||
|
||||
|
||||
async def flush_store(store):
|
||||
async def flush_store(store: storage.Store) -> None:
|
||||
"""Make sure all delayed writes of a store are written."""
|
||||
if store._data is None:
|
||||
return
|
||||
@ -1212,12 +1222,14 @@ async def flush_store(store):
|
||||
await store._async_handle_write_data()
|
||||
|
||||
|
||||
async def get_system_health_info(hass, domain):
|
||||
async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str, Any]:
|
||||
"""Get system health info."""
|
||||
return await hass.data["system_health"][domain].info_callback(hass)
|
||||
|
||||
|
||||
def mock_integration(hass, module, built_in=True):
|
||||
def mock_integration(
|
||||
hass: HomeAssistant, module: MockModule, built_in: bool = True
|
||||
) -> loader.Integration:
|
||||
"""Mock an integration."""
|
||||
integration = loader.Integration(
|
||||
hass,
|
||||
@ -1228,7 +1240,7 @@ def mock_integration(hass, module, built_in=True):
|
||||
module.mock_manifest(),
|
||||
)
|
||||
|
||||
def mock_import_platform(platform_name):
|
||||
def mock_import_platform(platform_name: str) -> NoReturn:
|
||||
raise ImportError(
|
||||
f"Mocked unable to import platform '{platform_name}'",
|
||||
name=f"{integration.pkg_path}.{platform_name}",
|
||||
@ -1243,7 +1255,9 @@ def mock_integration(hass, module, built_in=True):
|
||||
return integration
|
||||
|
||||
|
||||
def mock_entity_platform(hass, platform_path, module):
|
||||
def mock_entity_platform(
|
||||
hass: HomeAssistant, platform_path: str, module: MockPlatform | None
|
||||
) -> None:
|
||||
"""Mock a entity platform.
|
||||
|
||||
platform_path is in form light.hue. Will create platform
|
||||
@ -1253,7 +1267,9 @@ def mock_entity_platform(hass, platform_path, module):
|
||||
mock_platform(hass, f"{platform_name}.{domain}", module)
|
||||
|
||||
|
||||
def mock_platform(hass, platform_path, module=None):
|
||||
def mock_platform(
|
||||
hass: HomeAssistant, platform_path: str, module: Mock | MockPlatform | None = None
|
||||
) -> None:
|
||||
"""Mock a platform.
|
||||
|
||||
platform_path is in form hue.config_flow.
|
||||
@ -1269,12 +1285,12 @@ def mock_platform(hass, platform_path, module=None):
|
||||
module_cache[platform_path] = module or Mock()
|
||||
|
||||
|
||||
def async_capture_events(hass, event_name):
|
||||
def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
|
||||
"""Create a helper that captures events."""
|
||||
events = []
|
||||
|
||||
@ha.callback
|
||||
def capture_events(event):
|
||||
@callback
|
||||
def capture_events(event: Event) -> None:
|
||||
events.append(event)
|
||||
|
||||
hass.bus.async_listen(event_name, capture_events)
|
||||
@ -1282,13 +1298,13 @@ def async_capture_events(hass, event_name):
|
||||
return events
|
||||
|
||||
|
||||
@ha.callback
|
||||
def async_mock_signal(hass, signal):
|
||||
@callback
|
||||
def async_mock_signal(hass: HomeAssistant, signal: str) -> list[tuple[Any]]:
|
||||
"""Catch all dispatches to a signal."""
|
||||
calls = []
|
||||
|
||||
@ha.callback
|
||||
def mock_signal_handler(*args):
|
||||
@callback
|
||||
def mock_signal_handler(*args: Any) -> None:
|
||||
"""Mock service call."""
|
||||
calls.append(args)
|
||||
|
||||
@ -1297,7 +1313,7 @@ def async_mock_signal(hass, signal):
|
||||
return calls
|
||||
|
||||
|
||||
def assert_lists_same(a, b):
|
||||
def assert_lists_same(a: list[Any], b: list[Any]) -> None:
|
||||
"""Compare two lists, ignoring order.
|
||||
|
||||
Check both that all items in a are in b and that all items in b are in a,
|
||||
@ -1322,17 +1338,17 @@ class _HA_ANY:
|
||||
|
||||
_other = _SENTINEL
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Test equal."""
|
||||
self._other = other
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
"""Test not equal."""
|
||||
self._other = other
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return repr() other to not show up in pytest quality diffs."""
|
||||
if self._other is _SENTINEL:
|
||||
return "<ANY>"
|
||||
@ -1342,7 +1358,7 @@ class _HA_ANY:
|
||||
ANY = _HA_ANY()
|
||||
|
||||
|
||||
def raise_contains_mocks(val):
|
||||
def raise_contains_mocks(val: Any) -> None:
|
||||
"""Raise for mocks."""
|
||||
if isinstance(val, Mock):
|
||||
raise ValueError
|
||||
|
Loading…
x
Reference in New Issue
Block a user