Add more type hints to conftest.py (#87842)

* Add more type hints in conftest.py

* Adjust stop_hass

* Adjust mock_integration_frame

* Adjust pylint plugin
This commit is contained in:
epenet 2023-02-11 13:48:53 +01:00 committed by GitHub
parent 6d87ebc7de
commit b7b82b1e3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 31 deletions

View File

@ -96,16 +96,36 @@ _TEST_FIXTURES: dict[str, list[str] | str] = {
"area_registry": "AreaRegistry", "area_registry": "AreaRegistry",
"async_setup_recorder_instance": "RecorderInstanceGenerator", "async_setup_recorder_instance": "RecorderInstanceGenerator",
"caplog": "pytest.LogCaptureFixture", "caplog": "pytest.LogCaptureFixture",
"current_request_with_host": "None",
"device_registry": "DeviceRegistry", "device_registry": "DeviceRegistry",
"enable_bluetooth": "None",
"enable_custom_integrations": "None",
"enable_nightly_purge": "bool", "enable_nightly_purge": "bool",
"enable_statistics": "bool", "enable_statistics": "bool",
"enable_statistics_table_validation": "bool", "enable_statistics_table_validation": "bool",
"entity_registry": "EntityRegistry", "entity_registry": "EntityRegistry",
"hass_access_token": "str",
"hass_admin_credential": "Credentials",
"hass_admin_user": "MockUser",
"hass_client": "ClientSessionGenerator", "hass_client": "ClientSessionGenerator",
"hass_client_no_auth": "ClientSessionGenerator", "hass_client_no_auth": "ClientSessionGenerator",
"hass_owner_user": "MockUser",
"hass_read_only_access_token": "str",
"hass_read_only_user": "MockUser",
"hass_recorder": "Callable[..., HomeAssistant]", "hass_recorder": "Callable[..., HomeAssistant]",
"hass_supervisor_access_token": "str",
"hass_supervisor_user": "MockUser",
"hass_ws_client": "WebSocketGenerator", "hass_ws_client": "WebSocketGenerator",
"issue_registry": "IssueRegistry", "issue_registry": "IssueRegistry",
"legacy_auth": "LegacyApiPasswordAuthProvider",
"local_auth": "HassAuthProvider",
"mock_async_zeroconf": "None",
"mock_bleak_scanner_start": "MagicMock",
"mock_bluetooth": "None",
"mock_bluetooth_adapters": "None",
"mock_device_tracker_conf": "list[Device]",
"mock_get_source_ip": "None",
"mock_zeroconf": "None",
"mqtt_client_mock": "MqttMockPahoClient", "mqtt_client_mock": "MqttMockPahoClient",
"mqtt_mock": "MqttMockHAClient", "mqtt_mock": "MqttMockHAClient",
"mqtt_mock_entry_no_yaml_config": "MqttMockHAClientGenerator", "mqtt_mock_entry_no_yaml_config": "MqttMockHAClientGenerator",
@ -113,6 +133,7 @@ _TEST_FIXTURES: dict[str, list[str] | str] = {
"recorder_db_url": "str", "recorder_db_url": "str",
"recorder_mock": "Recorder", "recorder_mock": "Recorder",
"requests_mock": "requests_mock.Mocker", "requests_mock": "requests_mock.Mocker",
"tmp_path": "Path",
} }
_TEST_FUNCTION_MATCH = TypeHintMatch( _TEST_FUNCTION_MATCH = TypeHintMatch(
function_name="test_*", function_name="test_*",

View File

@ -8,12 +8,11 @@ import datetime
import functools import functools
import gc import gc
import itertools import itertools
from json import JSONDecoder
import logging import logging
import sqlite3 import sqlite3
import ssl import ssl
import threading import threading
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
from aiohttp import client from aiohttp import client
@ -23,6 +22,7 @@ from aiohttp.test_utils import (
TestServer, TestServer,
make_mocked_request, make_mocked_request,
) )
from aiohttp.typedefs import JSONDecoder
from aiohttp.web import Application from aiohttp.web import Application
import freezegun import freezegun
import multidict import multidict
@ -101,13 +101,13 @@ asyncio.set_event_loop_policy(runner.HassEventLoopPolicy(False))
asyncio.set_event_loop_policy = lambda policy: None asyncio.set_event_loop_policy = lambda policy: None
def _utcnow(): def _utcnow() -> datetime.datetime:
"""Make utcnow patchable by freezegun.""" """Make utcnow patchable by freezegun."""
return datetime.datetime.now(datetime.timezone.utc) return datetime.datetime.now(datetime.timezone.utc)
dt_util.utcnow = _utcnow dt_util.utcnow = _utcnow # type: ignore[assignment]
event.time_tracker_utcnow = _utcnow event.time_tracker_utcnow = _utcnow # type: ignore[assignment]
def pytest_addoption(parser: pytest.Parser) -> None: def pytest_addoption(parser: pytest.Parser) -> None:
@ -143,8 +143,8 @@ def pytest_runtest_setup() -> None:
pytest_socket.socket_allow_hosts(["127.0.0.1"]) pytest_socket.socket_allow_hosts(["127.0.0.1"])
pytest_socket.disable_socket(allow_unix_socket=True) pytest_socket.disable_socket(allow_unix_socket=True)
freezegun.api.datetime_to_fakedatetime = ha_datetime_to_fakedatetime freezegun.api.datetime_to_fakedatetime = ha_datetime_to_fakedatetime # type: ignore[attr-defined]
freezegun.api.FakeDatetime = HAFakeDatetime freezegun.api.FakeDatetime = HAFakeDatetime # type: ignore[attr-defined]
def adapt_datetime(val): def adapt_datetime(val):
return val.isoformat(" ") return val.isoformat(" ")
@ -154,6 +154,7 @@ def pytest_runtest_setup() -> None:
# Setup HAFakeDatetime converter for pymysql # Setup HAFakeDatetime converter for pymysql
try: try:
# pylint: disable-next=import-outside-toplevel
import MySQLdb.converters as MySQLdb_converters import MySQLdb.converters as MySQLdb_converters
except ImportError: except ImportError:
pass pass
@ -163,12 +164,12 @@ def pytest_runtest_setup() -> None:
] = MySQLdb_converters.DateTime2literal ] = MySQLdb_converters.DateTime2literal
def ha_datetime_to_fakedatetime(datetime): def ha_datetime_to_fakedatetime(datetime) -> freezegun.api.FakeDatetime: # type: ignore[name-defined]
"""Convert datetime to FakeDatetime. """Convert datetime to FakeDatetime.
Modified to include https://github.com/spulec/freezegun/pull/424. Modified to include https://github.com/spulec/freezegun/pull/424.
""" """
return freezegun.api.FakeDatetime( return freezegun.api.FakeDatetime( # type: ignore[attr-defined]
datetime.year, datetime.year,
datetime.month, datetime.month,
datetime.day, datetime.day,
@ -181,7 +182,7 @@ def ha_datetime_to_fakedatetime(datetime):
) )
class HAFakeDatetime(freezegun.api.FakeDatetime): class HAFakeDatetime(freezegun.api.FakeDatetime): # type: ignore[name-defined]
"""Modified to include https://github.com/spulec/freezegun/pull/424.""" """Modified to include https://github.com/spulec/freezegun/pull/424."""
@classmethod @classmethod
@ -200,16 +201,20 @@ class HAFakeDatetime(freezegun.api.FakeDatetime):
return ha_datetime_to_fakedatetime(result) return ha_datetime_to_fakedatetime(result)
def check_real(func): _R = TypeVar("_R")
_P = ParamSpec("_P")
def check_real(func: Callable[_P, Coroutine[Any, Any, _R]]):
"""Force a function to require a keyword _test_real to be passed in.""" """Force a function to require a keyword _test_real to be passed in."""
@functools.wraps(func) @functools.wraps(func)
async def guard_func(*args, **kwargs): async def guard_func(*args: _P.args, **kwargs: _P.kwargs) -> _R:
real = kwargs.pop("_test_real", None) real = kwargs.pop("_test_real", None)
if not real: if not real:
raise Exception( raise RuntimeError(
'Forgot to mock or pass "_test_real=True" to %s', func.__name__ f'Forgot to mock or pass "_test_real=True" to {func.__name__}'
) )
return await func(*args, **kwargs) return await func(*args, **kwargs)
@ -268,7 +273,7 @@ def verify_cleanup(
if tasks: if tasks:
event_loop.run_until_complete(asyncio.wait(tasks)) event_loop.run_until_complete(asyncio.wait(tasks))
for handle in event_loop._scheduled: for handle in event_loop._scheduled: # type: ignore[attr-defined]
if not handle.cancelled(): if not handle.cancelled():
_LOGGER.warning("Lingering timer after test %r", handle) _LOGGER.warning("Lingering timer after test %r", handle)
handle.cancel() handle.cancel()
@ -382,6 +387,7 @@ def aiohttp_client(
else: else:
assert not args, "args should be empty" assert not args, "args should be empty"
client: TestClient
if isinstance(__param, Application): if isinstance(__param, Application):
server_kwargs = server_kwargs or {} server_kwargs = server_kwargs or {}
server = TestServer(__param, loop=loop, **server_kwargs) server = TestServer(__param, loop=loop, **server_kwargs)
@ -441,7 +447,7 @@ def hass(
) )
orig_exception_handler(loop, context) orig_exception_handler(loop, context)
exceptions = [] exceptions: list[Exception] = []
hass = loop.run_until_complete(async_test_home_assistant(loop, load_registries)) hass = loop.run_until_complete(async_test_home_assistant(loop, load_registries))
ha._cv_hass.set(hass) ha._cv_hass.set(hass)
@ -692,7 +698,7 @@ def hass_client_no_auth(
@pytest.fixture @pytest.fixture
def current_request(): def current_request() -> Generator[MagicMock, None, None]:
"""Mock current request.""" """Mock current request."""
with patch("homeassistant.components.http.current_request") as mock_request_context: with patch("homeassistant.components.http.current_request") as mock_request_context:
mocked_request = make_mocked_request( mocked_request = make_mocked_request(
@ -706,7 +712,7 @@ def current_request():
@pytest.fixture @pytest.fixture
def current_request_with_host(current_request): def current_request_with_host(current_request: MagicMock) -> None:
"""Mock current request with a host header.""" """Mock current request with a host header."""
new_headers = multidict.CIMultiDict(current_request.get.return_value.headers) new_headers = multidict.CIMultiDict(current_request.get.return_value.headers)
new_headers[config_entry_oauth2_flow.HEADER_FRONTEND_BASE] = "https://example.com" new_headers[config_entry_oauth2_flow.HEADER_FRONTEND_BASE] = "https://example.com"
@ -954,7 +960,7 @@ async def mqtt_mock_entry_with_yaml_config(
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_network(): def mock_network() -> Generator[None, None, None]:
"""Mock network.""" """Mock network."""
mock_adapter = Adapter( mock_adapter = Adapter(
name="eth0", name="eth0",
@ -973,7 +979,7 @@ def mock_network():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_get_source_ip(): def mock_get_source_ip() -> Generator[None, None, None]:
"""Mock network util's async_get_source_ip.""" """Mock network util's async_get_source_ip."""
with patch( with patch(
"homeassistant.components.network.util.async_get_source_ip", "homeassistant.components.network.util.async_get_source_ip",
@ -983,7 +989,7 @@ def mock_get_source_ip():
@pytest.fixture @pytest.fixture
def mock_zeroconf(): def mock_zeroconf() -> Generator[None, None, None]:
"""Mock zeroconf.""" """Mock zeroconf."""
with patch("homeassistant.components.zeroconf.HaZeroconf", autospec=True), patch( with patch("homeassistant.components.zeroconf.HaZeroconf", autospec=True), patch(
"homeassistant.components.zeroconf.HaAsyncServiceBrowser", autospec=True "homeassistant.components.zeroconf.HaAsyncServiceBrowser", autospec=True
@ -992,7 +998,7 @@ def mock_zeroconf():
@pytest.fixture @pytest.fixture
def mock_async_zeroconf(mock_zeroconf): def mock_async_zeroconf(mock_zeroconf: None) -> Generator[None, None, None]:
"""Mock AsyncZeroconf.""" """Mock AsyncZeroconf."""
with patch("homeassistant.components.zeroconf.HaAsyncZeroconf") as mock_aiozc: with patch("homeassistant.components.zeroconf.HaAsyncZeroconf") as mock_aiozc:
zc = mock_aiozc.return_value zc = mock_aiozc.return_value
@ -1007,7 +1013,7 @@ def mock_async_zeroconf(mock_zeroconf):
@pytest.fixture @pytest.fixture
def enable_custom_integrations(hass): def enable_custom_integrations(hass: HomeAssistant) -> None:
"""Enable custom integrations defined in the test dir.""" """Enable custom integrations defined in the test dir."""
hass.data.pop(loader.DATA_CUSTOM_COMPONENTS) hass.data.pop(loader.DATA_CUSTOM_COMPONENTS)
@ -1334,7 +1340,7 @@ def mock_bleak_scanner_start() -> Generator[MagicMock, None, None]:
# We need to drop the stop method from the object since we patched # We need to drop the stop method from the object since we patched
# out start and this fixture will expire before the stop method is called # out start and this fixture will expire before the stop method is called
# when EVENT_HOMEASSISTANT_STOP is fired. # when EVENT_HOMEASSISTANT_STOP is fired.
bluetooth_scanner.OriginalBleakScanner.stop = AsyncMock() bluetooth_scanner.OriginalBleakScanner.stop = AsyncMock() # type: ignore[assignment]
with patch( with patch(
"homeassistant.components.bluetooth.scanner.OriginalBleakScanner.start", "homeassistant.components.bluetooth.scanner.OriginalBleakScanner.start",
) as mock_bleak_scanner_start: ) as mock_bleak_scanner_start:
@ -1343,7 +1349,7 @@ def mock_bleak_scanner_start() -> Generator[MagicMock, None, None]:
@pytest.fixture @pytest.fixture
def mock_bluetooth( def mock_bluetooth(
mock_bleak_scanner_start: MagicMock, mock_bluetooth_adapters mock_bleak_scanner_start: MagicMock, mock_bluetooth_adapters: None
) -> None: ) -> None:
"""Mock out bluetooth from starting.""" """Mock out bluetooth from starting."""

View File

@ -7,7 +7,9 @@ import pytest
from homeassistant.helpers import frame from homeassistant.helpers import frame
async def test_extract_frame_integration(caplog, mock_integration_frame): async def test_extract_frame_integration(
caplog: pytest.LogCaptureFixture, mock_integration_frame: Mock
) -> None:
"""Test extracting the current frame from integration context.""" """Test extracting the current frame from integration context."""
found_frame, integration, path = frame.get_integration_frame() found_frame, integration, path = frame.get_integration_frame()

View File

@ -1,7 +1,7 @@
"""Test state helpers.""" """Test state helpers."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
@ -25,7 +25,9 @@ from homeassistant.util import dt as dt_util
from tests.common import async_mock_service from tests.common import async_mock_service
async def test_async_track_states(hass, mock_integration_frame): async def test_async_track_states(
hass: HomeAssistant, mock_integration_frame: Mock
) -> None:
"""Test AsyncTrackStates context manager.""" """Test AsyncTrackStates context manager."""
point1 = dt_util.utcnow() point1 = dt_util.utcnow()
point2 = point1 + timedelta(seconds=5) point2 = point1 + timedelta(seconds=5)
@ -82,7 +84,9 @@ async def test_call_to_component(hass: HomeAssistant) -> None:
) )
async def test_get_changed_since(hass, mock_integration_frame): async def test_get_changed_since(
hass: HomeAssistant, mock_integration_frame: Mock
) -> None:
"""Test get_changed_since.""" """Test get_changed_since."""
point1 = dt_util.utcnow() point1 = dt_util.utcnow()
point2 = point1 + timedelta(seconds=5) point2 = point1 + timedelta(seconds=5)

View File

@ -23,7 +23,7 @@ BAD_CORE_CONFIG = "homeassistant:\n unit_system: bad\n\n\n"
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def apply_stop_hass(stop_hass): async def apply_stop_hass(stop_hass: None) -> None:
"""Make sure all hass are stopped.""" """Make sure all hass are stopped."""

View File

@ -32,7 +32,7 @@ def apply_mock_storage(hass_storage):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def apply_stop_hass(stop_hass): async def apply_stop_hass(stop_hass: None) -> None:
"""Make sure all hass are stopped.""" """Make sure all hass are stopped."""