Recreate aiohttp ClientSession after DNS plug-in load (#5862)

* Recreate aiohttp ClientSession after DNS plug-in load

Create a temporary ClientSession early in case we need to load version
information from the internet. This doesn't use the final DNS setup
and hence might fail to load in certain situations since we don't have
the fallback mechanims in place yet. But if the DNS container image
is present, we'll continue the setup and load the DNS plug-in. We then
can recreate the ClientSession such that it uses the DNS plug-in.

This works around an issue with aiodns, which today doesn't reload
`resolv.conf` automatically when it changes. This lead to Supervisor
using the initial `resolv.conf` as created by Docker. It meant that
we did not use the DNS plug-in (and its fallback capabilities) in
Supervisor. Also it meant that changes to the DNS setup at runtime
did not propagate to the aiohttp ClientSession (as observed in #5332).

* Mock aiohttp.ClientSession for all tests

Currently in several places pytest actually uses the aiohttp
ClientSession and reaches out to the internet. This is not ideal
for unit tests and should be avoided.

This creates several new fixtures to aid this effort: The `websession`
fixture simply returns a mocked aiohttp.ClientSession, which can be
used whenever a function is tested which needs the global websession.

A separate new fixture to mock the connectivity check named
`supervisor_internet` since this is often used through the Job
decorator which require INTERNET_SYSTEM.

And the `mock_update_data` uses the already existing update json
test data from the fixture directory instead of loading the data
from the internet.

* Log ClientSession nameserver information

When recreating the aiohttp ClientSession, log information what
nameservers exactly are going to be used.

* Refuse ClientSession initialization when API is available

Previous attempts to reinitialize the ClientSession have shown
use of the ClientSession after it was closed due to API requets
being handled in parallel to the reinitialization (see #5851).
Make sure this is not possible by refusing to reinitialize the
ClientSession when the API is available.

* Fix pytests

Also sure we don't create aiohttp ClientSession objects unnecessarily.

* Apply suggestions from code review

Co-authored-by: Jan Čermák <sairon@users.noreply.github.com>

---------

Co-authored-by: Jan Čermák <sairon@users.noreply.github.com>
This commit is contained in:
Stefan Agner 2025-05-06 16:23:40 +02:00 committed by GitHub
parent 2e44e6494f
commit 85f8107b60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 256 additions and 102 deletions

View File

@ -230,6 +230,9 @@ filterwarnings = [
"ignore:pkg_resources is deprecated as an API:DeprecationWarning:dirhash", "ignore:pkg_resources is deprecated as an API:DeprecationWarning:dirhash",
"ignore::pytest.PytestUnraisableExceptionWarning", "ignore::pytest.PytestUnraisableExceptionWarning",
] ]
markers = [
"no_mock_init_websession: disable the autouse mock of init_websession for this test",
]
[tool.ruff] [tool.ruff]
lint.select = [ lint.select = [

View File

@ -20,7 +20,7 @@ from ...const import (
ROLE_DEFAULT, ROLE_DEFAULT,
ROLE_HOMEASSISTANT, ROLE_HOMEASSISTANT,
ROLE_MANAGER, ROLE_MANAGER,
CoreState, VALID_API_STATES,
) )
from ...coresys import CoreSys, CoreSysAttributes from ...coresys import CoreSys, CoreSysAttributes
from ...utils import version_is_new_enough from ...utils import version_is_new_enough
@ -200,11 +200,7 @@ class SecurityMiddleware(CoreSysAttributes):
@middleware @middleware
async def system_validation(self, request: Request, handler: Callable) -> Response: async def system_validation(self, request: Request, handler: Callable) -> Response:
"""Check if core is ready to response.""" """Check if core is ready to response."""
if self.sys_core.state not in ( if self.sys_core.state not in VALID_API_STATES:
CoreState.STARTUP,
CoreState.RUNNING,
CoreState.FREEZE,
):
return api_return_error( return api_return_error(
message=f"System is not ready with state: {self.sys_core.state}" message=f"System is not ready with state: {self.sys_core.state}"
) )

View File

@ -552,3 +552,12 @@ STARTING_STATES = [
CoreState.STARTUP, CoreState.STARTUP,
CoreState.SETUP, CoreState.SETUP,
] ]
# States in which the API can be used (enforced by system_validation())
VALID_API_STATES = frozenset(
{
CoreState.STARTUP,
CoreState.RUNNING,
CoreState.FREEZE,
}
)

View File

@ -124,6 +124,19 @@ class Core(CoreSysAttributes):
"""Start setting up supervisor orchestration.""" """Start setting up supervisor orchestration."""
await self.set_state(CoreState.SETUP) await self.set_state(CoreState.SETUP)
# Initialize websession early. At this point we'll use the Docker DNS proxy
# at 127.0.0.11, which does not have the fallback feature and hence might
# fail in certain environments. But a websession is required to get the
# initial version information after a device wipe or otherwise empty state
# (e.g. CI environment, Supervised).
#
# An OS installation has the plug-in container images pre-installed, so we
# setup can continue even if this early websession fails to connect to the
# internet. We'll reinitialize the websession when the DNS plug-in is up to
# make sure the DNS plug-in along with its fallback capabilities is used
# (see #5857).
await self.coresys.init_websession()
# Check internet on startup # Check internet on startup
await self.sys_supervisor.check_connectivity() await self.sys_supervisor.check_connectivity()

View File

@ -21,6 +21,7 @@ from .const import (
ENV_SUPERVISOR_MACHINE, ENV_SUPERVISOR_MACHINE,
MACHINE_ID, MACHINE_ID,
SERVER_SOFTWARE, SERVER_SOFTWARE,
VALID_API_STATES,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -68,7 +69,6 @@ class CoreSys:
# External objects # External objects
self._loop: asyncio.BaseEventLoop = asyncio.get_running_loop() self._loop: asyncio.BaseEventLoop = asyncio.get_running_loop()
self._websession: aiohttp.ClientSession = aiohttp.ClientSession()
# Global objects # Global objects
self._config: CoreConfig = CoreConfig() self._config: CoreConfig = CoreConfig()
@ -100,11 +100,7 @@ class CoreSys:
self._security: Security | None = None self._security: Security | None = None
self._bus: Bus | None = None self._bus: Bus | None = None
self._mounts: MountManager | None = None self._mounts: MountManager | None = None
self._websession: aiohttp.ClientSession | None = None
# Set default header for aiohttp
self._websession._default_headers = MappingProxyType(
{aiohttp.hdrs.USER_AGENT: SERVER_SOFTWARE}
)
# Task factory attributes # Task factory attributes
self._set_task_context: list[Callable[[Context], Context]] = [] self._set_task_context: list[Callable[[Context], Context]] = []
@ -114,6 +110,33 @@ class CoreSys:
await self.config.read_data() await self.config.read_data()
return self return self
async def init_websession(self) -> None:
"""Initialize global aiohttp ClientSession."""
if self.core.state in VALID_API_STATES:
# Make sure we don't reinitialize the session if the API is running (see #5851)
raise RuntimeError(
"Initializing ClientSession is not safe when API is running"
)
if self._websession:
await self._websession.close()
resolver = aiohttp.AsyncResolver()
# pylint: disable=protected-access
_LOGGER.debug(
"Initializing ClientSession with AsyncResolver. Using nameservers %s",
resolver._resolver.nameservers,
)
connector = aiohttp.TCPConnector(loop=self.loop, resolver=resolver)
session = aiohttp.ClientSession(
headers=MappingProxyType({aiohttp.hdrs.USER_AGENT: SERVER_SOFTWARE}),
connector=connector,
)
self._websession = session
async def init_machine(self): async def init_machine(self):
"""Initialize machine information.""" """Initialize machine information."""
@ -165,6 +188,8 @@ class CoreSys:
@property @property
def websession(self) -> aiohttp.ClientSession: def websession(self) -> aiohttp.ClientSession:
"""Return websession object.""" """Return websession object."""
if self._websession is None:
raise RuntimeError("WebSession not setup yet")
return self._websession return self._websession
@property @property

View File

@ -177,7 +177,13 @@ class PluginDns(PluginBase):
# Update supervisor # Update supervisor
await self._write_resolv(HOST_RESOLV) await self._write_resolv(HOST_RESOLV)
await self.sys_supervisor.check_connectivity()
# Reinitializing aiohttp.ClientSession after DNS setup makes sure that
# aiodns is using the right DNS servers (see #5857).
# At this point it should be fairly safe to replace the session since
# we only use the session synchronously during setup and not thorugh the
# API which previously caused issues (see #5851).
await self.coresys.init_websession()
async def install(self) -> None: async def install(self) -> None:
"""Install CoreDNS.""" """Install CoreDNS."""

View File

@ -190,7 +190,7 @@ async def test_addon_shutdown_error(
async def test_addon_uninstall_removes_discovery( async def test_addon_uninstall_removes_discovery(
coresys: CoreSys, install_addon_ssh: Addon coresys: CoreSys, install_addon_ssh: Addon, websession: MagicMock
): ):
"""Test discovery messages removed when addon uninstalled.""" """Test discovery messages removed when addon uninstalled."""
assert coresys.discovery.list_messages == [] assert coresys.discovery.list_messages == []
@ -203,7 +203,6 @@ async def test_addon_uninstall_removes_discovery(
assert coresys.discovery.list_messages == [message] assert coresys.discovery.list_messages == [message]
coresys.homeassistant.api.ensure_access_token = AsyncMock() coresys.homeassistant.api.ensure_access_token = AsyncMock()
coresys.websession.delete = MagicMock()
await coresys.addons.uninstall(TEST_ADDON_SLUG) await coresys.addons.uninstall(TEST_ADDON_SLUG)
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@ -1,7 +1,7 @@
"""Test auth API.""" """Test auth API."""
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
import pytest import pytest
@ -9,6 +9,7 @@ import pytest
from supervisor.addons.addon import Addon from supervisor.addons.addon import Addon
from supervisor.coresys import CoreSys from supervisor.coresys import CoreSys
from tests.common import MockResponse
from tests.const import TEST_ADDON_SLUG from tests.const import TEST_ADDON_SLUG
LIST_USERS_RESPONSE = [ LIST_USERS_RESPONSE = [
@ -78,7 +79,10 @@ def fixture_mock_check_login(coresys: CoreSys):
async def test_password_reset( async def test_password_reset(
api_client: TestClient, coresys: CoreSys, caplog: pytest.LogCaptureFixture api_client: TestClient,
coresys: CoreSys,
caplog: pytest.LogCaptureFixture,
websession: MagicMock,
): ):
"""Test password reset api.""" """Test password reset api."""
coresys.homeassistant.api.access_token = "abc123" coresys.homeassistant.api.access_token = "abc123"
@ -87,15 +91,12 @@ async def test_password_reset(
days=1 days=1
) )
mock_websession = AsyncMock() websession.post = MagicMock(return_value=MockResponse(status=200))
mock_websession.post.return_value.__aenter__.return_value.status = 200 resp = await api_client.post(
with patch("supervisor.coresys.aiohttp.ClientSession.post") as post: "/auth/reset", json={"username": "john", "password": "doe"}
post.return_value.__aenter__.return_value.status = 200 )
resp = await api_client.post( assert resp.status == 200
"/auth/reset", json={"username": "john", "password": "doe"} assert "Successful password reset for 'john'" in caplog.text
)
assert resp.status == 200
assert "Successful password reset for 'john'" in caplog.text
async def test_list_users( async def test_list_users(

View File

@ -276,6 +276,7 @@ async def test_api_backup_restore_background(
backup_type: str, backup_type: str,
options: dict[str, Any], options: dict[str, Any],
tmp_supervisor_data: Path, tmp_supervisor_data: Path,
supervisor_internet: AsyncMock,
): ):
"""Test background option on backup/restore APIs.""" """Test background option on backup/restore APIs."""
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
@ -472,6 +473,7 @@ async def test_restore_immediate_errors(
api_client: TestClient, api_client: TestClient,
coresys: CoreSys, coresys: CoreSys,
mock_partial_backup: Backup, mock_partial_backup: Backup,
supervisor_internet: AsyncMock,
): ):
"""Test restore errors that return immediately even in background mode.""" """Test restore errors that return immediately even in background mode."""
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
@ -1010,6 +1012,7 @@ async def test_restore_backup_from_location(
coresys: CoreSys, coresys: CoreSys,
tmp_supervisor_data: Path, tmp_supervisor_data: Path,
local_location: str | None, local_location: str | None,
supervisor_internet: AsyncMock,
): ):
"""Test restoring a backup from a specific location.""" """Test restoring a backup from a specific location."""
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
@ -1059,6 +1062,7 @@ async def test_restore_backup_from_location(
async def test_restore_backup_unencrypted_after_encrypted( async def test_restore_backup_unencrypted_after_encrypted(
api_client: TestClient, api_client: TestClient,
coresys: CoreSys, coresys: CoreSys,
supervisor_internet: AsyncMock,
): ):
"""Test restoring an unencrypted backup after an encrypted backup and vis-versa.""" """Test restoring an unencrypted backup after an encrypted backup and vis-versa."""
enc_tar = copy(get_fixture_path("test_consolidate.tar"), coresys.config.path_backup) enc_tar = copy(get_fixture_path("test_consolidate.tar"), coresys.config.path_backup)
@ -1131,6 +1135,7 @@ async def test_restore_homeassistant_adds_env(
docker: DockerAPI, docker: DockerAPI,
backup_type: str, backup_type: str,
postbody: dict[str, Any], postbody: dict[str, Any],
supervisor_internet: AsyncMock,
): ):
"""Test restoring home assistant from backup adds env to container.""" """Test restoring home assistant from backup adds env to container."""
event = asyncio.Event() event = asyncio.Event()
@ -1328,6 +1333,7 @@ async def test_missing_file_removes_location_from_cache(
url_path: str, url_path: str,
body: dict[str, Any] | None, body: dict[str, Any] | None,
backup_file: str, backup_file: str,
supervisor_internet: AsyncMock,
): ):
"""Test finding a missing file removes the location from cache.""" """Test finding a missing file removes the location from cache."""
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
@ -1387,6 +1393,7 @@ async def test_missing_file_removes_backup_from_cache(
url_path: str, url_path: str,
body: dict[str, Any] | None, body: dict[str, Any] | None,
backup_file: str, backup_file: str,
supervisor_internet: AsyncMock,
): ):
"""Test finding a missing file removes the backup from cache if its the only one.""" """Test finding a missing file removes the backup from cache if its the only one."""
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
@ -1412,7 +1419,9 @@ async def test_missing_file_removes_backup_from_cache(
@pytest.mark.usefixtures("tmp_supervisor_data") @pytest.mark.usefixtures("tmp_supervisor_data")
async def test_immediate_list_after_missing_file_restore( async def test_immediate_list_after_missing_file_restore(
api_client: TestClient, coresys: CoreSys api_client: TestClient,
coresys: CoreSys,
supervisor_internet: AsyncMock,
): ):
"""Test race with reload for missing file on restore does not error.""" """Test race with reload for missing file on restore does not error."""
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)

View File

@ -84,12 +84,14 @@ async def test_api_list_discovery(
@pytest.mark.parametrize("api_client", [TEST_ADDON_SLUG], indirect=True) @pytest.mark.parametrize("api_client", [TEST_ADDON_SLUG], indirect=True)
async def test_api_send_del_discovery( async def test_api_send_del_discovery(
api_client: TestClient, coresys: CoreSys, install_addon_ssh: Addon api_client: TestClient,
coresys: CoreSys,
install_addon_ssh: Addon,
websession: MagicMock,
): ):
"""Test adding and removing discovery.""" """Test adding and removing discovery."""
install_addon_ssh.data["discovery"] = ["test"] install_addon_ssh.data["discovery"] = ["test"]
coresys.homeassistant.api.ensure_access_token = AsyncMock() coresys.homeassistant.api.ensure_access_token = AsyncMock()
coresys.websession.post = MagicMock()
resp = await api_client.post("/discovery", json={"service": "test", "config": {}}) resp = await api_client.post("/discovery", json={"service": "test", "config": {}})
assert resp.status == 200 assert resp.status == 200

View File

@ -1,5 +1,7 @@
"""Test Supervisor API.""" """Test Supervisor API."""
from unittest.mock import AsyncMock
import pytest import pytest
from supervisor.coresys import CoreSys from supervisor.coresys import CoreSys
@ -36,7 +38,9 @@ async def test_api_security_options_pwned(api_client, coresys: CoreSys):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_integrity_check(api_client, coresys: CoreSys): async def test_api_integrity_check(
api_client, coresys: CoreSys, supervisor_internet: AsyncMock
):
"""Test security integrity check.""" """Test security integrity check."""
coresys.security.content_trust = False coresys.security.content_trust = False

View File

@ -2,7 +2,7 @@
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
from aiohttp import ClientResponse from aiohttp import ClientResponse
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
@ -92,7 +92,9 @@ async def test_api_store_repositories_repository(
assert result["data"]["slug"] == repository.slug assert result["data"]["slug"] == repository.slug
async def test_api_store_add_repository(api_client: TestClient, coresys: CoreSys): async def test_api_store_add_repository(
api_client: TestClient, coresys: CoreSys, supervisor_internet: AsyncMock
) -> None:
"""Test POST /store/repositories REST API.""" """Test POST /store/repositories REST API."""
with ( with (
patch("supervisor.store.repository.Repository.load", return_value=None), patch("supervisor.store.repository.Repository.load", return_value=None),

View File

@ -2,7 +2,7 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import time import time
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
from blockbuster import BlockingError from blockbuster import BlockingError
@ -34,7 +34,7 @@ async def test_api_supervisor_options_debug(api_client: TestClient, coresys: Cor
async def test_api_supervisor_options_add_repository( async def test_api_supervisor_options_add_repository(
api_client: TestClient, coresys: CoreSys api_client: TestClient, coresys: CoreSys, supervisor_internet: AsyncMock
): ):
"""Test add a repository via POST /supervisor/options REST API.""" """Test add a repository via POST /supervisor/options REST API."""
assert REPO_URL not in coresys.store.repository_urls assert REPO_URL not in coresys.store.repository_urls
@ -231,7 +231,9 @@ async def test_api_supervisor_fallback_log_capture(
capture_exception.assert_called_once() capture_exception.assert_called_once()
async def test_api_supervisor_reload(api_client: TestClient): async def test_api_supervisor_reload(
api_client: TestClient, supervisor_internet: AsyncMock, websession: MagicMock
):
"""Test supervisor reload.""" """Test supervisor reload."""
resp = await api_client.post("/supervisor/reload") resp = await api_client.post("/supervisor/reload")
assert resp.status == 200 assert resp.status == 200

View File

@ -103,3 +103,31 @@ def get_job_decorator(func) -> Job:
def reset_last_call(func, group: str | None = None) -> None: def reset_last_call(func, group: str | None = None) -> None:
"""Reset last call for a function using the Job decorator.""" """Reset last call for a function using the Job decorator."""
get_job_decorator(func).set_last_call(datetime.min, group) get_job_decorator(func).set_last_call(datetime.min, group)
class MockResponse:
"""Mock response for aiohttp requests."""
def __init__(self, *, status=200, text=""):
"""Initialize mock response."""
self.status = status
self._text = text
def update_text(self, text: str):
"""Update the text of the response."""
self._text = text
async def read(self):
"""Read the response body."""
return self._text.encode("utf-8")
async def text(self) -> str:
"""Return the response body as text."""
return self._text
async def __aenter__(self):
"""Enter the context manager."""
return self
async def __aexit__(self, exc_type, exc, tb):
"""Exit the context manager."""

View File

@ -9,7 +9,7 @@ import subprocess
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
from uuid import uuid4 from uuid import uuid4
from aiohttp import web from aiohttp import ClientSession, web
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
from blockbuster import BlockBuster, blockbuster_ctx from blockbuster import BlockBuster, blockbuster_ctx
@ -53,7 +53,13 @@ from supervisor.store.addon import AddonStore
from supervisor.store.repository import Repository from supervisor.store.repository import Repository
from supervisor.utils.dt import utcnow from supervisor.utils.dt import utcnow
from .common import load_binary_fixture, load_json_fixture, mock_dbus_services from .common import (
MockResponse,
load_binary_fixture,
load_fixture,
load_json_fixture,
mock_dbus_services,
)
from .const import TEST_ADDON_SLUG from .const import TEST_ADDON_SLUG
from .dbus_service_mocks.base import DBusServiceMock from .dbus_service_mocks.base import DBusServiceMock
from .dbus_service_mocks.network_connection_settings import ( from .dbus_service_mocks.network_connection_settings import (
@ -329,6 +335,7 @@ async def coresys(
aiohttp_client, aiohttp_client,
run_supervisor_state, run_supervisor_state,
supervisor_name, supervisor_name,
request: pytest.FixtureRequest,
) -> CoreSys: ) -> CoreSys:
"""Create a CoreSys Mock.""" """Create a CoreSys Mock."""
with ( with (
@ -397,12 +404,14 @@ async def coresys(
ha_version=AwesomeVersion("2021.2.4") ha_version=AwesomeVersion("2021.2.4")
) )
if not request.node.get_closest_marker("no_mock_init_websession"):
coresys_obj.init_websession = AsyncMock()
# Don't remove files/folders related to addons and stores # Don't remove files/folders related to addons and stores
with patch("supervisor.store.git.GitRepo._remove"): with patch("supervisor.store.git.GitRepo._remove"):
yield coresys_obj yield coresys_obj
await coresys_obj.dbus.unload() await coresys_obj.dbus.unload()
await coresys_obj.websession.close()
@pytest.fixture @pytest.fixture
@ -512,6 +521,31 @@ async def api_client(
yield await aiohttp_client(api.webapp) yield await aiohttp_client(api.webapp)
@pytest.fixture
def supervisor_internet(coresys: CoreSys) -> Generator[AsyncMock]:
"""Fixture which simluate Supervsior internet connection."""
connectivity_check = AsyncMock(return_value=True)
coresys.supervisor.check_connectivity = connectivity_check
yield connectivity_check
@pytest.fixture
def websession(coresys: CoreSys) -> Generator[MagicMock]:
"""Fixture for global aiohttp SessionClient."""
coresys._websession = MagicMock(spec_set=ClientSession)
yield coresys._websession
@pytest.fixture
def mock_update_data(websession: MagicMock) -> Generator[MockResponse]:
"""Mock updater JSON data."""
version_data = load_fixture("version_stable.json")
client_response = MockResponse(text=version_data)
client_response.status = 200
websession.get = MagicMock(return_value=client_response)
yield client_response
@pytest.fixture @pytest.fixture
def store_manager(coresys: CoreSys): def store_manager(coresys: CoreSys):
"""Fixture for the store manager.""" """Fixture for the store manager."""

View File

@ -1,7 +1,6 @@
"""Test scheduled tasks.""" """Test scheduled tasks."""
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from shutil import copy from shutil import copy
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
@ -18,7 +17,7 @@ from supervisor.homeassistant.core import HomeAssistantCore
from supervisor.misc.tasks import Tasks from supervisor.misc.tasks import Tasks
from supervisor.supervisor import Supervisor from supervisor.supervisor import Supervisor
from tests.common import get_fixture_path, load_fixture from tests.common import MockResponse, get_fixture_path
# pylint: disable=protected-access # pylint: disable=protected-access
@ -173,25 +172,17 @@ async def test_watchdog_homeassistant_api_reanimation_limit(
@pytest.mark.usefixtures("no_job_throttle") @pytest.mark.usefixtures("no_job_throttle")
async def test_reload_updater_triggers_supervisor_update( async def test_reload_updater_triggers_supervisor_update(
tasks: Tasks, coresys: CoreSys tasks: Tasks,
coresys: CoreSys,
mock_update_data: MockResponse,
supervisor_internet: AsyncMock,
): ):
"""Test an updater reload triggers a supervisor update if there is one.""" """Test an updater reload triggers a supervisor update if there is one."""
coresys.hardware.disk.get_disk_free_space = lambda x: 5000 coresys.hardware.disk.get_disk_free_space = lambda x: 5000
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
coresys.security.content_trust = False coresys.security.content_trust = False
version_data = load_fixture("version_stable.json")
version_resp = AsyncMock()
version_resp.status = 200
version_resp.read.return_value = version_data
@asynccontextmanager
async def mock_get_for_version(*args, **kwargs) -> AsyncGenerator[AsyncMock]:
"""Mock get call for version information."""
yield version_resp
with ( with (
patch("supervisor.coresys.aiohttp.ClientSession.get", new=mock_get_for_version),
patch.object( patch.object(
Supervisor, Supervisor,
"version", "version",
@ -208,7 +199,8 @@ async def test_reload_updater_triggers_supervisor_update(
update.assert_not_called() update.assert_not_called()
# Version change causes an update # Version change causes an update
version_resp.read.return_value = version_data.replace("2024.10.0", "2024.10.1") version_data = await mock_update_data.text()
mock_update_data.update_text(version_data.replace("2024.10.0", "2024.10.1"))
await tasks._reload_updater() await tasks._reload_updater()
update.assert_called_once() update.assert_called_once()

View File

@ -1,6 +1,6 @@
"""Test Home Assistant OS functionality.""" """Test Home Assistant OS functionality."""
from unittest.mock import PropertyMock, patch from unittest.mock import AsyncMock, PropertyMock, patch
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
from dbus_fast import Variant from dbus_fast import Variant
@ -10,6 +10,7 @@ from supervisor.const import CoreState
from supervisor.coresys import CoreSys from supervisor.coresys import CoreSys
from supervisor.exceptions import HassOSJobError from supervisor.exceptions import HassOSJobError
from tests.common import MockResponse
from tests.dbus_service_mocks.base import DBusServiceMock from tests.dbus_service_mocks.base import DBusServiceMock
from tests.dbus_service_mocks.rauc import Rauc as RaucService from tests.dbus_service_mocks.rauc import Rauc as RaucService
@ -17,7 +18,9 @@ from tests.dbus_service_mocks.rauc import Rauc as RaucService
@pytest.mark.usefixtures("no_job_throttle") @pytest.mark.usefixtures("no_job_throttle")
async def test_ota_url_generic_x86_64_rename(coresys: CoreSys) -> None: async def test_ota_url_generic_x86_64_rename(
coresys: CoreSys, mock_update_data: MockResponse, supervisor_internet: AsyncMock
) -> None:
"""Test download URL generated.""" """Test download URL generated."""
coresys.os._board = "intel-nuc" coresys.os._board = "intel-nuc"
coresys.os._version = AwesomeVersion("5.13") coresys.os._version = AwesomeVersion("5.13")
@ -65,7 +68,9 @@ def test_ota_url_os_name_rel_5_downgrade(coresys: CoreSys) -> None:
assert url == url_formatted assert url == url_formatted
async def test_update_fails_if_out_of_date(coresys: CoreSys) -> None: async def test_update_fails_if_out_of_date(
coresys: CoreSys, supervisor_internet: AsyncMock
) -> None:
"""Test update of OS fails if Supervisor is out of date.""" """Test update of OS fails if Supervisor is out of date."""
await coresys.core.set_state(CoreState.RUNNING) await coresys.core.set_state(CoreState.RUNNING)
with ( with (

View File

@ -1,6 +1,6 @@
"""Test plugin manager.""" """Test plugin manager."""
from unittest.mock import PropertyMock, patch from unittest.mock import AsyncMock, PropertyMock, patch
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
import pytest import pytest
@ -10,6 +10,8 @@ from supervisor.docker.interface import DockerInterface
from supervisor.plugins.base import PluginBase from supervisor.plugins.base import PluginBase
from supervisor.supervisor import Supervisor from supervisor.supervisor import Supervisor
from tests.common import MockResponse
def mock_awaitable_bool(value: bool): def mock_awaitable_bool(value: bool):
"""Return a mock of an awaitable bool.""" """Return a mock of an awaitable bool."""
@ -37,7 +39,9 @@ async def test_repair(coresys: CoreSys):
@pytest.mark.usefixtures("no_job_throttle") @pytest.mark.usefixtures("no_job_throttle")
async def test_load(coresys: CoreSys): async def test_load(
coresys: CoreSys, mock_update_data: MockResponse, supervisor_internet: AsyncMock
):
"""Test plugin manager load.""" """Test plugin manager load."""
coresys.hardware.disk.get_disk_free_space = lambda x: 5000 coresys.hardware.disk.get_disk_free_space = lambda x: 5000
await coresys.updater.load() await coresys.updater.load()

View File

@ -52,13 +52,14 @@ async def test_check(
docker.containers.get = _make_mock_container_get( docker.containers.get = _make_mock_container_get(
["homeassistant", "hassio_audio", "addon_local_ssh"], folder ["homeassistant", "hassio_audio", "addon_local_ssh"], folder
) )
# Use state used in setup()
await coresys.core.set_state(CoreState.SETUP)
with patch.object(DockerInterface, "is_running", return_value=True): with patch.object(DockerInterface, "is_running", return_value=True):
await coresys.plugins.load() await coresys.plugins.load()
await coresys.homeassistant.load() await coresys.homeassistant.load()
await coresys.addons.load() await coresys.addons.load()
docker_config = CheckDockerConfig(coresys) docker_config = CheckDockerConfig(coresys)
await coresys.core.set_state(CoreState.RUNNING)
assert not coresys.resolution.issues assert not coresys.resolution.issues
assert not coresys.resolution.suggestions assert not coresys.resolution.suggestions

View File

@ -16,7 +16,7 @@ from supervisor.security.const import ContentTrustResult, IntegrityResult
from supervisor.utils.dt import utcnow from supervisor.utils.dt import utcnow
async def test_fixup(coresys: CoreSys): async def test_fixup(coresys: CoreSys, supervisor_internet: AsyncMock):
"""Test fixup.""" """Test fixup."""
system_execute_integrity = FixupSystemExecuteIntegrity(coresys) system_execute_integrity = FixupSystemExecuteIntegrity(coresys)
@ -42,7 +42,7 @@ async def test_fixup(coresys: CoreSys):
assert len(coresys.resolution.issues) == 0 assert len(coresys.resolution.issues) == 0
async def test_fixup_error(coresys: CoreSys): async def test_fixup_error(coresys: CoreSys, supervisor_internet: AsyncMock):
"""Test fixup.""" """Test fixup."""
system_execute_integrity = FixupSystemExecuteIntegrity(coresys) system_execute_integrity = FixupSystemExecuteIntegrity(coresys)

View File

@ -34,7 +34,7 @@ async def test_write_state(run_supervisor_state: MagicMock, coresys: CoreSys):
) )
async def test_adjust_system_datetime(coresys: CoreSys): async def test_adjust_system_datetime(coresys: CoreSys, websession: MagicMock):
"""Test _adjust_system_datetime method with successful retrieve_whoami.""" """Test _adjust_system_datetime method with successful retrieve_whoami."""
utc_ts = datetime.datetime.now().replace(tzinfo=datetime.UTC) utc_ts = datetime.datetime.now().replace(tzinfo=datetime.UTC)
with patch( with patch(
@ -52,7 +52,9 @@ async def test_adjust_system_datetime(coresys: CoreSys):
mock_retrieve_whoami.assert_not_called() mock_retrieve_whoami.assert_not_called()
async def test_adjust_system_datetime_without_ssl(coresys: CoreSys): async def test_adjust_system_datetime_without_ssl(
coresys: CoreSys, websession: MagicMock
):
"""Test _adjust_system_datetime method when retrieve_whoami raises WhoamiSSLError.""" """Test _adjust_system_datetime method when retrieve_whoami raises WhoamiSSLError."""
utc_ts = datetime.datetime.now().replace(tzinfo=datetime.UTC) utc_ts = datetime.datetime.now().replace(tzinfo=datetime.UTC)
with patch( with patch(
@ -67,7 +69,9 @@ async def test_adjust_system_datetime_without_ssl(coresys: CoreSys):
assert coresys.core.sys_config.timezone == "Europe/Zurich" assert coresys.core.sys_config.timezone == "Europe/Zurich"
async def test_adjust_system_datetime_if_time_behind(coresys: CoreSys): async def test_adjust_system_datetime_if_time_behind(
coresys: CoreSys, websession: MagicMock
):
"""Test _adjust_system_datetime method when current time is ahead more than 3 days.""" """Test _adjust_system_datetime method when current time is ahead more than 3 days."""
utc_ts = datetime.datetime.now().replace(tzinfo=datetime.UTC) + datetime.timedelta( utc_ts = datetime.datetime.now().replace(tzinfo=datetime.UTC) + datetime.timedelta(
days=4 days=4

View File

@ -1,9 +1,12 @@
"""Testing handling with CoreState.""" """Testing handling with CoreState."""
from datetime import timedelta from datetime import timedelta
from unittest.mock import MagicMock, patch
from aiohttp.hdrs import USER_AGENT from aiohttp.hdrs import USER_AGENT
import pytest
from supervisor.const import CoreState
from supervisor.coresys import CoreSys from supervisor.coresys import CoreSys
from supervisor.dbus.timedate import TimeDate from supervisor.dbus.timedate import TimeDate
from supervisor.utils.dt import utcnow from supervisor.utils.dt import utcnow
@ -36,9 +39,25 @@ async def test_now(coresys: CoreSys):
assert zurich - utc <= timedelta(hours=2) assert zurich - utc <= timedelta(hours=2)
def test_custom_user_agent(coresys: CoreSys): @pytest.mark.no_mock_init_websession
async def test_custom_user_agent(coresys: CoreSys):
"""Test custom useragent.""" """Test custom useragent."""
assert ( with patch(
"HomeAssistantSupervisor/9999.09.9.dev9999" "supervisor.coresys.aiohttp.ClientSession", return_value=MagicMock()
in coresys.websession._default_headers[USER_AGENT] # pylint: disable=protected-access ) as mock_session:
) await coresys.init_websession()
assert (
"HomeAssistantSupervisor/9999.09.9.dev9999"
in mock_session.call_args_list[0][1]["headers"][USER_AGENT]
)
@pytest.mark.no_mock_init_websession
async def test_no_init_when_api_running(coresys: CoreSys):
"""Test ClientSession reinitialization is refused when API is running."""
with patch("supervisor.coresys.aiohttp.ClientSession"):
await coresys.init_websession()
await coresys.core.set_state(CoreState.RUNNING)
# Reinitialize websession should not be possible while running
with pytest.raises(RuntimeError):
await coresys.init_websession()

View File

@ -2,7 +2,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import errno import errno
from unittest.mock import AsyncMock, Mock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
from aiohttp import ClientTimeout from aiohttp import ClientTimeout
from aiohttp.client_exceptions import ClientError from aiohttp.client_exceptions import ClientError
@ -23,17 +23,7 @@ from supervisor.resolution.const import ContextType, IssueType
from supervisor.resolution.data import Issue from supervisor.resolution.data import Issue
from supervisor.supervisor import Supervisor from supervisor.supervisor import Supervisor
from tests.common import reset_last_call from tests.common import MockResponse, reset_last_call
@pytest.fixture(name="websession", scope="function")
async def fixture_webession(coresys: CoreSys) -> AsyncMock:
"""Mock of websession."""
mock_websession = AsyncMock()
with patch.object(
type(coresys), "websession", new=PropertyMock(return_value=mock_websession)
):
yield mock_websession
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -42,14 +32,14 @@ async def fixture_webession(coresys: CoreSys) -> AsyncMock:
@pytest.mark.usefixtures("no_job_throttle") @pytest.mark.usefixtures("no_job_throttle")
async def test_connectivity_check( async def test_connectivity_check(
coresys: CoreSys, coresys: CoreSys,
websession: AsyncMock, websession: MagicMock,
side_effect: Exception | None, side_effect: Exception | None,
connectivity: bool, connectivity: bool,
): ):
"""Test connectivity check.""" """Test connectivity check."""
assert coresys.supervisor.connectivity is True assert coresys.supervisor.connectivity is True
websession.head.side_effect = side_effect websession.head = AsyncMock(side_effect=side_effect)
await coresys.supervisor.check_connectivity() await coresys.supervisor.check_connectivity()
assert coresys.supervisor.connectivity is connectivity assert coresys.supervisor.connectivity is connectivity
@ -66,14 +56,14 @@ async def test_connectivity_check(
) )
async def test_connectivity_check_throttling( async def test_connectivity_check_throttling(
coresys: CoreSys, coresys: CoreSys,
websession: AsyncMock, websession: MagicMock,
side_effect: Exception | None, side_effect: Exception | None,
call_interval: timedelta, call_interval: timedelta,
throttled: bool, throttled: bool,
): ):
"""Test connectivity check throttled when checks succeed.""" """Test connectivity check throttled when checks succeed."""
coresys.supervisor.connectivity = None coresys.supervisor.connectivity = None
websession.head.side_effect = side_effect websession.head = AsyncMock(side_effect=side_effect)
reset_last_call(Supervisor.check_connectivity) reset_last_call(Supervisor.check_connectivity)
with travel(datetime.now(), tick=False) as traveller: with travel(datetime.now(), tick=False) as traveller:
@ -105,35 +95,32 @@ async def test_update_failed(coresys: CoreSys, capture_exception: Mock):
"channel", [UpdateChannel.STABLE, UpdateChannel.BETA, UpdateChannel.DEV] "channel", [UpdateChannel.STABLE, UpdateChannel.BETA, UpdateChannel.DEV]
) )
async def test_update_apparmor( async def test_update_apparmor(
coresys: CoreSys, channel: UpdateChannel, tmp_supervisor_data coresys: CoreSys, channel: UpdateChannel, websession: MagicMock, tmp_supervisor_data
): ):
"""Test updating apparmor.""" """Test updating apparmor."""
websession.get = Mock(return_value=MockResponse())
coresys.updater.channel = channel coresys.updater.channel = channel
with ( with (
patch("supervisor.coresys.aiohttp.ClientSession.get") as get,
patch.object(AppArmorControl, "load_profile") as load_profile, patch.object(AppArmorControl, "load_profile") as load_profile,
): ):
get.return_value.__aenter__.return_value.status = 200
get.return_value.__aenter__.return_value.text = AsyncMock(return_value="")
await coresys.supervisor.update_apparmor() await coresys.supervisor.update_apparmor()
get.assert_called_once_with( websession.get.assert_called_once_with(
f"https://version.home-assistant.io/apparmor_{channel}.txt", f"https://version.home-assistant.io/apparmor_{channel}.txt",
timeout=ClientTimeout(total=10), timeout=ClientTimeout(total=10),
) )
load_profile.assert_called_once() load_profile.assert_called_once()
async def test_update_apparmor_error(coresys: CoreSys, tmp_supervisor_data): async def test_update_apparmor_error(
coresys: CoreSys, websession: MagicMock, tmp_supervisor_data
):
"""Test error updating apparmor profile.""" """Test error updating apparmor profile."""
websession.get = Mock(return_value=MockResponse())
with ( with (
patch("supervisor.coresys.aiohttp.ClientSession.get") as get,
patch.object(AppArmorControl, "load_profile"), patch.object(AppArmorControl, "load_profile"),
patch("supervisor.supervisor.Path.write_text", side_effect=(err := OSError())), patch("supervisor.supervisor.Path.write_text", side_effect=(err := OSError())),
): ):
get.return_value.__aenter__.return_value.status = 200
get.return_value.__aenter__.return_value.text = AsyncMock(return_value="")
err.errno = errno.EBUSY err.errno = errno.EBUSY
with pytest.raises(SupervisorAppArmorError): with pytest.raises(SupervisorAppArmorError):
await coresys.supervisor.update_apparmor() await coresys.supervisor.update_apparmor()

View File

@ -1,6 +1,7 @@
"""Test updater files.""" """Test updater files."""
import asyncio import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
@ -11,7 +12,7 @@ from supervisor.coresys import CoreSys
from supervisor.dbus.const import ConnectivityState from supervisor.dbus.const import ConnectivityState
from supervisor.jobs import SupervisorJob from supervisor.jobs import SupervisorJob
from tests.common import load_binary_fixture from tests.common import MockResponse, load_binary_fixture
from tests.dbus_service_mocks.network_manager import ( from tests.dbus_service_mocks.network_manager import (
NetworkManager as NetworkManagerService, NetworkManager as NetworkManagerService,
) )
@ -20,15 +21,15 @@ URL_TEST = "https://version.home-assistant.io/stable.json"
@pytest.mark.usefixtures("no_job_throttle") @pytest.mark.usefixtures("no_job_throttle")
async def test_fetch_versions(coresys: CoreSys) -> None: async def test_fetch_versions(
coresys: CoreSys, mock_update_data: MockResponse, supervisor_internet: AsyncMock
) -> None:
"""Test download and sync version.""" """Test download and sync version."""
coresys.security.force = True coresys.security.force = True
await coresys.updater.fetch_data() await coresys.updater.fetch_data()
async with coresys.websession.get(URL_TEST) as request: data = json.loads(await mock_update_data.text())
data = await request.json()
assert coresys.updater.version_supervisor == data["supervisor"] assert coresys.updater.version_supervisor == data["supervisor"]
assert coresys.updater.version_homeassistant == data["homeassistant"]["default"] assert coresys.updater.version_homeassistant == data["homeassistant"]["default"]
@ -73,7 +74,13 @@ async def test_fetch_versions(coresys: CoreSys) -> None:
("4.20", "5.13"), ("4.20", "5.13"),
], ],
) )
async def test_os_update_path(coresys: CoreSys, version: str, expected: str): async def test_os_update_path(
coresys: CoreSys,
version: str,
expected: str,
mock_update_data: AsyncMock,
supervisor_internet: AsyncMock,
):
"""Test OS upgrade path across major versions.""" """Test OS upgrade path across major versions."""
coresys.os._board = "rpi4" # pylint: disable=protected-access coresys.os._board = "rpi4" # pylint: disable=protected-access
coresys.os._version = AwesomeVersion(version) # pylint: disable=protected-access coresys.os._version = AwesomeVersion(version) # pylint: disable=protected-access
@ -85,7 +92,9 @@ async def test_os_update_path(coresys: CoreSys, version: str, expected: str):
@pytest.mark.usefixtures("no_job_throttle") @pytest.mark.usefixtures("no_job_throttle")
async def test_delayed_fetch_for_connectivity( async def test_delayed_fetch_for_connectivity(
coresys: CoreSys, network_manager_service: NetworkManagerService coresys: CoreSys,
network_manager_service: NetworkManagerService,
websession: MagicMock,
): ):
"""Test initial version fetch waits for connectivity on load.""" """Test initial version fetch waits for connectivity on load."""
coresys.websession.get = MagicMock() coresys.websession.get = MagicMock()