Compare commits

...

37 Commits

Author SHA1 Message Date
Erik
b590cd1e6c Update clear_exception_traceback fixture 2025-05-08 09:08:46 +02:00
Erik
25cafd7b43 Adjust unifiprotect tests 2025-05-08 09:06:22 +02:00
Erik
9a4bcd88db Adjust tessie tests 2025-05-08 09:05:47 +02:00
Erik
8a7123b880 Adjust tessie tests 2025-05-08 09:04:14 +02:00
Erik
8e2011a100 Reset esphome DomainData cache inbetween tests 2025-05-08 09:02:46 +02:00
Erik
26d48e20dd Update clear_exception_traceback fixture 2025-05-07 17:57:05 +02:00
Erik
7a340fb676 Update clear_exception_traceback fixture 2025-05-07 17:14:33 +02:00
Erik
c300e1e376 Reset more caches in 2025-05-07 17:13:55 +02:00
Erik
8aac4777b1 Adjust nest tests 2025-05-07 17:13:05 +02:00
Erik
fdcaf2897a Adjust imap tests 2025-05-07 17:12:24 +02:00
Erik
473b77279f Adjust analytics tests 2025-05-07 17:11:50 +02:00
Erik
1351304343 Fix patching in network helper tests 2025-05-07 14:22:47 +02:00
Erik
980b3023e9 Remove reuse of exception object from bang_olufsen tests 2025-05-07 14:20:15 +02:00
Erik
58fa6a06a7 Update clear_exception_traceback fixture 2025-05-07 14:19:14 +02:00
Erik
9db63ca774 Revert patching of httpx mocker 2025-05-07 14:18:48 +02:00
Erik
c0d867d0c4 Update clear_exception_traceback fixture 2025-05-07 09:29:59 +02:00
Erik
61d64d2d59 Reset exceptions in update_coordinator tests 2025-05-07 09:17:02 +02:00
Erik
402fb8e53a Update clear_exception_traceback fixture 2025-05-07 09:12:42 +02:00
Erik
3be9553508 Update clear_exception_traceback fixture 2025-05-07 08:46:56 +02:00
Erik
8cd72586ec Reset cache in emulated_hue 2025-05-07 08:41:44 +02:00
Erik
751f97a462 Modify require_admin decorator 2025-05-06 14:42:18 +02:00
Erik
12909c1877 Update clear_exception_traceback fixture 2025-05-06 14:42:18 +02:00
Erik
9e01c14b16 Update clear_exception_traceback fixture 2025-05-06 14:42:18 +02:00
Erik
956fbce7d8 Fix httpx monkey patch 2025-05-06 14:42:18 +02:00
Erik
3b86a1a2b6 Clear exception traceback after each test 2025-05-06 14:42:18 +02:00
Erik
a4bd6754df Reduce scope of mock_network fixture 2025-05-06 14:42:18 +02:00
Erik
7faf4bfd72 Patch httpx mocker 2025-05-06 14:42:18 +02:00
Erik
ea92047502 Make name a non cached property 2025-05-06 14:42:18 +02:00
Erik
2e74a2ad28 Disable cached_property in entity helper 2025-05-06 14:42:18 +02:00
Erik
ab5f20aa69 Tear down evict_faked_translations before counting hass objects 2025-05-06 14:42:18 +02:00
Erik
156ce39202 Format log records early 2025-05-06 14:42:17 +02:00
Erik
09cb358015 Reset template cache 2025-05-06 14:42:17 +02:00
Erik
8beddd2481 Reset template state cache 2025-05-06 14:42:17 +02:00
Erik
4d5e809e9b Tweak 2025-05-06 14:42:17 +02:00
Erik
71af693569 Reset aiohttp route cache 2025-05-06 14:42:17 +02:00
Erik
2396fe2245 Tweak singleton helper 2025-05-06 14:42:17 +02:00
Erik
aa19dfacfc Fail tests which leak hass instances 2025-05-06 14:42:17 +02:00
32 changed files with 440 additions and 231 deletions

View File

