diff --git a/homeassistant/components/frontend/storage.py b/homeassistant/components/frontend/storage.py index c75fa360db7..11d155dbcb4 100644 --- a/homeassistant/components/frontend/storage.py +++ b/homeassistant/components/frontend/storage.py @@ -10,7 +10,7 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.components.websocket_api import ActiveConnection -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.storage import Store from homeassistant.util.hass_dict import HassKey @@ -22,6 +22,7 @@ async def async_setup_frontend_storage(hass: HomeAssistant) -> None: """Set up frontend storage.""" websocket_api.async_register_command(hass, websocket_set_user_data) websocket_api.async_register_command(hass, websocket_get_user_data) + websocket_api.async_register_command(hass, websocket_subscribe_user_data) async def async_user_store(hass: HomeAssistant, user_id: str) -> UserStore: @@ -41,6 +42,7 @@ class UserStore: """Initialize the user store.""" self._store = _UserStore(hass, user_id) self.data: dict[str, Any] = {} + self.subscriptions: dict[str | None, list[Callable[[], None]]] = {} async def async_load(self) -> None: """Load the data from the store.""" @@ -50,6 +52,23 @@ class UserStore: """Set an item item and save the store.""" self.data[key] = value await self._store.async_save(self.data) + for cb in self.subscriptions.get(None, []): + cb() + for cb in self.subscriptions.get(key, []): + cb() + + @callback + def async_subscribe( + self, key: str | None, on_update_callback: Callable[[], None] + ) -> Callable[[], None]: + """Save the data to the store.""" + self.subscriptions.setdefault(key, []).append(on_update_callback) + + def unsubscribe() -> None: + """Unsubscribe from the store.""" + self.subscriptions[key].remove(on_update_callback) + + return unsubscribe class _UserStore(Store[dict[str, Any]]): @@ -124,3 +143,29 @@ async def websocket_get_user_data( connection.send_result( msg["id"], {"value": data.get(msg["key"]) if "key" in msg else data} ) + + +@websocket_api.websocket_command( + {vol.Required("type"): "frontend/subscribe_user_data", vol.Optional("key"): str} +) +@websocket_api.async_response +@with_user_store +async def websocket_subscribe_user_data( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict[str, Any], + store: UserStore, +) -> None: + """Handle subscribe to user data command.""" + key: str | None = msg.get("key") + + def on_data_update() -> None: + """Handle user data update.""" + data = store.data + connection.send_event( + msg["id"], {"value": data.get(key) if key is not None else data} + ) + + connection.subscriptions[msg["id"]] = store.async_subscribe(key, on_data_update) + on_data_update() + connection.send_result(msg["id"]) diff --git a/tests/components/frontend/test_storage.py b/tests/components/frontend/test_storage.py index 360ca151551..f4a61b743c5 100644 --- a/tests/components/frontend/test_storage.py +++ b/tests/components/frontend/test_storage.py @@ -79,12 +79,46 @@ async def test_get_user_data( assert res["result"]["value"]["test-complex"][0]["foo"] == "bar" +@pytest.mark.parametrize( + ("subscriptions", "events"), + [ + ([], []), + ([(1, {}, {})], [(1, {"test-key": "test-value"})]), + ([(1, {"key": "test-key"}, None)], [(1, "test-value")]), + ([(1, {"key": "other-key"}, None)], []), + ], +) async def test_set_user_data_empty( - hass: HomeAssistant, hass_ws_client: WebSocketGenerator + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + subscriptions: list[tuple[int, dict[str, str], Any]], + events: list[tuple[int, Any]], ) -> None: - """Test set_user_data command.""" + """Test set_user_data command. + + Also test subscribing. + """ client = await hass_ws_client(hass) + for msg_id, key, event_data in subscriptions: + await client.send_json( + { + "id": msg_id, + "type": "frontend/subscribe_user_data", + } + | key + ) + + event = await client.receive_json() + assert event == { + "id": msg_id, + "type": "event", + "event": {"value": event_data}, + } + + res = await client.receive_json() + assert res["success"], res + # test creating await client.send_json( @@ -104,6 +138,10 @@ async def test_set_user_data_empty( } ) + for msg_id, event_data in events: + event = await client.receive_json() + assert event == {"id": msg_id, "type": "event", "event": {"value": event_data}} + res = await client.receive_json() assert res["success"], res @@ -116,11 +154,63 @@ async def test_set_user_data_empty( assert res["result"]["value"] == "test-value" +@pytest.mark.parametrize( + ("subscriptions", "events"), + [ + ( + [], + [[], []], + ), + ( + [(1, {}, {"test-key": "test-value", "test-complex": "string"})], + [ + [ + ( + 1, + { + "test-complex": "string", + "test-key": "test-value", + "test-non-existent-key": "test-value-new", + }, + ) + ], + [ + ( + 1, + { + "test-complex": [{"foo": "bar"}], + "test-key": "test-value", + "test-non-existent-key": "test-value-new", + }, + ) + ], + ], + ), + ( + [(1, {"key": "test-key"}, "test-value")], + [[], []], + ), + ( + [(1, {"key": "test-non-existent-key"}, None)], + [[(1, "test-value-new")], []], + ), + ( + [(1, {"key": "test-complex"}, "string")], + [[], [(1, [{"foo": "bar"}])]], + ), + ( + [(1, {"key": "other-key"}, None)], + [[], []], + ), + ], +) async def test_set_user_data( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, hass_storage: dict[str, Any], hass_admin_user: MockUser, + subscriptions: list[tuple[int, dict[str, str], Any]], + events: list[list[tuple[int, Any]]], ) -> None: """Test set_user_data command with initial data.""" storage_key = f"{DOMAIN}.user_data_{hass_admin_user.id}" @@ -131,6 +221,25 @@ async def test_set_user_data( client = await hass_ws_client(hass) + for msg_id, key, event_data in subscriptions: + await client.send_json( + { + "id": msg_id, + "type": "frontend/subscribe_user_data", + } + | key + ) + + event = await client.receive_json() + assert event == { + "id": msg_id, + "type": "event", + "event": {"value": event_data}, + } + + res = await client.receive_json() + assert res["success"], res + # test creating await client.send_json( @@ -142,6 +251,10 @@ async def test_set_user_data( } ) + for msg_id, event_data in events[0]: + event = await client.receive_json() + assert event == {"id": msg_id, "type": "event", "event": {"value": event_data}} + res = await client.receive_json() assert res["success"], res @@ -164,6 +277,10 @@ async def test_set_user_data( } ) + for msg_id, event_data in events[1]: + event = await client.receive_json() + assert event == {"id": msg_id, "type": "event", "event": {"value": event_data}} + res = await client.receive_json() assert res["success"], res