diff --git a/tests/common.py b/tests/common.py index cc2bc454810..4b357fe7033 100644 --- a/tests/common.py +++ b/tests/common.py @@ -20,6 +20,7 @@ from typing import Any from unittest.mock import AsyncMock, Mock, patch from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401 +import voluptuous as vol from homeassistant import auth, config_entries, core as ha, loader from homeassistant.auth import ( @@ -42,7 +43,7 @@ from homeassistant.const import ( STATE_OFF, STATE_ON, ) -from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant +from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, ServiceCall, State from homeassistant.helpers import ( area_registry, device_registry, @@ -57,6 +58,7 @@ from homeassistant.helpers import ( ) from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.typing import ConfigType from homeassistant.setup import setup_component from homeassistant.util.async_ import run_callback_threadsafe import homeassistant.util.dt as date_util @@ -328,7 +330,9 @@ async def async_test_home_assistant(loop, load_registries=True): return hass -def async_mock_service(hass, domain, service, schema=None): +def async_mock_service( + hass: HomeAssistant, domain: str, service: str, schema: vol.Schema | None = None +) -> list[ServiceCall]: """Set up a fake service & return a calls log list to this service.""" calls = [] @@ -417,18 +421,20 @@ def get_fixture_path(filename: str, integration: str | None = None) -> pathlib.P if integration is None: return pathlib.Path(__file__).parent.joinpath("fixtures", filename) - else: - return pathlib.Path(__file__).parent.joinpath( - "components", integration, "fixtures", filename - ) + + return pathlib.Path(__file__).parent.joinpath( + "components", integration, "fixtures", filename + ) -def load_fixture(filename, integration=None): +def load_fixture(filename: str, integration: str | None = None) -> str: """Load a fixture.""" return get_fixture_path(filename, integration).read_text() -def mock_state_change_event(hass, new_state, old_state=None): +def mock_state_change_event( + hass: HomeAssistant, new_state: State, old_state: State | None = None +) -> None: """Mock state change envent.""" event_data = {"entity_id": new_state.entity_id, "new_state": new_state} @@ -439,7 +445,7 @@ def mock_state_change_event(hass, new_state, old_state=None): @ha.callback -def mock_component(hass, component): +def mock_component(hass: HomeAssistant, component: str) -> None: """Mock a component is setup.""" if component in hass.config.components: AssertionError(f"Integration {component} is already setup") @@ -447,7 +453,10 @@ def mock_component(hass, component): hass.config.components.add(component) -def mock_registry(hass, mock_entries=None): +def mock_registry( + hass: HomeAssistant, + mock_entries: dict[str, entity_registry.RegistryEntry] | None = None, +) -> entity_registry.EntityRegistry: """Mock the Entity Registry.""" registry = entity_registry.EntityRegistry(hass) if mock_entries is None: @@ -460,7 +469,9 @@ def mock_registry(hass, mock_entries=None): return registry -def mock_area_registry(hass, mock_entries=None): +def mock_area_registry( + hass: HomeAssistant, mock_entries: dict[str, area_registry.AreaEntry] | None = None +) -> area_registry.AreaRegistry: """Mock the Area Registry.""" registry = area_registry.AreaRegistry(hass) registry.areas = mock_entries or OrderedDict() @@ -469,7 +480,10 @@ def mock_area_registry(hass, mock_entries=None): return registry -def mock_device_registry(hass, mock_entries=None): +def mock_device_registry( + hass: HomeAssistant, + mock_entries: dict[str, device_registry.DeviceEntry] | None = None, +) -> device_registry.DeviceRegistry: """Mock the Device Registry.""" registry = device_registry.DeviceRegistry(hass) registry.devices = device_registry.DeviceRegistryItems() @@ -545,7 +559,9 @@ class MockUser(auth_models.User): self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup) -async def register_auth_provider(hass, config): +async def register_auth_provider( + hass: HomeAssistant, config: ConfigType +) -> auth_providers.AuthProvider: """Register an auth provider.""" provider = await auth_providers.auth_provider_from_config( hass, hass.auth._store, config