Restrict stopping core during migrations with force option (#5205)

This commit is contained in:
Mike Degatano 2024-07-25 11:14:45 -04:00 committed by GitHub
parent 591b9a4d87
commit 0bbd15bfda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 155 additions and 10 deletions

View File

@ -36,6 +36,7 @@ ATTR_DT_UTC = "dt_utc"
ATTR_EJECTABLE = "ejectable"
ATTR_FALLBACK = "fallback"
ATTR_FILESYSTEMS = "filesystems"
ATTR_FORCE = "force"
ATTR_GROUP_IDS = "group_ids"
ATTR_IDENTIFIERS = "identifiers"
ATTR_IS_ACTIVE = "is_active"

View File

@ -1,4 +1,5 @@
"""Init file for Supervisor Home Assistant RESTful API."""
import asyncio
from collections.abc import Awaitable
import logging
@ -34,9 +35,9 @@ from ..const import (
ATTR_WATCHDOG,
)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError
from ..exceptions import APIDBMigrationInProgress, APIError
from ..validate import docker_image, network_port, version_tag
from .const import ATTR_SAFE_MODE
from .const import ATTR_FORCE, ATTR_SAFE_MODE
from .utils import api_process, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__)
@ -66,6 +67,13 @@ SCHEMA_UPDATE = vol.Schema(
SCHEMA_RESTART = vol.Schema(
{
vol.Optional(ATTR_SAFE_MODE, default=False): vol.Boolean(),
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)
SCHEMA_STOP = vol.Schema(
{
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)
@ -73,6 +81,17 @@ SCHEMA_RESTART = vol.Schema(
class APIHomeAssistant(CoreSysAttributes):
"""Handle RESTful API for Home Assistant functions."""
async def _check_offline_migration(self, force: bool = False) -> None:
"""Check and raise if there's an offline DB migration in progress."""
if (
not force
and (state := await self.sys_homeassistant.api.get_api_state())
and state.offline_db_migration
):
raise APIDBMigrationInProgress(
"Offline database migration in progress, try again after it has completed"
)
@api_process
async def info(self, request: web.Request) -> dict[str, Any]:
"""Return host information."""
@ -154,6 +173,7 @@ class APIHomeAssistant(CoreSysAttributes):
async def update(self, request: web.Request) -> None:
"""Update Home Assistant."""
body = await api_validate(SCHEMA_UPDATE, request)
await self._check_offline_migration()
await asyncio.shield(
self.sys_homeassistant.core.update(
@ -163,9 +183,12 @@ class APIHomeAssistant(CoreSysAttributes):
)
@api_process
def stop(self, request: web.Request) -> Awaitable[None]:
async def stop(self, request: web.Request) -> Awaitable[None]:
"""Stop Home Assistant."""
return asyncio.shield(self.sys_homeassistant.core.stop())
body = await api_validate(SCHEMA_STOP, request)
await self._check_offline_migration(force=body[ATTR_FORCE])
return await asyncio.shield(self.sys_homeassistant.core.stop())
@api_process
def start(self, request: web.Request) -> Awaitable[None]:
@ -176,6 +199,7 @@ class APIHomeAssistant(CoreSysAttributes):
async def restart(self, request: web.Request) -> None:
"""Restart Home Assistant."""
body = await api_validate(SCHEMA_RESTART, request)
await self._check_offline_migration(force=body[ATTR_FORCE])
await asyncio.shield(
self.sys_homeassistant.core.restart(safe_mode=body[ATTR_SAFE_MODE])
@ -185,6 +209,7 @@ class APIHomeAssistant(CoreSysAttributes):
async def rebuild(self, request: web.Request) -> None:
"""Rebuild Home Assistant."""
body = await api_validate(SCHEMA_RESTART, request)
await self._check_offline_migration(force=body[ATTR_FORCE])
await asyncio.shield(
self.sys_homeassistant.core.rebuild(safe_mode=body[ATTR_SAFE_MODE])

View File

@ -28,7 +28,7 @@ from ..const import (
ATTR_TIMEZONE,
)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError, HostLogError
from ..exceptions import APIDBMigrationInProgress, APIError, HostLogError
from ..host.const import (
PARAM_BOOT_ID,
PARAM_FOLLOW,
@ -46,6 +46,7 @@ from .const import (
ATTR_BROADCAST_MDNS,
ATTR_DT_SYNCHRONIZED,
ATTR_DT_UTC,
ATTR_FORCE,
ATTR_IDENTIFIERS,
ATTR_LLMNR_HOSTNAME,
ATTR_STARTUP_TIME,
@ -64,10 +65,29 @@ DEFAULT_RANGE = 100
SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_HOSTNAME): str})
# pylint: disable=no-value-for-parameter
SCHEMA_SHUTDOWN = vol.Schema(
{
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)
# pylint: enable=no-value-for-parameter
class APIHost(CoreSysAttributes):
"""Handle RESTful API for host functions."""
async def _check_ha_offline_migration(self, force: bool) -> None:
"""Check if HA has an offline migration in progress and raise if not forced."""
if (
not force
and (state := await self.sys_homeassistant.api.get_api_state())
and state.offline_db_migration
):
raise APIDBMigrationInProgress(
"Home Assistant offline database migration in progress, please wait until complete before shutting down host"
)
@api_process
async def info(self, request):
"""Return host information."""
@ -109,14 +129,20 @@ class APIHost(CoreSysAttributes):
)
@api_process
def reboot(self, request):
async def reboot(self, request):
"""Reboot host."""
return asyncio.shield(self.sys_host.control.reboot())
body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])
return await asyncio.shield(self.sys_host.control.reboot())
@api_process
def shutdown(self, request):
async def shutdown(self, request):
"""Poweroff host."""
return asyncio.shield(self.sys_host.control.shutdown())
body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])
return await asyncio.shield(self.sys_host.control.shutdown())
@api_process
def reload(self, request):

View File

@ -1,4 +1,5 @@
"""Core Exceptions."""
from collections.abc import Callable
@ -339,6 +340,12 @@ class APIAddonNotInstalled(APIError):
"""Not installed addon requested at addons API."""
class APIDBMigrationInProgress(APIError):
"""Service is unavailable due to an offline DB migration is in progress."""
status = 503
# Service / Discovery

View File

@ -8,6 +8,7 @@ from awesomeversion import AwesomeVersion
import pytest
from supervisor.coresys import CoreSys
from supervisor.homeassistant.api import APIState
from supervisor.homeassistant.core import HomeAssistantCore
from supervisor.homeassistant.module import HomeAssistant
@ -142,3 +143,48 @@ async def test_api_rebuild(
assert container.remove.call_count == 4
assert container.start.call_count == 2
assert safe_mode_marker.exists()
@pytest.mark.parametrize("action", ["rebuild", "restart", "stop", "update"])
async def test_migration_blocks_stopping_core(
api_client: TestClient,
coresys: CoreSys,
action: str,
):
"""Test that an offline db migration in progress stops users from stopping/restarting core."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
resp = await api_client.post(f"/homeassistant/{action}")
assert resp.status == 503
result = await resp.json()
assert (
result["message"]
== "Offline database migration in progress, try again after it has completed"
)
async def test_force_rebuild_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option rebuilds even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(HomeAssistantCore, "rebuild") as rebuild:
await api_client.post("/homeassistant/rebuild", json={"force": True})
rebuild.assert_called_once()
async def test_force_restart_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option restarts even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(HomeAssistantCore, "restart") as restart:
await api_client.post("/homeassistant/restart", json={"force": True})
restart.assert_called_once()
async def test_force_stop_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option stops even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(HomeAssistantCore, "stop") as stop:
await api_client.post("/homeassistant/stop", json={"force": True})
stop.assert_called_once()

View File

@ -1,13 +1,15 @@
"""Test Host API."""
from unittest.mock import ANY, MagicMock
from unittest.mock import ANY, MagicMock, patch
from aiohttp.test_utils import TestClient
import pytest
from supervisor.coresys import CoreSys
from supervisor.dbus.resolved import Resolved
from supervisor.homeassistant.api import APIState
from supervisor.host.const import LogFormat, LogFormatter
from supervisor.host.control import SystemControl
from tests.dbus_service_mocks.base import DBusServiceMock
from tests.dbus_service_mocks.systemd import Systemd as SystemdService
@ -324,3 +326,41 @@ async def test_advanced_logs_errors(api_client: TestClient):
content
== "Invalid content type requested. Only text/plain and text/x-log supported for now."
)
@pytest.mark.parametrize("action", ["reboot", "shutdown"])
async def test_migration_blocks_shutdown(
api_client: TestClient,
coresys: CoreSys,
action: str,
):
"""Test that an offline db migration in progress stops users from shuting down or rebooting system."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
resp = await api_client.post(f"/host/{action}")
assert resp.status == 503
result = await resp.json()
assert (
result["message"]
== "Home Assistant offline database migration in progress, please wait until complete before shutting down host"
)
async def test_force_reboot_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option reboots even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(SystemControl, "reboot") as reboot:
await api_client.post("/host/reboot", json={"force": True})
reboot.assert_called_once()
async def test_force_shutdown_during_migration(
api_client: TestClient, coresys: CoreSys
):
"""Test force option shutdown even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(SystemControl, "shutdown") as shutdown:
await api_client.post("/host/shutdown", json={"force": True})
shutdown.assert_called_once()