From 0ca3f25c5784b5ba2549578689439e87ef6faf17 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 18 Jun 2024 16:15:42 +0200 Subject: [PATCH] Add WS command for subscribing to storage collection changes (#119481) --- homeassistant/helpers/collection.py | 64 +++++- tests/components/lovelace/test_resources.py | 100 ++++++++-- tests/helpers/test_collection.py | 211 ++++++++++++++++++++ 3 files changed, 361 insertions(+), 14 deletions(-) diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index 1ce4a9d092b..1dd94d85f9a 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -18,7 +18,7 @@ from voluptuous.humanize import humanize_error from homeassistant.components import websocket_api from homeassistant.const import CONF_ID -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.util import slugify @@ -525,6 +525,9 @@ class StorageCollectionWebsocket[_StorageCollectionT: StorageCollection]: self.create_schema = create_schema self.update_schema = update_schema + self._remove_subscription: CALLBACK_TYPE | None = None + self._subscribers: set[tuple[websocket_api.ActiveConnection, int]] = set() + assert self.api_prefix[-1] != "/", "API prefix should not end in /" @property @@ -564,6 +567,15 @@ class StorageCollectionWebsocket[_StorageCollectionT: StorageCollection]: ), ) + websocket_api.async_register_command( + hass, + f"{self.api_prefix}/subscribe", + self._ws_subscribe, + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + {vol.Required("type"): f"{self.api_prefix}/subscribe"} + ), + ) + websocket_api.async_register_command( hass, f"{self.api_prefix}/update", @@ -619,6 +631,56 @@ class StorageCollectionWebsocket[_StorageCollectionT: StorageCollection]: except ValueError as err: connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err)) + @callback + def _ws_subscribe( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Subscribe to collection updates.""" + + async def async_change_listener( + change_set: Iterable[CollectionChange], + ) -> None: + json_msg = [ + { + "change_type": change.change_type, + self.item_id_key: change.item_id, + "item": change.item, + } + for change in change_set + ] + for connection, msg_id in self._subscribers: + connection.send_message(websocket_api.event_message(msg_id, json_msg)) + + if not self._subscribers: + self._remove_subscription = ( + self.storage_collection.async_add_change_set_listener( + async_change_listener + ) + ) + + self._subscribers.add((connection, msg["id"])) + + @callback + def cancel_subscription() -> None: + self._subscribers.remove((connection, msg["id"])) + if not self._subscribers and self._remove_subscription: + self._remove_subscription() + self._remove_subscription = None + + connection.subscriptions[msg["id"]] = cancel_subscription + + connection.send_message(websocket_api.result_message(msg["id"])) + + json_msg = [ + { + "change_type": CHANGE_ADDED, + self.item_id_key: item_id, + "item": item, + } + for item_id, item in self.storage_collection.data.items() + ] + connection.send_message(websocket_api.event_message(msg["id"], json_msg)) + async def ws_update_item( self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: diff --git a/tests/components/lovelace/test_resources.py b/tests/components/lovelace/test_resources.py index bf6b44f0950..281fb001fc2 100644 --- a/tests/components/lovelace/test_resources.py +++ b/tests/components/lovelace/test_resources.py @@ -2,7 +2,7 @@ import copy from typing import Any -from unittest.mock import patch +from unittest.mock import ANY, patch import uuid import pytest @@ -101,8 +101,43 @@ async def test_storage_resources_import( client = await hass_ws_client(hass) - # Fetch data - await client.send_json({"id": 5, "type": list_cmd}) + # Subscribe + await client.send_json_auto_id({"type": "lovelace/resources/subscribe"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] is None + event_id = response["id"] + + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [] + + # Fetch data - this also loads the resources + await client.send_json_auto_id({"type": list_cmd}) + + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [ + { + "change_type": "added", + "item": { + "id": ANY, + "type": "js", + "url": "/local/bla.js", + }, + "resource_id": ANY, + }, + { + "change_type": "added", + "item": { + "id": ANY, + "type": "css", + "url": "/local/bla.css", + }, + "resource_id": ANY, + }, + ] + response = await client.receive_json() assert response["success"] assert ( @@ -115,18 +150,31 @@ async def test_storage_resources_import( ) # Add a resource - await client.send_json( + await client.send_json_auto_id( { - "id": 6, "type": "lovelace/resources/create", "res_type": "module", "url": "/local/yo.js", } ) + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [ + { + "change_type": "added", + "item": { + "id": ANY, + "type": "module", + "url": "/local/yo.js", + }, + "resource_id": ANY, + } + ] + response = await client.receive_json() assert response["success"] - await client.send_json({"id": 7, "type": list_cmd}) + await client.send_json_auto_id({"type": list_cmd}) response = await client.receive_json() assert response["success"] @@ -137,19 +185,32 @@ async def test_storage_resources_import( # Update a resource first_item = response["result"][0] - await client.send_json( + await client.send_json_auto_id( { - "id": 8, "type": "lovelace/resources/update", "resource_id": first_item["id"], "res_type": "css", "url": "/local/updated.css", } ) + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [ + { + "change_type": "updated", + "item": { + "id": first_item["id"], + "type": "css", + "url": "/local/updated.css", + }, + "resource_id": first_item["id"], + } + ] + response = await client.receive_json() assert response["success"] - await client.send_json({"id": 9, "type": list_cmd}) + await client.send_json_auto_id({"type": list_cmd}) response = await client.receive_json() assert response["success"] @@ -157,18 +218,31 @@ async def test_storage_resources_import( assert first_item["type"] == "css" assert first_item["url"] == "/local/updated.css" - # Delete resources - await client.send_json( + # Delete a resource + await client.send_json_auto_id( { - "id": 10, "type": "lovelace/resources/delete", "resource_id": first_item["id"], } ) + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [ + { + "change_type": "removed", + "item": { + "id": first_item["id"], + "type": "css", + "url": "/local/updated.css", + }, + "resource_id": first_item["id"], + } + ] + response = await client.receive_json() assert response["success"] - await client.send_json({"id": 11, "type": list_cmd}) + await client.send_json_auto_id({"type": list_cmd}) response = await client.receive_json() assert response["success"] diff --git a/tests/helpers/test_collection.py b/tests/helpers/test_collection.py index dc9ac21e246..f4d5b06dae0 100644 --- a/tests/helpers/test_collection.py +++ b/tests/helpers/test_collection.py @@ -563,3 +563,214 @@ async def test_storage_collection_websocket( "name": "Updated name", }, ) + + +async def test_storage_collection_websocket_subscribe( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test exposing a storage collection via websockets.""" + store = storage.Store(hass, 1, "test-data") + coll = MockStorageCollection(store) + changes = track_changes(coll) + collection.DictStorageCollectionWebsocket( + coll, + "test_item/collection", + "test_item", + {vol.Required("name"): str, vol.Required("immutable_string"): str}, + {vol.Optional("name"): str}, + ).async_setup(hass) + + client = await hass_ws_client(hass) + + # Subscribe + await client.send_json_auto_id({"type": "test_item/collection/subscribe"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] is None + assert len(changes) == 0 + event_id = response["id"] + + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [] + + # Create invalid + await client.send_json_auto_id( + { + "type": "test_item/collection/create", + "name": 1, + # Forgot to add immutable_string + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "invalid_format" + assert len(changes) == 0 + + # Create + await client.send_json_auto_id( + { + "type": "test_item/collection/create", + "name": "Initial Name", + "immutable_string": "no-changes", + } + ) + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [ + { + "change_type": "added", + "item": { + "id": "initial_name", + "immutable_string": "no-changes", + "name": "Initial Name", + }, + "test_item_id": "initial_name", + } + ] + response = await client.receive_json() + assert response["success"] + assert response["result"] == { + "id": "initial_name", + "name": "Initial Name", + "immutable_string": "no-changes", + } + assert len(changes) == 1 + assert changes[0] == (collection.CHANGE_ADDED, "initial_name", response["result"]) + + # Subscribe again + await client.send_json_auto_id({"type": "test_item/collection/subscribe"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] is None + event_id_2 = response["id"] + + response = await client.receive_json() + assert response["id"] == event_id_2 + assert response["event"] == [ + { + "change_type": "added", + "item": { + "id": "initial_name", + "immutable_string": "no-changes", + "name": "Initial Name", + }, + "test_item_id": "initial_name", + }, + ] + + await client.send_json_auto_id( + {"type": "unsubscribe_events", "subscription": event_id_2} + ) + response = await client.receive_json() + assert response["success"] + + # List + await client.send_json_auto_id({"type": "test_item/collection/list"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + { + "id": "initial_name", + "name": "Initial Name", + "immutable_string": "no-changes", + } + ] + assert len(changes) == 1 + + # Update invalid data + await client.send_json_auto_id( + { + "type": "test_item/collection/update", + "test_item_id": "initial_name", + "immutable_string": "no-changes", + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "invalid_format" + assert len(changes) == 1 + + # Update invalid item + await client.send_json_auto_id( + { + "type": "test_item/collection/update", + "test_item_id": "non-existing", + "name": "Updated name", + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "not_found" + assert len(changes) == 1 + + # Update + await client.send_json_auto_id( + { + "type": "test_item/collection/update", + "test_item_id": "initial_name", + "name": "Updated name", + } + ) + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [ + { + "change_type": "updated", + "item": { + "id": "initial_name", + "immutable_string": "no-changes", + "name": "Updated name", + }, + "test_item_id": "initial_name", + } + ] + response = await client.receive_json() + assert response["success"] + assert response["result"] == { + "id": "initial_name", + "name": "Updated name", + "immutable_string": "no-changes", + } + assert len(changes) == 2 + assert changes[1] == (collection.CHANGE_UPDATED, "initial_name", response["result"]) + + # Delete invalid ID + await client.send_json_auto_id( + {"type": "test_item/collection/update", "test_item_id": "non-existing"} + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "not_found" + assert len(changes) == 2 + + # Delete + await client.send_json_auto_id( + {"type": "test_item/collection/delete", "test_item_id": "initial_name"} + ) + response = await client.receive_json() + assert response["id"] == event_id + assert response["event"] == [ + { + "change_type": "removed", + "item": { + "id": "initial_name", + "immutable_string": "no-changes", + "name": "Updated name", + }, + "test_item_id": "initial_name", + } + ] + response = await client.receive_json() + assert response["success"] + + assert len(changes) == 3 + assert changes[2] == ( + collection.CHANGE_REMOVED, + "initial_name", + { + "id": "initial_name", + "immutable_string": "no-changes", + "name": "Updated name", + }, + )