Restrict stopping core during migrations with force option (#5205)

This commit is contained in:
Mike Degatano 2024-07-25 11:14:45 -04:00 committed by GitHub
parent 591b9a4d87
commit 0bbd15bfda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 155 additions and 10 deletions

View File

@ -36,6 +36,7 @@ ATTR_DT_UTC = "dt_utc"
ATTR_EJECTABLE = "ejectable" ATTR_EJECTABLE = "ejectable"
ATTR_FALLBACK = "fallback" ATTR_FALLBACK = "fallback"
ATTR_FILESYSTEMS = "filesystems" ATTR_FILESYSTEMS = "filesystems"
ATTR_FORCE = "force"
ATTR_GROUP_IDS = "group_ids" ATTR_GROUP_IDS = "group_ids"
ATTR_IDENTIFIERS = "identifiers" ATTR_IDENTIFIERS = "identifiers"
ATTR_IS_ACTIVE = "is_active" ATTR_IS_ACTIVE = "is_active"

View File

@ -1,4 +1,5 @@
"""Init file for Supervisor Home Assistant RESTful API.""" """Init file for Supervisor Home Assistant RESTful API."""
import asyncio import asyncio
from collections.abc import Awaitable from collections.abc import Awaitable
import logging import logging
@ -34,9 +35,9 @@ from ..const import (
ATTR_WATCHDOG, ATTR_WATCHDOG,
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import APIError from ..exceptions import APIDBMigrationInProgress, APIError
from ..validate import docker_image, network_port, version_tag from ..validate import docker_image, network_port, version_tag
from .const import ATTR_SAFE_MODE from .const import ATTR_FORCE, ATTR_SAFE_MODE
from .utils import api_process, api_validate from .utils import api_process, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -66,6 +67,13 @@ SCHEMA_UPDATE = vol.Schema(
SCHEMA_RESTART = vol.Schema( SCHEMA_RESTART = vol.Schema(
{ {
vol.Optional(ATTR_SAFE_MODE, default=False): vol.Boolean(), vol.Optional(ATTR_SAFE_MODE, default=False): vol.Boolean(),
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)
SCHEMA_STOP = vol.Schema(
{
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
} }
) )
@ -73,6 +81,17 @@ SCHEMA_RESTART = vol.Schema(
class APIHomeAssistant(CoreSysAttributes): class APIHomeAssistant(CoreSysAttributes):
"""Handle RESTful API for Home Assistant functions.""" """Handle RESTful API for Home Assistant functions."""
async def _check_offline_migration(self, force: bool = False) -> None:
"""Check and raise if there's an offline DB migration in progress."""
if (
not force
and (state := await self.sys_homeassistant.api.get_api_state())
and state.offline_db_migration
):
raise APIDBMigrationInProgress(
"Offline database migration in progress, try again after it has completed"
)
@api_process @api_process
async def info(self, request: web.Request) -> dict[str, Any]: async def info(self, request: web.Request) -> dict[str, Any]:
"""Return host information.""" """Return host information."""
@ -154,6 +173,7 @@ class APIHomeAssistant(CoreSysAttributes):
async def update(self, request: web.Request) -> None: async def update(self, request: web.Request) -> None:
"""Update Home Assistant.""" """Update Home Assistant."""
body = await api_validate(SCHEMA_UPDATE, request) body = await api_validate(SCHEMA_UPDATE, request)
await self._check_offline_migration()
await asyncio.shield( await asyncio.shield(
self.sys_homeassistant.core.update( self.sys_homeassistant.core.update(
@ -163,9 +183,12 @@ class APIHomeAssistant(CoreSysAttributes):
) )
@api_process @api_process
def stop(self, request: web.Request) -> Awaitable[None]: async def stop(self, request: web.Request) -> Awaitable[None]:
"""Stop Home Assistant.""" """Stop Home Assistant."""
return asyncio.shield(self.sys_homeassistant.core.stop()) body = await api_validate(SCHEMA_STOP, request)
await self._check_offline_migration(force=body[ATTR_FORCE])
return await asyncio.shield(self.sys_homeassistant.core.stop())
@api_process @api_process
def start(self, request: web.Request) -> Awaitable[None]: def start(self, request: web.Request) -> Awaitable[None]:
@ -176,6 +199,7 @@ class APIHomeAssistant(CoreSysAttributes):
async def restart(self, request: web.Request) -> None: async def restart(self, request: web.Request) -> None:
"""Restart Home Assistant.""" """Restart Home Assistant."""
body = await api_validate(SCHEMA_RESTART, request) body = await api_validate(SCHEMA_RESTART, request)
await self._check_offline_migration(force=body[ATTR_FORCE])
await asyncio.shield( await asyncio.shield(
self.sys_homeassistant.core.restart(safe_mode=body[ATTR_SAFE_MODE]) self.sys_homeassistant.core.restart(safe_mode=body[ATTR_SAFE_MODE])
@ -185,6 +209,7 @@ class APIHomeAssistant(CoreSysAttributes):
async def rebuild(self, request: web.Request) -> None: async def rebuild(self, request: web.Request) -> None:
"""Rebuild Home Assistant.""" """Rebuild Home Assistant."""
body = await api_validate(SCHEMA_RESTART, request) body = await api_validate(SCHEMA_RESTART, request)
await self._check_offline_migration(force=body[ATTR_FORCE])
await asyncio.shield( await asyncio.shield(
self.sys_homeassistant.core.rebuild(safe_mode=body[ATTR_SAFE_MODE]) self.sys_homeassistant.core.rebuild(safe_mode=body[ATTR_SAFE_MODE])

View File

@ -28,7 +28,7 @@ from ..const import (
ATTR_TIMEZONE, ATTR_TIMEZONE,
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import APIError, HostLogError from ..exceptions import APIDBMigrationInProgress, APIError, HostLogError
from ..host.const import ( from ..host.const import (
PARAM_BOOT_ID, PARAM_BOOT_ID,
PARAM_FOLLOW, PARAM_FOLLOW,
@ -46,6 +46,7 @@ from .const import (
ATTR_BROADCAST_MDNS, ATTR_BROADCAST_MDNS,
ATTR_DT_SYNCHRONIZED, ATTR_DT_SYNCHRONIZED,
ATTR_DT_UTC, ATTR_DT_UTC,
ATTR_FORCE,
ATTR_IDENTIFIERS, ATTR_IDENTIFIERS,
ATTR_LLMNR_HOSTNAME, ATTR_LLMNR_HOSTNAME,
ATTR_STARTUP_TIME, ATTR_STARTUP_TIME,
@ -64,10 +65,29 @@ DEFAULT_RANGE = 100
SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_HOSTNAME): str}) SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_HOSTNAME): str})
# pylint: disable=no-value-for-parameter
SCHEMA_SHUTDOWN = vol.Schema(
{
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)
# pylint: enable=no-value-for-parameter
class APIHost(CoreSysAttributes): class APIHost(CoreSysAttributes):
"""Handle RESTful API for host functions.""" """Handle RESTful API for host functions."""
async def _check_ha_offline_migration(self, force: bool) -> None:
"""Check if HA has an offline migration in progress and raise if not forced."""
if (
not force
and (state := await self.sys_homeassistant.api.get_api_state())
and state.offline_db_migration
):
raise APIDBMigrationInProgress(
"Home Assistant offline database migration in progress, please wait until complete before shutting down host"
)
@api_process @api_process
async def info(self, request): async def info(self, request):
"""Return host information.""" """Return host information."""
@ -109,14 +129,20 @@ class APIHost(CoreSysAttributes):
) )
@api_process @api_process
def reboot(self, request): async def reboot(self, request):
"""Reboot host.""" """Reboot host."""
return asyncio.shield(self.sys_host.control.reboot()) body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])
return await asyncio.shield(self.sys_host.control.reboot())
@api_process @api_process
def shutdown(self, request): async def shutdown(self, request):
"""Poweroff host.""" """Poweroff host."""
return asyncio.shield(self.sys_host.control.shutdown()) body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])
return await asyncio.shield(self.sys_host.control.shutdown())
@api_process @api_process
def reload(self, request): def reload(self, request):

View File

@ -1,4 +1,5 @@
"""Core Exceptions.""" """Core Exceptions."""
from collections.abc import Callable from collections.abc import Callable
@ -339,6 +340,12 @@ class APIAddonNotInstalled(APIError):
"""Not installed addon requested at addons API.""" """Not installed addon requested at addons API."""
class APIDBMigrationInProgress(APIError):
"""Service is unavailable due to an offline DB migration is in progress."""
status = 503
# Service / Discovery # Service / Discovery

View File

@ -8,6 +8,7 @@ from awesomeversion import AwesomeVersion
import pytest import pytest
from supervisor.coresys import CoreSys from supervisor.coresys import CoreSys
from supervisor.homeassistant.api import APIState
from supervisor.homeassistant.core import HomeAssistantCore from supervisor.homeassistant.core import HomeAssistantCore
from supervisor.homeassistant.module import HomeAssistant from supervisor.homeassistant.module import HomeAssistant
@ -142,3 +143,48 @@ async def test_api_rebuild(
assert container.remove.call_count == 4 assert container.remove.call_count == 4
assert container.start.call_count == 2 assert container.start.call_count == 2
assert safe_mode_marker.exists() assert safe_mode_marker.exists()
@pytest.mark.parametrize("action", ["rebuild", "restart", "stop", "update"])
async def test_migration_blocks_stopping_core(
api_client: TestClient,
coresys: CoreSys,
action: str,
):
"""Test that an offline db migration in progress stops users from stopping/restarting core."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
resp = await api_client.post(f"/homeassistant/{action}")
assert resp.status == 503
result = await resp.json()
assert (
result["message"]
== "Offline database migration in progress, try again after it has completed"
)
async def test_force_rebuild_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option rebuilds even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(HomeAssistantCore, "rebuild") as rebuild:
await api_client.post("/homeassistant/rebuild", json={"force": True})
rebuild.assert_called_once()
async def test_force_restart_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option restarts even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(HomeAssistantCore, "restart") as restart:
await api_client.post("/homeassistant/restart", json={"force": True})
restart.assert_called_once()
async def test_force_stop_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option stops even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(HomeAssistantCore, "stop") as stop:
await api_client.post("/homeassistant/stop", json={"force": True})
stop.assert_called_once()

View File

@ -1,13 +1,15 @@
"""Test Host API.""" """Test Host API."""
from unittest.mock import ANY, MagicMock from unittest.mock import ANY, MagicMock, patch
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
import pytest import pytest
from supervisor.coresys import CoreSys from supervisor.coresys import CoreSys
from supervisor.dbus.resolved import Resolved from supervisor.dbus.resolved import Resolved
from supervisor.homeassistant.api import APIState
from supervisor.host.const import LogFormat, LogFormatter from supervisor.host.const import LogFormat, LogFormatter
from supervisor.host.control import SystemControl
from tests.dbus_service_mocks.base import DBusServiceMock from tests.dbus_service_mocks.base import DBusServiceMock
from tests.dbus_service_mocks.systemd import Systemd as SystemdService from tests.dbus_service_mocks.systemd import Systemd as SystemdService
@ -324,3 +326,41 @@ async def test_advanced_logs_errors(api_client: TestClient):
content content
== "Invalid content type requested. Only text/plain and text/x-log supported for now." == "Invalid content type requested. Only text/plain and text/x-log supported for now."
) )
@pytest.mark.parametrize("action", ["reboot", "shutdown"])
async def test_migration_blocks_shutdown(
api_client: TestClient,
coresys: CoreSys,
action: str,
):
"""Test that an offline db migration in progress stops users from shuting down or rebooting system."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
resp = await api_client.post(f"/host/{action}")
assert resp.status == 503
result = await resp.json()
assert (
result["message"]
== "Home Assistant offline database migration in progress, please wait until complete before shutting down host"
)
async def test_force_reboot_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option reboots even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(SystemControl, "reboot") as reboot:
await api_client.post("/host/reboot", json={"force": True})
reboot.assert_called_once()
async def test_force_shutdown_during_migration(
api_client: TestClient, coresys: CoreSys
):
"""Test force option shutdown even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)
with patch.object(SystemControl, "shutdown") as shutdown:
await api_client.post("/host/shutdown", json={"force": True})
shutdown.assert_called_once()