@@ -165,9 +165,7 @@ class ConfigManagerFlowIndexView(
"""Not implemented."""
raise aiohttp.web_exceptions.HTTPMethodNotAllowed("GET", ["POST"])
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission="add")
@RequestDataValidator(
vol.Schema(
{
@@ -218,16 +216,12 @@ class ConfigManagerFlowResourceView(
url = "/api/config/config_entries/flow/{flow_id}"
name = "api:config:config_entries:flow:resource"
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission="add")
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow."""
return await super().get(request, flow_id)
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add")
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission="add")
async def post(self, request: web.Request, flow_id: str) -> web.Response:
"""Handle a POST request."""
return await super().post(request, flow_id)
@@ -262,9 +256,7 @@ class OptionManagerFlowIndexView(
url = "/api/config/config_entries/options/flow"
name = "api:config:config_entries:option:flow"
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
async def post(self, request: web.Request) -> web.Response:
"""Handle a POST request.
@@ -281,16 +273,12 @@ class OptionManagerFlowResourceView(
url = "/api/config/config_entries/options/flow/{flow_id}"
name = "api:config:config_entries:options:flow:resource"
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow."""
return await super().get(request, flow_id)
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
async def post(self, request: web.Request, flow_id: str) -> web.Response:
"""Handle a POST request."""
return await super().post(request, flow_id)
@@ -304,9 +292,7 @@ class SubentryManagerFlowIndexView(
url = "/api/config/config_entries/subentries/flow"
name = "api:config:config_entries:subentries:flow"
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
@RequestDataValidator(
vol.Schema(
{
@@ -341,16 +327,12 @@ class SubentryManagerFlowResourceView(
url = "/api/config/config_entries/subentries/flow/{flow_id}"
name = "api:config:config_entries:subentries:flow:resource"
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow."""
return await super().get(request, flow_id)
@require_admin(
error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
)
@require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
async def post(self, request: web.Request, flow_id: str) -> web.Response:
"""Handle a POST request."""
return await super().post(request, flow_id)

View File

@@ -27,7 +27,8 @@ def require_admin[
](
_func: None = None,
*,
error: Unauthorized | None = None,
perm_category: str | None = None,
permission: str | None = None,
) -> Callable[
[_FuncType[_HomeAssistantViewT, _P, _ResponseT]],
_FuncType[_HomeAssistantViewT, _P, _ResponseT],
@@ -51,7 +52,8 @@ def require_admin[
](
_func: _FuncType[_HomeAssistantViewT, _P, _ResponseT] | None = None,
*,
error: Unauthorized | None = None,
perm_category: str | None = None,
permission: str | None = None,
) -> (
Callable[
[_FuncType[_HomeAssistantViewT, _P, _ResponseT]],
@@ -76,7 +78,7 @@ def require_admin[
"""Check admin and call function."""
user: User = request["hass_user"]
if not user.is_admin:
raise error or Unauthorized()
raise Unauthorized(perm_category=perm_category, permission=permission)
return await func(self, request, *args, **kwargs)

View File

@@ -14,7 +14,6 @@ from homeassistant.components import websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.components.http.decorators import require_admin
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized
from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.data_entry_flow import (
FlowManagerIndexView,
@@ -114,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
url = "/api/repairs/issues/fix"
name = "api:repairs:issues:fix"
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
@require_admin(permission=POLICY_EDIT)
@RequestDataValidator(
vol.Schema(
{
@@ -149,12 +148,12 @@ class RepairsFlowResourceView(FlowManagerResourceView):
url = "/api/repairs/issues/fix/{flow_id}"
name = "api:repairs:issues:fix:resource"
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
@require_admin(permission=POLICY_EDIT)
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow."""
return await super().get(request, flow_id)
@require_admin(error=Unauthorized(permission=POLICY_EDIT))
@require_admin(permission=POLICY_EDIT)
async def post(self, request: web.Request, flow_id: str) -> web.Response:
"""Handle a POST request."""
return await super().post(request, flow_id)

View File

@@ -41,6 +41,7 @@ from typing import (
final,
overload,
)
import weakref
from propcache.api import cached_property, under_cached_property
import voluptuous as vol
@@ -409,6 +410,9 @@ class CoreState(enum.Enum):
return self.value
hass_instances: list[weakref.ref[HomeAssistant]] = []
class HomeAssistant:
"""Root object of the Home Assistant home automation."""
@@ -419,6 +423,7 @@ class HomeAssistant:
def __new__(cls, config_dir: str) -> Self:
"""Set the _hass thread local data."""
hass = super().__new__(cls)
hass_instances.append(weakref.ref(hass))
_hass.hass = hass
return hass

View File

@@ -413,7 +413,6 @@ CACHED_PROPERTIES_WITH_ATTR_ = {
"extra_state_attributes",
"force_update",
"icon",
"name",
"should_poll",
"state",
"supported_features",
@@ -730,7 +729,7 @@ class Entity(
name = self.name
return None if name is UNDEFINED else name
@cached_property
@property
def name(self) -> str | UndefinedType | None:
"""Return the name of the entity."""
# The check for self.platform guards against integrations not using an

View File

@@ -49,7 +49,6 @@ def singleton[_S, _T, _U](
"""Wrap a function with caching logic."""
if not asyncio.iscoroutinefunction(func):
@functools.lru_cache(maxsize=1)
@bind_hass
@functools.wraps(func)
def wrapped(hass: HomeAssistant) -> _U:

View File

@@ -642,7 +642,8 @@ def mock_registry(
registry.entities[key] = entry
hass.data[er.DATA_REGISTRY] = registry
er.async_get.cache_clear()
with suppress(AttributeError):
er.async_get.cache_clear()
return registry
@@ -694,7 +695,8 @@ def mock_area_registry(
registry.areas[key] = entry
hass.data[ar.DATA_REGISTRY] = registry
ar.async_get.cache_clear()
with suppress(AttributeError):
ar.async_get.cache_clear()
return registry
@@ -723,7 +725,8 @@ def mock_device_registry(
registry.deleted_devices = dr.DeviceRegistryItems()
hass.data[dr.DATA_REGISTRY] = registry
dr.async_get.cache_clear()
with suppress(AttributeError):
dr.async_get.cache_clear()
return registry
@@ -1307,7 +1310,8 @@ def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None:
_LOGGER.debug("Restore cache: %s", data.last_states)
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
rs.async_get.cache_clear()
with suppress(AttributeError):
rs.async_get.cache_clear()
hass.data[key] = data
@@ -1335,7 +1339,8 @@ def mock_restore_cache_with_extra_data(
_LOGGER.debug("Restore cache: %s", data.last_states)
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
rs.async_get.cache_clear()
with suppress(AttributeError):
rs.async_get.cache_clear()
hass.data[key] = data

View File

@@ -178,10 +178,10 @@ async def test_send_base(
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -231,10 +231,10 @@ async def test_send_base_with_supervisor(
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -266,10 +266,10 @@ async def test_send_usage(
in caplog.text
)
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -328,10 +328,10 @@ async def test_send_usage_with_supervisor(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -356,10 +356,10 @@ async def test_send_statistics(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -419,10 +419,10 @@ async def test_send_statistics_disabled_integration(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -464,10 +464,10 @@ async def test_send_statistics_ignored_integration(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -547,10 +547,10 @@ async def test_send_statistics_with_supervisor(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -594,10 +594,10 @@ async def test_custom_integrations(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -693,11 +693,11 @@ async def test_send_with_no_energy(
get_recorder_instance.return_value = Mock(database_engine=Mock())
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert "energy" not in submitted_data
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -723,11 +723,11 @@ async def test_send_with_no_energy_config(
energy_is_configured.return_value = False
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data["energy"]["configured"] is False
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert (
snapshot(matcher=path_type({"recorder.version": (AwesomeVersion,)}))
== submitted_data
@@ -756,11 +756,11 @@ async def test_send_with_energy_config(
energy_is_configured.return_value = True
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data["energy"]["configured"] is True
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert (
snapshot(matcher=path_type({"recorder.version": (AwesomeVersion,)}))
== submitted_data
@@ -787,11 +787,11 @@ async def test_send_usage_with_certificate(
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data["certificate"] is True
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data
@@ -815,11 +815,11 @@ async def test_send_with_recorder(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data["recorder"]["engine"] == "sqlite"
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert (
snapshot(matcher=path_type({"recorder.version": (AwesomeVersion,)}))
== submitted_data
@@ -913,10 +913,10 @@ async def test_not_check_config_entries_if_yaml(
):
await analytics.send_analytics()
logged_data = caplog.records[-1].args
logged_data = caplog.records[-1].getMessage()
submitted_data = _last_call_payload(aioclient_mock)
assert submitted_data["integration_count"] == 1
assert submitted_data["integrations"] == ["default_config"]
assert submitted_data == logged_data
assert logged_data.endswith(str(submitted_data))
assert snapshot == submitted_data

View File

@@ -1,9 +1,7 @@
"""Constants used for testing the bang_olufsen integration."""
from ipaddress import IPv4Address, IPv6Address
from unittest.mock import Mock
from mozart_api.exceptions import ApiException
from mozart_api.models import (
Action,
ListeningModeRef,
@@ -200,16 +198,6 @@ TEST_DEEZER_TRACK = PlayQueueItem(
uri="1234567890",
)
# codespell can't see the escaped ', so it thinks the word is misspelled
TEST_DEEZER_INVALID_FLOW = ApiException(
status=400,
reason="Bad Request",
http_resp=Mock(
status=400,
reason="Bad Request",
data='{"message": "Couldn\'t start user flow for me"}', # codespell:ignore
),
)
TEST_SOUND_MODE = 123
TEST_SOUND_MODE_2 = 234
TEST_SOUND_MODE_NAME = "Test Listening Mode"

View File

@@ -2,9 +2,9 @@
from contextlib import AbstractContextManager, nullcontext as does_not_raise
import logging
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
from mozart_api.exceptions import NotFoundException
from mozart_api.exceptions import ApiException, NotFoundException
from mozart_api.models import (
BeolinkLeader,
BeolinkSelf,
@@ -81,7 +81,6 @@ from .const import (
TEST_ACTIVE_SOUND_MODE_NAME_2,
TEST_AUDIO_SOURCES,
TEST_DEEZER_FLOW,
TEST_DEEZER_INVALID_FLOW,
TEST_DEEZER_PLAYLIST,
TEST_DEEZER_TRACK,
TEST_FALLBACK_SOURCES,
@@ -1249,7 +1248,16 @@ async def test_async_play_media_invalid_deezer(
) -> None:
"""Test async_play_media with an invalid/no Deezer login."""
mock_mozart_client.start_deezer_flow.side_effect = TEST_DEEZER_INVALID_FLOW
# codespell can't see the escaped ', so it thinks the word is misspelled
mock_mozart_client.start_deezer_flow.side_effect = ApiException(
status=400,
reason="Bad Request",
http_resp=Mock(
status=400,
reason="Bad Request",
data='{"message": "Couldn\'t start user flow for me"}', # codespell:ignore
),
)
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)

View File

@@ -2,7 +2,15 @@
import pytest
from homeassistant.components.emulated_hue.config import Config
@pytest.fixture(autouse=True, name="stub_blueprint_populate")
def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None:
"""Stub copying the blueprints to the config folder."""
@pytest.fixture(autouse=True)
def reset_config_cache() -> None:
"""Reset config cache."""
Config.entity_id_to_number.cache_clear()

View File

@@ -27,7 +27,7 @@ from aioesphomeapi import (
import pytest
from zeroconf import Zeroconf
from homeassistant.components.esphome import dashboard
from homeassistant.components.esphome import dashboard, domain_data
from homeassistant.components.esphome.const import (
CONF_ALLOW_SERVICE_CALLS,
CONF_BLUETOOTH_MAC_ADDRESS,
@@ -112,6 +112,12 @@ def mock_tts(mock_tts_cache_dir: Path) -> None:
"""Auto mock the tts cache."""
@pytest.fixture(autouse=True)
def reset_domain_data_cache() -> None:
"""Reset the DomainData cache."""
domain_data.DomainData.get.cache_clear()
@pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"""Return the default mocked config entry."""

View File

@@ -1,6 +1,9 @@
"""Test the imap entry initialization."""
from __future__ import annotations
import asyncio
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
@@ -422,8 +425,8 @@ async def test_late_folder_error(
@pytest.mark.parametrize(
"imap_close",
[
AsyncMock(side_effect=AioImapException("Something went wrong")),
AsyncMock(side_effect=TimeoutError),
lambda: AsyncMock(side_effect=AioImapException("Something went wrong")),
lambda: AsyncMock(side_effect=TimeoutError),
],
ids=["AioImapException", "TimeoutError"],
)
@@ -431,7 +434,7 @@ async def test_handle_cleanup_exception(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mock_imap_protocol: MagicMock,
imap_close: Exception,
imap_close: Callable[[], AsyncMock],
) -> None:
"""Test handling an excepton during cleaning up."""
config_entry = MockConfigEntry(domain=DOMAIN, data=MOCK_CONFIG)
@@ -448,7 +451,7 @@ async def test_handle_cleanup_exception(
assert state.state == "0"
# Fail cleaning up
mock_imap_protocol.close.side_effect = imap_close
mock_imap_protocol.close.side_effect = imap_close()
assert await hass.config_entries.async_unload(config_entry.entry_id)
await hass.async_block_till_done()

View File

@@ -8,7 +8,7 @@ mode (e.g. yaml, ConfigEntry, etc) however some tests override and just run in
relevant modes.
"""
from collections.abc import Generator
from collections.abc import Callable, Generator
import datetime
from http import HTTPStatus
import logging
@@ -146,23 +146,31 @@ async def test_setup_device_manager_failure(
@pytest.mark.parametrize("token_expiration_time", [EXPIRED_TOKEN_TIMESTAMP])
@pytest.mark.parametrize(
("token_response_args", "expected_state", "expected_steps"),
("token_response_args", "token_response_exc", "expected_state", "expected_steps"),
[
# Cases that retry integration setup
(
{"status": HTTPStatus.INTERNAL_SERVER_ERROR},
lambda: None,
ConfigEntryState.SETUP_RETRY,
[],
),
(
{},
lambda: aiohttp.ClientError("No internet"),
ConfigEntryState.SETUP_RETRY,
[],
),
({"exc": aiohttp.ClientError("No internet")}, ConfigEntryState.SETUP_RETRY, []),
# Cases that require the user to reauthenticate in a config flow
(
{"status": HTTPStatus.BAD_REQUEST},
lambda: None,
ConfigEntryState.SETUP_ERROR,
["reauth_confirm"],
),
(
{"status": HTTPStatus.FORBIDDEN},
lambda: None,
ConfigEntryState.SETUP_ERROR,
["reauth_confirm"],
),
@@ -173,6 +181,7 @@ async def test_expired_token_refresh_error(
setup_base_platform: PlatformSetup,
aioclient_mock: AiohttpClientMocker,
token_response_args: dict,
token_response_exc: Callable[[], Exception | None],
expected_state: ConfigEntryState,
expected_steps: list[str],
) -> None:
@@ -180,6 +189,7 @@ async def test_expired_token_refresh_error(
aioclient_mock.post(
OAUTH2_TOKEN,
exc=token_response_exc(),
**token_response_args,
)

View File

@@ -32,21 +32,41 @@ TEST_REQUEST_INFO = RequestInfo(
url=TESSIE_URL, method="GET", headers={}, real_url=TESSIE_URL
)
ERROR_AUTH = ClientResponseError(
request_info=TEST_REQUEST_INFO, history=None, status=HTTPStatus.UNAUTHORIZED
)
ERROR_TIMEOUT = ClientResponseError(
request_info=TEST_REQUEST_INFO, history=None, status=HTTPStatus.REQUEST_TIMEOUT
)
ERROR_UNKNOWN = ClientResponseError(
request_info=TEST_REQUEST_INFO, history=None, status=HTTPStatus.BAD_REQUEST
)
ERROR_VIRTUAL_KEY = ClientResponseError(
request_info=TEST_REQUEST_INFO,
history=None,
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
ERROR_CONNECTION = ClientConnectionError()
def error_auth() -> ClientResponseError:
"""Return an error."""
return ClientResponseError(
request_info=TEST_REQUEST_INFO, history=None, status=HTTPStatus.UNAUTHORIZED
)
def error_timeout() -> ClientResponseError:
"""Return an error."""
return ClientResponseError(
request_info=TEST_REQUEST_INFO, history=None, status=HTTPStatus.REQUEST_TIMEOUT
)
def error_unknown() -> ClientResponseError:
"""Return an error."""
return ClientResponseError(
request_info=TEST_REQUEST_INFO, history=None, status=HTTPStatus.BAD_REQUEST
)
def error_virtual_key() -> ClientResponseError:
"""Return an error."""
return ClientResponseError(
request_info=TEST_REQUEST_INFO,
history=None,
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
def error_connection() -> ClientResponseError:
"""Return an error."""
return ClientConnectionError()
# Fleet API library
PRODUCTS = load_json_object_fixture("products.json", DOMAIN)

View File

@@ -22,7 +22,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er
from .common import ERROR_UNKNOWN, TEST_RESPONSE, assert_entities, setup_platform
from .common import TEST_RESPONSE, assert_entities, error_unknown, setup_platform
async def test_climate(
@@ -115,10 +115,11 @@ async def test_errors(hass: HomeAssistant) -> None:
entity_id = "climate.test_climate"
# Test setting climate on with unknown error
exc = error_unknown()
with (
patch(
"homeassistant.components.tessie.climate.stop_climate",
side_effect=ERROR_UNKNOWN,
side_effect=exc,
) as mock_set,
pytest.raises(HomeAssistantError) as error,
):
@@ -129,4 +130,4 @@ async def test_errors(hass: HomeAssistant) -> None:
blocking=True,
)
mock_set.assert_called_once()
assert error.value.__cause__ == ERROR_UNKNOWN
assert error.value.__cause__ == exc

View File

@@ -11,11 +11,11 @@ from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from .common import (
ERROR_AUTH,
ERROR_CONNECTION,
ERROR_UNKNOWN,
TEST_CONFIG,
TEST_STATE_OF_ALL_VEHICLES,
error_auth,
error_connection,
error_unknown,
)
from tests.common import MockConfigEntry
@@ -97,9 +97,9 @@ async def test_abort(
@pytest.mark.parametrize(
("side_effect", "error"),
[
(ERROR_AUTH, {CONF_ACCESS_TOKEN: "invalid_access_token"}),
(ERROR_UNKNOWN, {"base": "unknown"}),
(ERROR_CONNECTION, {"base": "cannot_connect"}),
(error_auth(), {CONF_ACCESS_TOKEN: "invalid_access_token"}),
(error_unknown(), {"base": "unknown"}),
(error_connection(), {"base": "cannot_connect"}),
],
)
async def test_form_errors(
@@ -165,9 +165,9 @@ async def test_reauth(
@pytest.mark.parametrize(
("side_effect", "error"),
[
(ERROR_AUTH, {CONF_ACCESS_TOKEN: "invalid_access_token"}),
(ERROR_UNKNOWN, {"base": "unknown"}),
(ERROR_CONNECTION, {"base": "cannot_connect"}),
(error_auth(), {CONF_ACCESS_TOKEN: "invalid_access_token"}),
(error_unknown(), {"base": "unknown"}),
(error_connection(), {"base": "cannot_connect"}),
],
)
async def test_reauth_errors(

View File

@@ -15,10 +15,10 @@ from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, Platform
from homeassistant.core import HomeAssistant
from .common import (
ERROR_AUTH,
ERROR_CONNECTION,
ERROR_UNKNOWN,
TEST_VEHICLE_STATUS_ASLEEP,
error_auth,
error_connection,
error_unknown,
setup_platform,
)
@@ -62,7 +62,7 @@ async def test_coordinator_clienterror(
) -> None:
"""Tests that the coordinator handles client errors."""
mock_get_status.side_effect = ERROR_UNKNOWN
mock_get_status.side_effect = error_unknown()
await setup_platform(hass, [Platform.BINARY_SENSOR])
freezer.tick(WAIT)
@@ -77,7 +77,7 @@ async def test_coordinator_auth(
) -> None:
"""Tests that the coordinator handles auth errors."""
mock_get_status.side_effect = ERROR_AUTH
mock_get_status.side_effect = error_auth()
await setup_platform(hass, [Platform.BINARY_SENSOR])
freezer.tick(WAIT)
@@ -91,7 +91,7 @@ async def test_coordinator_connection(
) -> None:
"""Tests that the coordinator handles connection errors."""
mock_get_status.side_effect = ERROR_CONNECTION
mock_get_status.side_effect = error_connection()
await setup_platform(hass, [Platform.BINARY_SENSOR])
freezer.tick(WAIT)
async_fire_time_changed(hass)

View File

@@ -17,10 +17,10 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er
from .common import (
ERROR_UNKNOWN,
TEST_RESPONSE,
TEST_RESPONSE_ERROR,
assert_entities,
error_unknown,
setup_platform,
)
@@ -81,10 +81,11 @@ async def test_errors(hass: HomeAssistant) -> None:
entity_id = "cover.test_charge_port_door"
# Test setting cover open with unknown error
exc = error_unknown()
with (
patch(
"homeassistant.components.tessie.cover.open_unlock_charge_port",
side_effect=ERROR_UNKNOWN,
side_effect=exc,
) as mock_set,
pytest.raises(HomeAssistantError) as error,
):
@@ -95,7 +96,7 @@ async def test_errors(hass: HomeAssistant) -> None:
blocking=True,
)
mock_set.assert_called_once()
assert error.value.__cause__ == ERROR_UNKNOWN
assert error.value.__cause__ == exc
# Test setting cover open with unknown error
with (

View File

@@ -7,7 +7,7 @@ from tesla_fleet_api.exceptions import TeslaFleetError
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from .common import ERROR_AUTH, ERROR_CONNECTION, ERROR_UNKNOWN, setup_platform
from .common import error_auth, error_connection, error_unknown, setup_platform
async def test_load_unload(hass: HomeAssistant) -> None:
@@ -25,7 +25,7 @@ async def test_auth_failure(
) -> None:
"""Test init with an authentication error."""
mock_get_state_of_all_vehicles.side_effect = ERROR_AUTH
mock_get_state_of_all_vehicles.side_effect = error_auth()
entry = await setup_platform(hass)
assert entry.state is ConfigEntryState.SETUP_ERROR
@@ -35,7 +35,7 @@ async def test_unknown_failure(
) -> None:
"""Test init with an client response error."""
mock_get_state_of_all_vehicles.side_effect = ERROR_UNKNOWN
mock_get_state_of_all_vehicles.side_effect = error_unknown()
entry = await setup_platform(hass)
assert entry.state is ConfigEntryState.SETUP_ERROR
@@ -45,7 +45,7 @@ async def test_connection_failure(
) -> None:
"""Test init with a network connection error."""
mock_get_state_of_all_vehicles.side_effect = ERROR_CONNECTION
mock_get_state_of_all_vehicles.side_effect = error_connection()
entry = await setup_platform(hass)
assert entry.state is ConfigEntryState.SETUP_RETRY

View File

@@ -20,7 +20,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er
from .common import ERROR_UNKNOWN, TEST_RESPONSE, assert_entities, setup_platform
from .common import TEST_RESPONSE, assert_entities, error_unknown, setup_platform
async def test_select(
@@ -107,10 +107,11 @@ async def test_errors(hass: HomeAssistant) -> None:
await setup_platform(hass, [Platform.SELECT])
# Test changing vehicle select with unknown error
exc = error_unknown()
with (
patch(
"homeassistant.components.tessie.select.set_seat_heat",
side_effect=ERROR_UNKNOWN,
side_effect=exc,
) as mock_set,
pytest.raises(HomeAssistantError) as error,
):
@@ -124,7 +125,7 @@ async def test_errors(hass: HomeAssistant) -> None:
blocking=True,
)
mock_set.assert_called_once()
assert error.value.__cause__ == ERROR_UNKNOWN
assert error.value.__cause__ == exc
# Test changing energy select with unknown error
with (

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Callable, Generator
from http import HTTPStatus
from pathlib import Path
from typing import Any
@@ -207,7 +207,7 @@ class MockTTS(MockPlatform):
async def mock_setup(
hass: HomeAssistant,
mock_provider: MockTTSProvider,
mock_provider: Callable[[], MockTTSProvider],
) -> None:
"""Set up a test provider."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))

View File

@@ -3,7 +3,7 @@
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
"""
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable
from contextlib import ExitStack
from pathlib import Path
from unittest.mock import MagicMock
@@ -60,22 +60,24 @@ async def internal_url_mock(hass: HomeAssistant) -> None:
@pytest.fixture
async def mock_tts(hass: HomeAssistant, mock_provider) -> None:
async def mock_tts(
hass: HomeAssistant, mock_provider: Callable[[], MockTTSProvider]
) -> None:
"""Mock TTS."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS(mock_provider))
mock_platform(hass, "test.tts", MockTTS(mock_provider()))
@pytest.fixture
def mock_provider() -> MockTTSProvider:
def mock_provider() -> Callable[[], MockTTSProvider]:
"""Test TTS provider."""
return MockTTSProvider(DEFAULT_LANG)
return lambda: MockTTSProvider(DEFAULT_LANG)
@pytest.fixture
def mock_tts_entity() -> MockTTSEntity:
def mock_tts_entity() -> Callable[[], MockTTSEntity]:
"""Test TTS entity."""
return MockTTSEntity(DEFAULT_LANG)
return lambda: MockTTSEntity(DEFAULT_LANG)
class TTSFlow(ConfigFlow):
@@ -106,13 +108,13 @@ def config_flow_fixture(
async def setup_fixture(
hass: HomeAssistant,
request: pytest.FixtureRequest,
mock_provider: MockTTSProvider,
mock_tts_entity: MockTTSEntity,
mock_provider: Callable[[], MockTTSProvider],
mock_tts_entity: Callable[[], MockTTSEntity],
) -> None:
"""Set up the test environment."""
if request.param == "mock_setup":
await mock_setup(hass, mock_provider)
await mock_setup(hass, mock_provider())
elif request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, mock_tts_entity)
await mock_config_entry_setup(hass, mock_tts_entity())
else:
raise RuntimeError("Invalid setup fixture")

View File

@@ -1,5 +1,7 @@
"""Tests for the TTS entity."""
from collections.abc import Callable
import pytest
from homeassistant.components import tts
@@ -38,14 +40,14 @@ async def test_default_entity_attributes() -> None:
async def test_restore_state(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
mock_tts_entity: Callable[[], MockTTSEntity],
) -> None:
"""Test we restore state in the integration."""
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
config_entry = await mock_config_entry_setup(hass, mock_tts_entity())
await hass.async_block_till_done()
assert config_entry.state is ConfigEntryState.LOADED

View File

@@ -1,6 +1,7 @@
"""The tests for the TTS component."""
import asyncio
from collections.abc import Callable
from http import HTTPStatus
from pathlib import Path
from typing import Any
@@ -48,7 +49,7 @@ ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
async def test_config_entry_unload(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MockTTSEntity,
mock_tts_entity: Callable[[], MockTTSEntity],
freezer: FrozenDateTimeFactory,
) -> None:
"""Test we can unload config entry."""
@@ -56,7 +57,7 @@ async def test_config_entry_unload(
state = hass.states.get(entity_id)
assert state is None
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
config_entry = await mock_config_entry_setup(hass, mock_tts_entity())
assert config_entry.state is ConfigEntryState.LOADED
state = hass.states.get(entity_id)
assert state is not None
@@ -178,7 +179,7 @@ async def test_service(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
[(lambda: MockTTSProvider("de_DE"), lambda: MockTTSEntity("de_DE"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@@ -242,7 +243,7 @@ async def test_service_default_language(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockTTSProvider("en_US"), MockTTSEntity("en_US"))],
[(lambda: MockTTSProvider("en_US"), lambda: MockTTSEntity("en_US"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@@ -498,7 +499,12 @@ class MockEntityWithDefaults(MockTTSEntity):
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockProviderWithDefaults(DEFAULT_LANG), MockEntityWithDefaults(DEFAULT_LANG))],
[
(
lambda: MockProviderWithDefaults(DEFAULT_LANG),
lambda: MockEntityWithDefaults(DEFAULT_LANG),
)
],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@@ -567,7 +573,12 @@ async def test_service_default_options(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockProviderWithDefaults(DEFAULT_LANG), MockEntityWithDefaults(DEFAULT_LANG))],
[
(
lambda: MockProviderWithDefaults(DEFAULT_LANG),
lambda: MockEntityWithDefaults(DEFAULT_LANG),
)
],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@@ -830,7 +841,7 @@ async def test_service_receive_voice(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
[(lambda: MockTTSProvider("de_DE"), lambda: MockTTSEntity("de_DE"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@@ -1013,11 +1024,11 @@ class MockEntityBoom(MockTTSEntity):
raise Exception("Boom!") # noqa: TRY002
@pytest.mark.parametrize("mock_provider", [MockProviderBoom(DEFAULT_LANG)])
@pytest.mark.parametrize("mock_provider", [lambda: MockProviderBoom(DEFAULT_LANG)])
async def test_setup_legacy_cache_dir(
hass: HomeAssistant,
mock_tts_cache_dir: Path,
mock_provider: MockTTSProvider,
mock_provider: Callable[[], MockTTSProvider],
) -> None:
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@@ -1028,7 +1039,7 @@ async def test_setup_legacy_cache_dir(
)
await hass.async_add_executor_job(Path(cache_file).write_bytes, tts_data)
await mock_setup(hass, mock_provider)
await mock_setup(hass, mock_provider())
await hass.services.async_call(
tts.DOMAIN,
@@ -1051,11 +1062,11 @@ async def test_setup_legacy_cache_dir(
await hass.async_block_till_done()
@pytest.mark.parametrize("mock_tts_entity", [MockEntityBoom(DEFAULT_LANG)])
@pytest.mark.parametrize("mock_tts_entity", [lambda: MockEntityBoom(DEFAULT_LANG)])
async def test_setup_cache_dir(
hass: HomeAssistant,
mock_tts_cache_dir: Path,
mock_tts_entity: MockTTSEntity,
mock_tts_entity: Callable[[], MockTTSEntity],
) -> None:
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@@ -1066,7 +1077,7 @@ async def test_setup_cache_dir(
)
await hass.async_add_executor_job(Path(cache_file).write_bytes, tts_data)
await mock_config_entry_setup(hass, mock_tts_entity)
await mock_config_entry_setup(hass, mock_tts_entity())
await hass.services.async_call(
tts.DOMAIN,
@@ -1111,7 +1122,7 @@ class MockEntityEmpty(MockTTSEntity):
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockProviderEmpty(DEFAULT_LANG), MockEntityEmpty(DEFAULT_LANG))],
[(lambda: MockProviderEmpty(DEFAULT_LANG), lambda: MockEntityEmpty(DEFAULT_LANG))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
@@ -1161,7 +1172,7 @@ async def test_service_get_tts_error(
async def test_legacy_cannot_retrieve_without_token(
hass: HomeAssistant,
mock_provider: MockTTSProvider,
mock_provider: Callable[[], MockTTSProvider],
mock_tts_cache_dir: Path,
hass_client: ClientSessionGenerator,
) -> None:
@@ -1172,7 +1183,7 @@ async def test_legacy_cannot_retrieve_without_token(
)
await hass.async_add_executor_job(Path(cache_file).write_bytes, tts_data)
await mock_setup(hass, mock_provider)
await mock_setup(hass, mock_provider())
client = await hass_client()
@@ -1184,7 +1195,7 @@ async def test_legacy_cannot_retrieve_without_token(
async def test_cannot_retrieve_without_token(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
mock_tts_entity: Callable[[], MockTTSEntity],
mock_tts_cache_dir: Path,
hass_client: ClientSessionGenerator,
) -> None:
@@ -1195,7 +1206,7 @@ async def test_cannot_retrieve_without_token(
)
await hass.async_add_executor_job(Path(cache_file).write_bytes, tts_data)
await mock_config_entry_setup(hass, mock_tts_entity)
await mock_config_entry_setup(hass, mock_tts_entity())
client = await hass_client()
@@ -1651,7 +1662,7 @@ async def test_ws_list_engines(
async def test_ws_list_engines_deprecated(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_tts_entity: MockTTSEntity,
mock_tts_entity: Callable[[], MockTTSEntity],
) -> None:
"""Test listing tts engines.
@@ -1668,7 +1679,7 @@ async def test_ws_list_engines_deprecated(
await async_setup_component(
hass, "tts", {"tts": [{"platform": "test"}, {"platform": "test_2"}]}
)
await mock_config_entry_setup(hass, mock_tts_entity)
await mock_config_entry_setup(hass, mock_tts_entity())
client = await hass_ws_client()
@@ -1822,18 +1833,19 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
async def test_default_engine_prefer_entity(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
mock_provider: MockTTSProvider,
mock_tts_entity: Callable[[], MockTTSEntity],
mock_provider: Callable[[], MockTTSProvider],
) -> None:
"""Test async_default_engine.
In this tests there's an entity and a legacy provider.
The test asserts async_default_engine returns the entity.
"""
mock_tts_entity._attr_name = "New test"
tts_entity = mock_tts_entity()
tts_entity._attr_name = "New test"
await mock_setup(hass, mock_provider)
await mock_config_entry_setup(hass, mock_tts_entity)
await mock_setup(hass, mock_provider())
await mock_config_entry_setup(hass, tts_entity)
await hass.async_block_till_done()
entity_engine = tts.async_resolve_engine(hass, "tts.new_test")
@@ -1854,7 +1866,7 @@ async def test_default_engine_prefer_entity(
)
async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant,
mock_provider: MockTTSProvider,
mock_provider: Callable[[], MockTTSProvider],
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.
@@ -1863,7 +1875,7 @@ async def test_default_engine_prefer_cloud_entity(
and a legacy provider.
The test asserts async_default_engine returns the entity from domain cloud.
"""
await mock_setup(hass, mock_provider)
await mock_setup(hass, mock_provider())
for domain in config_flow_test_domains:
entity = MockTTSEntity(DEFAULT_LANG)
entity._attr_name = f"{domain} TTS entity"
@@ -1878,13 +1890,16 @@ async def test_default_engine_prefer_cloud_entity(
assert tts.async_default_engine(hass) == "tts.cloud_tts_entity"
async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> None:
async def test_stream(
hass: HomeAssistant, mock_tts_entity: Callable[[], MockTTSEntity]
) -> None:
"""Test creating streams."""
await mock_config_entry_setup(hass, mock_tts_entity)
tts_entity = mock_tts_entity()
await mock_config_entry_setup(hass, tts_entity)
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
assert stream.language == mock_tts_entity.default_language
assert stream.options == (mock_tts_entity.default_options or {})
stream = tts.async_create_stream(hass, tts_entity.entity_id)
assert stream.language == tts_entity.default_language
assert stream.options == (tts_entity.default_options or {})
assert tts.async_get_stream(hass, stream.token) is stream
stream.async_set_message("beer")
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
@@ -1904,7 +1919,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
data_gen=gen_data(),
)
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
tts_entity.async_stream_tts_audio = async_stream_tts_audio
async def stream_message():
"""Mock stream message."""
@@ -1912,7 +1927,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
yield "ll"
yield "o"
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
stream = tts.async_create_stream(hass, tts_entity.entity_id)
stream.async_set_message_stream(stream_message())
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
assert result_data == b"hello"

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
import pytest
@@ -77,7 +78,7 @@ async def test_invalid_platform(
async def test_platform_setup_without_provider(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mock_provider: MockTTSProvider,
mock_provider: Callable[[], MockTTSProvider],
) -> None:
"""Test platform setup without provider returned."""
@@ -94,7 +95,7 @@ async def test_platform_setup_without_provider(
return None
mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider()))
await async_load_platform(
hass,
@@ -111,7 +112,7 @@ async def test_platform_setup_without_provider(
async def test_platform_setup_with_error(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mock_provider: MockTTSProvider,
mock_provider: Callable[[], MockTTSProvider],
) -> None:
"""Test platform setup with an error during setup."""
@@ -128,7 +129,7 @@ async def test_platform_setup_with_error(
raise Exception("Setup error") # noqa: TRY002
mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider()))
await async_load_platform(
hass,

View File

@@ -1,5 +1,6 @@
"""Tests for TTS media source."""
from collections.abc import Callable
from http import HTTPStatus
import re
from unittest.mock import MagicMock
@@ -47,7 +48,7 @@ async def setup_media_source(hass: HomeAssistant) -> None:
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
[(lambda: MSProvider(DEFAULT_LANG), lambda: MSEntity(DEFAULT_LANG))],
)
@pytest.mark.parametrize(
"setup",
@@ -101,24 +102,28 @@ async def test_browsing(hass: HomeAssistant, setup: str) -> None:
@pytest.mark.parametrize(
("mock_provider", "extra_options"),
[
(MSProvider(DEFAULT_LANG), "&tts_options=%7B%22voice%22%3A%22Paulus%22%7D"),
(MSProvider(DEFAULT_LANG), "&voice=Paulus"),
(
lambda: MSProvider(DEFAULT_LANG),
"&tts_options=%7B%22voice%22%3A%22Paulus%22%7D",
),
(lambda: MSProvider(DEFAULT_LANG), "&voice=Paulus"),
],
)
async def test_legacy_resolving(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_provider: MSProvider,
mock_provider: Callable[[], MSProvider],
extra_options: str,
) -> None:
"""Test resolving legacy provider."""
await mock_setup(hass, mock_provider)
mock_get_tts_audio = mock_provider.get_tts_audio
provider = mock_provider()
await mock_setup(hass, provider)
mock_get_tts_audio = provider.get_tts_audio
mock_provider.has_entity = True
provider.has_entity = True
root = await media_source.async_browse_media(hass, "media-source://tts")
assert len(root.children) == 0
mock_provider.has_entity = False
provider.has_entity = False
root = await media_source.async_browse_media(hass, "media-source://tts")
assert len(root.children) == 1
@@ -155,19 +160,23 @@ async def test_legacy_resolving(
@pytest.mark.parametrize(
("mock_tts_entity", "extra_options"),
[
(MSEntity(DEFAULT_LANG), "&tts_options=%7B%22voice%22%3A%22Paulus%22%7D"),
(MSEntity(DEFAULT_LANG), "&voice=Paulus"),
(
lambda: MSEntity(DEFAULT_LANG),
"&tts_options=%7B%22voice%22%3A%22Paulus%22%7D",
),
(lambda: MSEntity(DEFAULT_LANG), "&voice=Paulus"),
],
)
async def test_resolving(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MSEntity,
mock_tts_entity: Callable[[], MSEntity],
extra_options: str,
) -> None:
"""Test resolving entity."""
await mock_config_entry_setup(hass, mock_tts_entity)
mock_get_tts_audio = mock_tts_entity.get_tts_audio
tts_entity = mock_tts_entity()
await mock_config_entry_setup(hass, tts_entity)
mock_get_tts_audio = tts_entity.get_tts_audio
mock_get_tts_audio.reset_mock()
media_id = "media-source://tts/tts.test?message=Hello%20World"
@@ -201,7 +210,7 @@ async def test_resolving(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
[(lambda: MSProvider(DEFAULT_LANG), lambda: MSEntity(DEFAULT_LANG))],
)
@pytest.mark.parametrize(
("setup", "engine"),

View File

@@ -1,5 +1,6 @@
"""The tests for the TTS component."""
from collections.abc import Callable
from unittest.mock import patch
import pytest
@@ -127,7 +128,7 @@ async def test_setup_legacy_service(hass: HomeAssistant) -> None:
async def test_setup_service(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
hass: HomeAssistant, mock_tts_entity: Callable[[], MockTTSEntity]
) -> None:
"""Set up platform and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@@ -142,7 +143,7 @@ async def test_setup_service(
},
}
await mock_config_entry_setup(hass, mock_tts_entity)
await mock_config_entry_setup(hass, mock_tts_entity())
with assert_setup_component(1, notify.DOMAIN):
assert await async_setup_component(hass, notify.DOMAIN, config)

View File

@@ -1,5 +1,6 @@
"""Tests for unifiprotect.media_source."""
from collections.abc import Callable
from datetime import datetime, timedelta
from ipaddress import IPv4Address
from unittest.mock import AsyncMock, Mock, patch
@@ -662,10 +663,10 @@ async def test_browse_media_recent_truncated(
@pytest.mark.parametrize(
("event", "expected_title"),
("make_event", "expected_title"),
[
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.RING,
@@ -679,7 +680,7 @@ async def test_browse_media_recent_truncated(
"Ring Event",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.MOTION,
@@ -693,7 +694,7 @@ async def test_browse_media_recent_truncated(
"Motion Event",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_DETECT,
@@ -716,7 +717,7 @@ async def test_browse_media_recent_truncated(
"Object Detection - Person",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_DETECT,
@@ -730,7 +731,7 @@ async def test_browse_media_recent_truncated(
"Object Detection - Person, Vehicle",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_DETECT,
@@ -744,7 +745,7 @@ async def test_browse_media_recent_truncated(
"Object Detection - License Plate, Vehicle",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_DETECT,
@@ -768,7 +769,7 @@ async def test_browse_media_recent_truncated(
"Object Detection - Vehicle: ABC1234",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_DETECT,
@@ -798,7 +799,7 @@ async def test_browse_media_recent_truncated(
"Object Detection - Car: ABC1234",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_DETECT,
@@ -833,7 +834,7 @@ async def test_browse_media_recent_truncated(
"Object Detection - Black Vehicle: ABC1234",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_DETECT,
@@ -866,7 +867,7 @@ async def test_browse_media_recent_truncated(
"Object Detection - Black Car",
),
(
Event(
lambda: Event(
model=ModelType.EVENT,
id="test_event_id",
type=EventType.SMART_AUDIO_DETECT,
@@ -886,7 +887,7 @@ async def test_browse_media_event(
ufp: MockUFPFixture,
doorbell: Camera,
fixed_now: datetime,
event: Event,
make_event: Callable[[], Event],
expected_title: str,
) -> None:
"""Test browsing specific event."""
@@ -894,6 +895,7 @@ async def test_browse_media_event(
ufp.api.get_bootstrap = AsyncMock(return_value=ufp.api.bootstrap)
await init_entry(hass, ufp, [doorbell], regenerate_ids=False)
event = make_event()
event.start = fixed_now - timedelta(seconds=20)
event.end = fixed_now
event.camera_id = doorbell.id

View File

@@ -21,7 +21,8 @@ import threading
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import AsyncMock, MagicMock, Mock, _patch, patch
from aiohttp import client
import _pytest.python_api
from aiohttp import client, web_app
from aiohttp.resolver import AsyncResolver
from aiohttp.test_utils import (
BaseTestServer,
@@ -58,6 +59,7 @@ from homeassistant import components, core as ha, loader, runner
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from homeassistant.auth.models import Credentials
from homeassistant.auth.providers import homeassistant
from homeassistant.components import api, mobile_app, websocket_api
from homeassistant.components.device_tracker.legacy import Device
# pylint: disable-next=hass-component-root-import
@@ -96,6 +98,7 @@ from homeassistant.helpers import (
issue_registry as ir,
label_registry as lr,
recorder as recorder_helper,
template,
translation as translation_helper,
)
from homeassistant.helpers.dispatcher import async_dispatcher_send
@@ -155,6 +158,40 @@ asyncio.set_event_loop_policy(runner.HassEventLoopPolicy(False))
asyncio.set_event_loop_policy = lambda policy: None
class HackLogRecord(logging.LogRecord):
"""Hack."""
def __init__(
self,
name,
level,
pathname,
lineno,
msg,
args,
exc_info,
func=None,
sinfo=None,
**kwargs,
) -> None:
"""Initialize the log record."""
super().__init__(
name, level, pathname, lineno, msg, args, exc_info, func, sinfo, **kwargs
)
msg = str(self.msg)
if self.args:
msg = msg % self.args
self.msg = msg
self.args = None
def getMessage(self):
"""Return the message for this LogRecord."""
return self.msg
logging.setLogRecordFactory(HackLogRecord)
def pytest_addoption(parser: pytest.Parser) -> None:
"""Register custom pytest options."""
parser.addoption("--dburl", action="store", default="sqlite://")
@@ -278,6 +315,67 @@ def caplog_fixture(caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture
return caplog
@pytest.fixture(autouse=True)
def clear_exception_traceback(request: pytest.FixtureRequest) -> Generator[None]:
"""Clear exception traceback after each test."""
exceptions: list[BaseException] = []
raises_ctx: list[_pytest.python_api.RaisesContext] = []
for fixture_name in request.fixturenames:
if fixture_name not in {
"addon_info_error",
"addon_store_info_error",
"api_exception",
"backup_info_side_effect",
"create_backup_error",
"doorbell_state_side_effect",
"error_type",
"error",
"exc",
"exception",
"expand_side_effect",
"expectation",
"expected_result",
"go2rtc_error",
"imap_wait_server_push_exception",
"init_tts_cache_dir_side_effect",
"install_addon_error",
"p_error",
"raise_error",
"remove_side_effect",
"raised",
"raises",
"set_active_program_option_side_effect",
"set_active_program_options_side_effect",
"set_addon_options_error",
"set_query_mock",
"set_selected_program_option_side_effect",
"set_selected_program_options_side_effect",
"side_eff",
"side_effect",
"sideeffect",
"start_addon_error",
"subscriber_side_effect",
"supervisor_error",
"test_exception",
"update_addon_error",
}:
continue
if isinstance(request.getfixturevalue(fixture_name), BaseException):
exceptions.append(request.getfixturevalue(fixture_name))
if isinstance(
request.getfixturevalue(fixture_name), _pytest.python_api.RaisesContext
):
raises_ctx.append(request.getfixturevalue(fixture_name))
yield
for ex in exceptions:
ex.__cause__ = None
ex.__context__ = None
ex.__traceback__ = None
for ctx in raises_ctx:
ctx.excinfo = None
@pytest.fixture(autouse=True, scope="module")
def garbage_collection() -> None:
"""Run garbage collection at known locations.
@@ -288,8 +386,25 @@ def garbage_collection() -> None:
handles the most common cases and let each module override
to run per test case if needed.
"""
start_live_hass_instances = len(
[hass() for hass in ha.hass_instances if hass() is not None]
)
yield
gc.collect()
gc.freeze()
end_live_hass_instances = len(
[hass() for hass in ha.hass_instances if hass() is not None]
)
if abs(start_live_hass_instances - end_live_hass_instances) > 1:
_LOGGER.error(
"Garbage collection did not clean up all Home Assistant instances. "
"Start: %s, End: %s",
start_live_hass_instances,
end_live_hass_instances,
)
pytest.fail(
f"Garbage collection did not clean up all Home Assistant instances. "
f"Start: {start_live_hass_instances}, End: {end_live_hass_instances}"
)
@pytest.fixture(autouse=True)
@@ -452,6 +567,20 @@ def reset_globals() -> Generator[None]:
frame.async_setup(None)
frame._REPORTED_INTEGRATIONS.clear()
# Reset the aiohttp cache
web_app._cached_build_middleware.cache_clear()
# Reset the recorder helper get_instance cache
recorder_helper.get_instance.cache_clear()
# Reset the template caches
api._cached_template.cache_clear()
mobile_app.webhook._cached_template.cache_clear()
websocket_api.commands._cached_template.cache_clear()
template.CACHED_TEMPLATE_LRU.clear()
template.CACHED_TEMPLATE_NO_COLLECT_LRU.clear()
template._domain_states.cache_clear()
# Reset patch_json
if patch_json.mock_objects:
obj = patch_json.mock_objects.pop()
@@ -1226,7 +1355,7 @@ async def mqtt_mock_entry(
yield _setup_mqtt_entry
@pytest.fixture(autouse=True, scope="session")
@pytest.fixture(autouse=True, scope="module")
def mock_network() -> Generator[None]:
"""Mock network."""
with (
@@ -1292,7 +1421,9 @@ def translations_once() -> Generator[_patch]:
@pytest.fixture(autouse=True, scope="module")
def evict_faked_translations(translations_once) -> Generator[_patch]:
def evict_faked_translations(
garbage_collection, translations_once
) -> Generator[_patch]:
"""Clear translations for mocked integrations from the cache after each module."""
real_component_strings = translation_helper._async_get_component_strings
with patch(
@@ -1325,7 +1456,8 @@ def disable_translations_once(
translations_once.start()
@pytest_asyncio.fixture(autouse=True, scope="session", loop_scope="session")
# @pytest_asyncio.fixture(autouse=True, scope="session", loop_scope="session")
@pytest_asyncio.fixture(autouse=True)
async def mock_zeroconf_resolver() -> AsyncGenerator[_patch]:
"""Mock out the zeroconf resolver."""
resolver = AsyncResolver()

View File

@@ -663,13 +663,16 @@ async def test_get_request_host_no_host_header(hass: HomeAssistant) -> None:
assert _get_request_host() is None
@patch("homeassistant.components.hassio.is_hassio", Mock(return_value=True))
@patch(
"homeassistant.components.hassio.is_hassio",
return_value=True,
)
@patch(
"homeassistant.components.hassio.get_host_info",
Mock(return_value={"hostname": "homeassistant"}),
return_value={"hostname": "homeassistant"},
)
async def test_get_current_request_url_with_known_host(
hass: HomeAssistant, current_request
get_host_info, is_hassio, hass: HomeAssistant, current_request
) -> None:
"""Test getting current request URL with known hosts addresses."""
hass.config.api = Mock(use_ssl=False, port=8123, local_ip="127.0.0.1")
@@ -728,13 +731,15 @@ async def test_get_current_request_url_with_known_host(
@patch(
"homeassistant.helpers.network.is_hassio",
Mock(return_value={"hostname": "homeassistant"}),
return_value={"hostname": "homeassistant"},
)
@patch(
"homeassistant.components.hassio.get_host_info",
Mock(return_value={"hostname": "hellohost"}),
return_value={"hostname": "hellohost"},
)
async def test_is_internal_request(hass: HomeAssistant, mock_current_request) -> None:
async def test_is_internal_request(
get_host_info, is_hassio, hass: HomeAssistant, mock_current_request
) -> None:
"""Test if accessing an instance on its internal URL."""
# Test with internal URL: http://example.local:8123
await async_process_ha_core_config(

View File

@@ -304,6 +304,7 @@ async def test_refresh_known_errors(
assert crd.last_update_success is False
assert isinstance(crd.last_exception, err_msg[1])
assert err_msg[2] in caplog.text
err_msg[0].__traceback__ = None
async def test_refresh_fail_unknown(
@@ -564,6 +565,7 @@ async def test_async_config_entry_first_refresh_failure(
assert crd.last_update_success is False
assert isinstance(crd.last_exception, err_msg[1])
assert err_msg[2] not in caplog.text
err_msg[0].__traceback__ = None
@pytest.mark.parametrize(
@@ -602,6 +604,7 @@ async def test_async_config_entry_first_refresh_failure_passed_through(
assert crd.last_update_success is False
assert isinstance(crd.last_exception, err_msg[1])
assert err_msg[2] not in caplog.text
err_msg[0].__traceback__ = None
async def test_async_config_entry_first_refresh_success(hass: HomeAssistant) -> None: