Improve recorder type hints in tests (#87826)

* Improve recorder type hints in tests

* Add comment

* Adjust comment
This commit is contained in:
epenet 2023-02-10 11:11:39 +01:00 committed by GitHub
parent b5dfd83c46
commit fac746c974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 30 deletions

View File

@ -94,17 +94,24 @@ _TEST_FIXTURES: dict[str, list[str] | str] = {
"aioclient_mock": "AiohttpClientMocker", "aioclient_mock": "AiohttpClientMocker",
"aiohttp_client": "ClientSessionGenerator", "aiohttp_client": "ClientSessionGenerator",
"area_registry": "AreaRegistry", "area_registry": "AreaRegistry",
"async_setup_recorder_instance": "RecorderInstanceGenerator",
"caplog": "pytest.LogCaptureFixture", "caplog": "pytest.LogCaptureFixture",
"device_registry": "DeviceRegistry", "device_registry": "DeviceRegistry",
"enable_nightly_purge": "bool",
"enable_statistics": "bool",
"enable_statistics_table_validation": "bool",
"entity_registry": "EntityRegistry", "entity_registry": "EntityRegistry",
"hass_client": "ClientSessionGenerator", "hass_client": "ClientSessionGenerator",
"hass_client_no_auth": "ClientSessionGenerator", "hass_client_no_auth": "ClientSessionGenerator",
"hass_recorder": "Callable[..., HomeAssistant]",
"hass_ws_client": "WebSocketGenerator", "hass_ws_client": "WebSocketGenerator",
"issue_registry": "IssueRegistry", "issue_registry": "IssueRegistry",
"mqtt_client_mock": "MqttMockPahoClient", "mqtt_client_mock": "MqttMockPahoClient",
"mqtt_mock": "MqttMockHAClient", "mqtt_mock": "MqttMockHAClient",
"mqtt_mock_entry_no_yaml_config": "MqttMockHAClientGenerator", "mqtt_mock_entry_no_yaml_config": "MqttMockHAClientGenerator",
"mqtt_mock_entry_with_yaml_config": "MqttMockHAClientGenerator", "mqtt_mock_entry_with_yaml_config": "MqttMockHAClientGenerator",
"recorder_db_url": "str",
"recorder_mock": "Recorder",
} }
_TEST_FUNCTION_MATCH = TypeHintMatch( _TEST_FUNCTION_MATCH = TypeHintMatch(
function_name="test_*", function_name="test_*",

View File

@ -13,7 +13,7 @@ import logging
import sqlite3 import sqlite3
import ssl import ssl
import threading import threading
from typing import Any from typing import TYPE_CHECKING, Any, cast
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
from aiohttp import ClientWebSocketResponse, client from aiohttp import ClientWebSocketResponse, client
@ -65,9 +65,15 @@ from .typing import (
MqttMockHAClient, MqttMockHAClient,
MqttMockHAClientGenerator, MqttMockHAClientGenerator,
MqttMockPahoClient, MqttMockPahoClient,
RecorderInstanceGenerator,
WebSocketGenerator, WebSocketGenerator,
) )
if TYPE_CHECKING:
# Local import to avoid processing recorder and SQLite modules when running a
# testcase which does not use the recorder.
from homeassistant.components import recorder
pytest.register_assert_rewrite("tests.common") pytest.register_assert_rewrite("tests.common")
from .common import ( # noqa: E402, isort:skip from .common import ( # noqa: E402, isort:skip
@ -75,7 +81,6 @@ from .common import ( # noqa: E402, isort:skip
INSTANCES, INSTANCES,
MockConfigEntry, MockConfigEntry,
MockUser, MockUser,
SetupRecorderInstanceT,
async_fire_mqtt_message, async_fire_mqtt_message,
async_test_home_assistant, async_test_home_assistant,
get_test_home_assistant, get_test_home_assistant,
@ -994,7 +999,7 @@ def enable_custom_integrations(hass):
@pytest.fixture @pytest.fixture
def enable_statistics(): def enable_statistics() -> bool:
"""Fixture to control enabling of recorder's statistics compilation. """Fixture to control enabling of recorder's statistics compilation.
To enable statistics, tests can be marked with: To enable statistics, tests can be marked with:
@ -1004,7 +1009,7 @@ def enable_statistics():
@pytest.fixture @pytest.fixture
def enable_statistics_table_validation(): def enable_statistics_table_validation() -> bool:
"""Fixture to control enabling of recorder's statistics table validation. """Fixture to control enabling of recorder's statistics table validation.
To enable statistics table validation, tests can be marked with: To enable statistics table validation, tests can be marked with:
@ -1014,7 +1019,7 @@ def enable_statistics_table_validation():
@pytest.fixture @pytest.fixture
def enable_nightly_purge(): def enable_nightly_purge() -> bool:
"""Fixture to control enabling of recorder's nightly purge job. """Fixture to control enabling of recorder's nightly purge job.
To enable nightly purging, tests can be marked with: To enable nightly purging, tests can be marked with:
@ -1024,7 +1029,7 @@ def enable_nightly_purge():
@pytest.fixture @pytest.fixture
def recorder_config(): def recorder_config() -> dict[str, Any] | None:
"""Fixture to override recorder config. """Fixture to override recorder config.
To override the config, tests can be marked with: To override the config, tests can be marked with:
@ -1035,14 +1040,15 @@ def recorder_config():
@pytest.fixture @pytest.fixture
def recorder_db_url( def recorder_db_url(
pytestconfig, pytestconfig: pytest.Config,
hass_fixture_setup, hass_fixture_setup: list[bool],
): ) -> Generator[str, None, None]:
"""Prepare a default database for tests and return a connection URL.""" """Prepare a default database for tests and return a connection URL."""
assert not hass_fixture_setup assert not hass_fixture_setup
db_url: str = pytestconfig.getoption("dburl") db_url = cast(str, pytestconfig.getoption("dburl"))
if db_url.startswith(("postgresql://", "mysql://")): if db_url.startswith(("postgresql://", "mysql://")):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy_utils import sqlalchemy_utils
def _ha_orm_quote(mixed, ident): def _ha_orm_quote(mixed, ident):
@ -1060,18 +1066,21 @@ def recorder_db_url(
sqlalchemy_utils.functions.database.quote = _ha_orm_quote sqlalchemy_utils.functions.database.quote = _ha_orm_quote
if db_url.startswith("mysql://"): if db_url.startswith("mysql://"):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy_utils import sqlalchemy_utils
charset = "utf8mb4' COLLATE = 'utf8mb4_unicode_ci" charset = "utf8mb4' COLLATE = 'utf8mb4_unicode_ci"
assert not sqlalchemy_utils.database_exists(db_url) assert not sqlalchemy_utils.database_exists(db_url)
sqlalchemy_utils.create_database(db_url, encoding=charset) sqlalchemy_utils.create_database(db_url, encoding=charset)
elif db_url.startswith("postgresql://"): elif db_url.startswith("postgresql://"):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy_utils import sqlalchemy_utils
assert not sqlalchemy_utils.database_exists(db_url) assert not sqlalchemy_utils.database_exists(db_url)
sqlalchemy_utils.create_database(db_url, encoding="utf8") sqlalchemy_utils.create_database(db_url, encoding="utf8")
yield db_url yield db_url
if db_url.startswith("mysql://"): if db_url.startswith("mysql://"):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy as sa import sqlalchemy as sa
made_url = sa.make_url(db_url) made_url = sa.make_url(db_url)
@ -1096,15 +1105,14 @@ def recorder_db_url(
@pytest.fixture @pytest.fixture
def hass_recorder( def hass_recorder(
recorder_db_url, recorder_db_url: str,
enable_nightly_purge, enable_nightly_purge: bool,
enable_statistics, enable_statistics: bool,
enable_statistics_table_validation, enable_statistics_table_validation: bool,
hass_storage, hass_storage,
): ) -> Generator[Callable[..., HomeAssistant], None, None]:
"""Home Assistant fixture with in-memory recorder.""" """Home Assistant fixture with in-memory recorder."""
# Local import to avoid processing recorder and SQLite modules when running a # pylint: disable-next=import-outside-toplevel
# testcase which does not use the recorder.
from homeassistant.components import recorder from homeassistant.components import recorder
original_tz = dt_util.DEFAULT_TIME_ZONE original_tz = dt_util.DEFAULT_TIME_ZONE
@ -1131,7 +1139,7 @@ def hass_recorder(
autospec=True, autospec=True,
): ):
def setup_recorder(config=None): def setup_recorder(config: dict[str, Any] | None = None) -> HomeAssistant:
"""Set up with params.""" """Set up with params."""
init_recorder_component(hass, config, recorder_db_url) init_recorder_component(hass, config, recorder_db_url)
hass.start() hass.start()
@ -1146,10 +1154,13 @@ def hass_recorder(
dt_util.DEFAULT_TIME_ZONE = original_tz dt_util.DEFAULT_TIME_ZONE = original_tz
async def _async_init_recorder_component(hass, add_config=None, db_url=None): async def _async_init_recorder_component(
hass: HomeAssistant,
add_config: dict[str, Any] | None = None,
db_url: str | None = None,
) -> None:
"""Initialize the recorder asynchronously.""" """Initialize the recorder asynchronously."""
# Local import to avoid processing recorder and SQLite modules when running a # pylint: disable-next=import-outside-toplevel
# testcase which does not use the recorder.
from homeassistant.components import recorder from homeassistant.components import recorder
config = dict(add_config) if add_config else {} config = dict(add_config) if add_config else {}
@ -1173,16 +1184,16 @@ async def _async_init_recorder_component(hass, add_config=None, db_url=None):
@pytest.fixture @pytest.fixture
async def async_setup_recorder_instance( async def async_setup_recorder_instance(
recorder_db_url, recorder_db_url: str,
enable_nightly_purge, enable_nightly_purge: bool,
enable_statistics, enable_statistics: bool,
enable_statistics_table_validation, enable_statistics_table_validation: bool,
) -> AsyncGenerator[SetupRecorderInstanceT, None]: ) -> AsyncGenerator[RecorderInstanceGenerator, None]:
"""Yield callable to setup recorder instance.""" """Yield callable to setup recorder instance."""
# Local import to avoid processing recorder and SQLite modules when running a # pylint: disable-next=import-outside-toplevel
# testcase which does not use the recorder.
from homeassistant.components import recorder from homeassistant.components import recorder
# pylint: disable-next=import-outside-toplevel
from .components.recorder.common import async_recorder_block_till_done from .components.recorder.common import async_recorder_block_till_done
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
@ -1222,7 +1233,11 @@ async def async_setup_recorder_instance(
@pytest.fixture @pytest.fixture
async def recorder_mock(recorder_config, async_setup_recorder_instance, hass): async def recorder_mock(
recorder_config: dict[str, Any] | None,
async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant,
) -> recorder.Recorder:
"""Fixture with in-memory recorder.""" """Fixture with in-memory recorder."""
return await async_setup_recorder_instance(hass, recorder_config) return await async_setup_recorder_instance(hass, recorder_config)

View File

@ -2,12 +2,17 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from typing import Any from typing import TYPE_CHECKING, Any, TypeAlias
from unittest.mock import MagicMock from unittest.mock import MagicMock
from aiohttp import ClientWebSocketResponse from aiohttp import ClientWebSocketResponse
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
if TYPE_CHECKING:
# Local import to avoid processing recorder module when running a
# testcase which does not use the recorder.
from homeassistant.components.recorder import Recorder
ClientSessionGenerator = Callable[..., Coroutine[Any, Any, TestClient]] ClientSessionGenerator = Callable[..., Coroutine[Any, Any, TestClient]]
MqttMockPahoClient = MagicMock MqttMockPahoClient = MagicMock
"""MagicMock for `paho.mqtt.client.Client`""" """MagicMock for `paho.mqtt.client.Client`"""
@ -15,4 +20,6 @@ MqttMockHAClient = MagicMock
"""MagicMock for `homeassistant.components.mqtt.MQTT`.""" """MagicMock for `homeassistant.components.mqtt.MQTT`."""
MqttMockHAClientGenerator = Callable[..., Coroutine[Any, Any, MqttMockHAClient]] MqttMockHAClientGenerator = Callable[..., Coroutine[Any, Any, MqttMockHAClient]]
"""MagicMock generator for `homeassistant.components.mqtt.MQTT`.""" """MagicMock generator for `homeassistant.components.mqtt.MQTT`."""
RecorderInstanceGenerator: TypeAlias = Callable[..., Coroutine[Any, Any, "Recorder"]]
"""Instance generator for `homeassistant.components.recorder.Recorder`."""
WebSocketGenerator = Callable[..., Coroutine[Any, Any, ClientWebSocketResponse]] WebSocketGenerator = Callable[..., Coroutine[Any, Any, ClientWebSocketResponse]]