From a13ae85982084632df2f6d75031574a9d4a8716c Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Thu, 9 Dec 2021 01:49:35 +0100 Subject: [PATCH] Introduce only_supervisor for @websocket_api.ws_require_user() (#61298) --- homeassistant/components/hassio/__init__.py | 5 ++-- .../components/recorder/websocket_api.py | 4 ++-- .../components/websocket_api/decorators.py | 6 +++++ homeassistant/const.py | 3 +++ .../components/recorder/test_websocket_api.py | 20 +++++++++------- .../websocket_api/test_decorators.py | 23 +++++++++++++++++++ tests/conftest.py | 22 +++++++++++++++++- 7 files changed, 70 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/hassio/__init__.py b/homeassistant/components/hassio/__init__.py index 614ea928828..78927d2f322 100644 --- a/homeassistant/components/hassio/__init__.py +++ b/homeassistant/components/hassio/__init__.py @@ -20,6 +20,7 @@ from homeassistant.const import ( ATTR_MANUFACTURER, ATTR_NAME, EVENT_CORE_CONFIG_UPDATE, + HASSIO_USER_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP, Platform, @@ -440,11 +441,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa: # Migrate old name if user.name == "Hass.io": - await hass.auth.async_update_user(user, name="Supervisor") + await hass.auth.async_update_user(user, name=HASSIO_USER_NAME) if refresh_token is None: user = await hass.auth.async_create_system_user( - "Supervisor", group_ids=[GROUP_ID_ADMIN] + HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN] ) refresh_token = await hass.auth.async_create_refresh_token(user) data["hassio_user"] = user.id diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index f6d4d57a7e5..aec7905615f 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -113,7 +113,7 @@ def ws_info( connection.send_result(msg["id"], recorder_info) -@websocket_api.require_admin +@websocket_api.ws_require_user(only_supervisor=True) @websocket_api.websocket_command({vol.Required("type"): "backup/start"}) @websocket_api.async_response async def ws_backup_start( @@ -131,7 +131,7 @@ async def ws_backup_start( connection.send_result(msg["id"]) -@websocket_api.require_admin +@websocket_api.ws_require_user(only_supervisor=True) @websocket_api.websocket_command({vol.Required("type"): "backup/end"}) @websocket_api.async_response async def ws_backup_end( diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index eff82a8c71d..296271c7cfd 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -8,6 +8,7 @@ from typing import Any import voluptuous as vol +from homeassistant.const import HASSIO_USER_NAME from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import Unauthorized @@ -70,6 +71,7 @@ def ws_require_user( allow_system_user: bool = True, only_active_user: bool = True, only_inactive_user: bool = False, + only_supervisor: bool = False, ) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]: """Decorate function validating login user exist in current WS connection. @@ -111,6 +113,10 @@ def ws_require_user( output_error("only_inactive_user", "Not allowed as active user") return + if only_supervisor and connection.user.name != HASSIO_USER_NAME: + output_error("only_supervisor", "Only allowed as Supervisor") + return + return func(hass, connection, msg) return check_current_user diff --git a/homeassistant/const.py b/homeassistant/const.py index c4e90dc2d57..cb33c20b14b 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -756,3 +756,6 @@ ENTITY_CATEGORIES: Final[list[str]] = [ CAST_APP_ID_HOMEASSISTANT_MEDIA: Final = "B45F4572" # The ID of the Home Assistant Lovelace Cast App CAST_APP_ID_HOMEASSISTANT_LOVELACE: Final = "A078F6B0" + +# User used by Supervisor +HASSIO_USER_NAME = "Supervisor" diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index 994d1c677af..2a9f737e9a5 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -360,9 +360,11 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client): assert response["result"]["thread_running"] is True -async def test_backup_start_no_recorder(hass, hass_ws_client): +async def test_backup_start_no_recorder( + hass, hass_ws_client, hass_supervisor_access_token +): """Test getting backup start when recorder is not present.""" - client = await hass_ws_client() + client = await hass_ws_client(hass, hass_supervisor_access_token) await client.send_json({"id": 1, "type": "backup/start"}) response = await client.receive_json() @@ -370,9 +372,9 @@ async def test_backup_start_no_recorder(hass, hass_ws_client): assert response["error"]["code"] == "unknown_command" -async def test_backup_start_timeout(hass, hass_ws_client): +async def test_backup_start_timeout(hass, hass_ws_client, hass_supervisor_access_token): """Test getting backup start when recorder is not present.""" - client = await hass_ws_client() + client = await hass_ws_client(hass, hass_supervisor_access_token) await async_init_recorder_component(hass) # Ensure there are no queued events @@ -388,9 +390,9 @@ async def test_backup_start_timeout(hass, hass_ws_client): await client.send_json({"id": 2, "type": "backup/end"}) -async def test_backup_end(hass, hass_ws_client): +async def test_backup_end(hass, hass_ws_client, hass_supervisor_access_token): """Test backup start.""" - client = await hass_ws_client() + client = await hass_ws_client(hass, hass_supervisor_access_token) await async_init_recorder_component(hass) # Ensure there are no queued events @@ -405,9 +407,11 @@ async def test_backup_end(hass, hass_ws_client): assert response["success"] -async def test_backup_end_without_start(hass, hass_ws_client): +async def test_backup_end_without_start( + hass, hass_ws_client, hass_supervisor_access_token +): """Test backup start.""" - client = await hass_ws_client() + client = await hass_ws_client(hass, hass_supervisor_access_token) await async_init_recorder_component(hass) # Ensure there are no queued events diff --git a/tests/components/websocket_api/test_decorators.py b/tests/components/websocket_api/test_decorators.py index 45d761f6fed..4fbc1ae1a21 100644 --- a/tests/components/websocket_api/test_decorators.py +++ b/tests/components/websocket_api/test_decorators.py @@ -66,3 +66,26 @@ async def test_async_response_request_context(hass, websocket_client): assert msg["id"] == 7 assert not msg["success"] assert msg["error"]["code"] == "not_found" + + +async def test_supervisor_only(hass, websocket_client): + """Test that only the Supervisor can make requests.""" + + @websocket_api.ws_require_user(only_supervisor=True) + @websocket_api.websocket_command({"type": "test-require-supervisor-user"}) + def require_supervisor_request(hass, connection, msg): + connection.send_result(msg["id"]) + + websocket_api.async_register_command(hass, require_supervisor_request) + + await websocket_client.send_json( + { + "id": 5, + "type": "test-require-supervisor-user", + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 5 + assert not msg["success"] + assert msg["error"]["code"] == "only_supervisor" diff --git a/tests/conftest.py b/tests/conftest.py index 88651a0ec3f..56be04edeeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,7 @@ from homeassistant.components.websocket_api.auth import ( TYPE_AUTH_REQUIRED, ) from homeassistant.components.websocket_api.http import URL -from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED +from homeassistant.const import ATTR_NOW, EVENT_TIME_CHANGED, HASSIO_USER_NAME from homeassistant.helpers import config_entry_oauth2_flow, event from homeassistant.setup import async_setup_component from homeassistant.util import location @@ -405,6 +405,26 @@ def hass_read_only_access_token(hass, hass_read_only_user, local_auth): return hass.auth.async_create_access_token(refresh_token) +@pytest.fixture +def hass_supervisor_user(hass, local_auth): + """Return the Home Assistant Supervisor user.""" + admin_group = hass.loop.run_until_complete( + hass.auth.async_get_group(GROUP_ID_ADMIN) + ) + return MockUser( + name=HASSIO_USER_NAME, groups=[admin_group], system_generated=True + ).add_to_hass(hass) + + +@pytest.fixture +def hass_supervisor_access_token(hass, hass_supervisor_user, local_auth): + """Return a Home Assistant Supervisor access token.""" + refresh_token = hass.loop.run_until_complete( + hass.auth.async_create_refresh_token(hass_supervisor_user) + ) + return hass.auth.async_create_access_token(refresh_token) + + @pytest.fixture def legacy_auth(hass): """Load legacy API password provider."""