diff --git a/supervisor/api/__init__.py b/supervisor/api/__init__.py index 71f12a202..08bdcdfd8 100644 --- a/supervisor/api/__init__.py +++ b/supervisor/api/__init__.py @@ -43,7 +43,10 @@ class RestAPI(CoreSysAttributes): self.security: SecurityMiddleware = SecurityMiddleware(coresys) self.webapp: web.Application = web.Application( client_max_size=MAX_CLIENT_SIZE, - middlewares=[self.security.token_validation], + middlewares=[ + self.security.system_validation, + self.security.token_validation, + ], ) # service stuff diff --git a/supervisor/api/security.py b/supervisor/api/security.py index b59d98353..1ee49978b 100644 --- a/supervisor/api/security.py +++ b/supervisor/api/security.py @@ -2,7 +2,7 @@ import logging import re -from aiohttp.web import middleware +from aiohttp.web import Request, RequestHandler, Response, middleware from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized from ..const import ( @@ -12,9 +12,10 @@ from ..const import ( ROLE_DEFAULT, ROLE_HOMEASSISTANT, ROLE_MANAGER, + CoreState, ) -from ..coresys import CoreSysAttributes -from .utils import excract_supervisor_token +from ..coresys import CoreSys, CoreSysAttributes +from .utils import api_return_error, excract_supervisor_token _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -108,12 +109,30 @@ ADDONS_ROLE_ACCESS = { class SecurityMiddleware(CoreSysAttributes): """Security middleware functions.""" - def __init__(self, coresys): + def __init__(self, coresys: CoreSys): """Initialize security middleware.""" - self.coresys = coresys + self.coresys: CoreSys = coresys @middleware - async def token_validation(self, request, handler): + async def system_validation( + self, request: Request, handler: RequestHandler + ) -> Response: + """Check if core is ready to response.""" + if self.sys_core.state not in ( + CoreState.STARTUP, + CoreState.RUNNING, + CoreState.FREEZE, + ): + return api_return_error( + message=f"System is not ready with state: {self.sys_core.state.value}" + ) + + return await handler(request) + + @middleware + async def token_validation( + self, request: Request, handler: RequestHandler + ) -> Response: """Check security access of this layer.""" request_from = None supervisor_token = excract_supervisor_token(request) diff --git a/supervisor/api/utils.py b/supervisor/api/utils.py index e9eeedc0d..096e4db74 100644 --- a/supervisor/api/utils.py +++ b/supervisor/api/utils.py @@ -97,11 +97,15 @@ def api_process_raw(content): return wrap_method -def api_return_error(error: Optional[Any] = None) -> web.Response: +def api_return_error( + error: Optional[Exception] = None, message: Optional[str] = None +) -> web.Response: """Return an API error message.""" - message = get_message_from_exception_chain(error) - if check_exception_chain(error, DockerAPIError): - message = format_message(message) + if error and not message: + message = get_message_from_exception_chain(error) + if check_exception_chain(error, DockerAPIError): + message = format_message(message) + return web.json_response( { JSON_RESULT: RESULT_ERROR, diff --git a/supervisor/core.py b/supervisor/core.py index 322b9ef90..5ec073357 100644 --- a/supervisor/core.py +++ b/supervisor/core.py @@ -111,6 +111,10 @@ class Core(CoreSysAttributes): """Start setting up supervisor orchestration.""" self.state = CoreState.SETUP + # rest api views + await self.sys_api.load() + await self.sys_api.start() + # Load DBus await self.sys_dbus.load() @@ -138,9 +142,6 @@ class Core(CoreSysAttributes): # Load Add-ons await self.sys_addons.load() - # rest api views - await self.sys_api.load() - # load last available data await self.sys_snapshots.load() @@ -189,7 +190,6 @@ class Core(CoreSysAttributes): async def start(self): """Start Supervisor orchestration.""" self.state = CoreState.STARTUP - await self.sys_api.start() # Check if system is healthy if not self.supported: diff --git a/tests/api/test_security.py b/tests/api/test_security.py new file mode 100644 index 000000000..08a6fdbaf --- /dev/null +++ b/tests/api/test_security.py @@ -0,0 +1,61 @@ +"""Test API security layer.""" + +from aiohttp import web +import pytest + +from supervisor.api import RestAPI +from supervisor.const import CoreState +from supervisor.coresys import CoreSys + +# pylint: disable=redefined-outer-name + + +@pytest.fixture +async def api_system(aiohttp_client, run_dir, coresys: CoreSys): + """Fixture for RestAPI client.""" + api = RestAPI(coresys) + api.webapp = web.Application() + await api.load() + + api.webapp.middlewares.append(api.security.system_validation) + yield await aiohttp_client(api.webapp) + + +@pytest.mark.asyncio +async def test_api_security_system_initialize(api_system, coresys: CoreSys): + """Test security.""" + coresys.core.state = CoreState.INITIALIZE + + resp = await api_system.get("/supervisor/ping") + result = await resp.json() + assert resp.status == 400 + assert result["result"] == "error" + + +@pytest.mark.asyncio +async def test_api_security_system_setup(api_system, coresys: CoreSys): + """Test security.""" + coresys.core.state = CoreState.SETUP + + resp = await api_system.get("/supervisor/ping") + result = await resp.json() + assert resp.status == 400 + assert result["result"] == "error" + + +@pytest.mark.asyncio +async def test_api_security_system_running(api_system, coresys: CoreSys): + """Test security.""" + coresys.core.state = CoreState.RUNNING + + resp = await api_system.get("/supervisor/ping") + assert resp.status == 200 + + +@pytest.mark.asyncio +async def test_api_security_system_startup(api_system, coresys: CoreSys): + """Test security.""" + coresys.core.state = CoreState.STARTUP + + resp = await api_system.get("/supervisor/ping") + assert resp.status == 200 diff --git a/tests/conftest.py b/tests/conftest.py index 04d93b346..aa77654e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """Common test functions.""" +from pathlib import Path from unittest.mock import MagicMock, PropertyMock, patch from uuid import uuid4 @@ -164,3 +165,12 @@ def store_manager(coresys: CoreSys): sm_obj.repositories = set(coresys.config.addons_repositories) with patch("supervisor.store.data.StoreData.update", return_value=MagicMock()): yield sm_obj + + +@pytest.fixture +def run_dir(tmp_path): + """Fixture to inject hassio env.""" + with patch("supervisor.core.RUN_SUPERVISOR_STATE") as mock_run: + tmp_state = Path(tmp_path, "supervisor") + mock_run.write_text = tmp_state.write_text + yield tmp_state diff --git a/tests/test_core_state.py b/tests/test_core_state.py index 98d97232b..00a440b31 100644 --- a/tests/test_core_state.py +++ b/tests/test_core_state.py @@ -1,22 +1,7 @@ """Testing handling with CoreState.""" -from pathlib import Path -from unittest.mock import patch - -import pytest from supervisor.const import CoreState -# pylint: disable=redefined-outer-name - - -@pytest.fixture -def run_dir(tmp_path): - """Fixture to inject hassio env.""" - with patch("supervisor.core.RUN_SUPERVISOR_STATE") as mock_run: - tmp_state = Path(tmp_path, "supervisor") - mock_run.write_text = tmp_state.write_text - yield tmp_state - def test_write_state(run_dir, coresys): """Test write corestate to /run/supervisor."""