diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 6044e55978d..e64ba46beb7 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -16,7 +16,7 @@ from homeassistant.const import ( MATCH_ALL, SIGNAL_BOOTSTRAP_INTEGRATONS, ) -from homeassistant.core import Context, Event, HomeAssistant, callback +from homeassistant.core import Context, Event, HomeAssistant, State, callback from homeassistant.exceptions import ( HomeAssistantError, ServiceNotFound, @@ -68,6 +68,7 @@ def async_register_commands( async_reg(hass, handle_test_condition) async_reg(hass, handle_unsubscribe_events) async_reg(hass, handle_validate_config) + async_reg(hass, handle_subscribe_entities) def pong_message(iden: int) -> dict[str, Any]: @@ -213,21 +214,27 @@ async def handle_call_service( connection.send_error(msg["id"], const.ERR_UNKNOWN_ERROR, str(err)) +@callback +def _async_get_allowed_states( + hass: HomeAssistant, connection: ActiveConnection +) -> list[State]: + if connection.user.permissions.access_all_entities("read"): + return hass.states.async_all() + entity_perm = connection.user.permissions.check_entity + return [ + state + for state in hass.states.async_all() + if entity_perm(state.entity_id, "read") + ] + + @callback @decorators.websocket_command({vol.Required("type"): "get_states"}) def handle_get_states( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Handle get states command.""" - if connection.user.permissions.access_all_entities("read"): - states = hass.states.async_all() - else: - entity_perm = connection.user.permissions.check_entity - states = [ - state - for state in hass.states.async_all() - if entity_perm(state.entity_id, "read") - ] + states = _async_get_allowed_states(hass, connection) # JSON serialize here so we can recover if it blows up due to the # state machine containing unserializable data. This command is required @@ -260,6 +267,84 @@ def handle_get_states( connection.send_message(response2) +@callback +@decorators.websocket_command( + { + vol.Required("type"): "subscribe_entities", + vol.Optional("entity_ids"): cv.entity_ids, + } +) +def handle_subscribe_entities( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle subscribe entities command.""" + # Circular dep + # pylint: disable=import-outside-toplevel + from .permissions import SUBSCRIBE_ALLOWLIST + + if "state_changed" not in SUBSCRIBE_ALLOWLIST and not connection.user.is_admin: + raise Unauthorized + + entity_ids = set(msg.get("entity_ids", [])) + + @callback + def forward_entity_changes(event: Event) -> None: + """Forward entity state changed events to websocket.""" + if not connection.user.permissions.check_entity( + event.data["entity_id"], POLICY_READ + ): + return + if entity_ids and event.data["entity_id"] not in entity_ids: + return + + connection.send_message(messages.cached_state_diff_message(msg["id"], event)) + + # We must never await between sending the states and listening for + # state changed events or we will introduce a race condition + # where some states are missed + states = _async_get_allowed_states(hass, connection) + connection.subscriptions[msg["id"]] = hass.bus.async_listen( + "state_changed", forward_entity_changes + ) + connection.send_result(msg["id"]) + data: dict[str, dict[str, dict]] = { + messages.ENTITY_EVENT_ADD: { + state.entity_id: messages.compressed_state_dict_add(state) + for state in states + if not entity_ids or state.entity_id in entity_ids + } + } + + # JSON serialize here so we can recover if it blows up due to the + # state machine containing unserializable data. This command is required + # to succeed for the UI to show. + response = messages.event_message(msg["id"], data) + try: + connection.send_message(const.JSON_DUMP(response)) + return + except (ValueError, TypeError): + connection.logger.error( + "Unable to serialize to JSON. Bad data found at %s", + format_unserializable_data( + find_paths_unserializable_data(response, dump=const.JSON_DUMP) + ), + ) + del response + + add_entities = data[messages.ENTITY_EVENT_ADD] + cannot_serialize: list[str] = [] + for entity_id, state_dict in add_entities.items(): + try: + const.JSON_DUMP(state_dict) + except (ValueError, TypeError): + cannot_serialize.append(entity_id) + + for entity_id in cannot_serialize: + del add_entities[entity_id] + + connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data))) + + @decorators.websocket_command({vol.Required("type"): "get_services"}) @decorators.async_response async def handle_get_services( diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index 8cdda3f8fa3..0695b279361 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -7,7 +7,7 @@ from typing import Any, Final import voluptuous as vol -from homeassistant.core import Event +from homeassistant.core import Event, State from homeassistant.helpers import config_validation as cv from homeassistant.util.json import ( find_paths_unserializable_data, @@ -31,6 +31,19 @@ BASE_COMMAND_MESSAGE_SCHEMA: Final = vol.Schema({vol.Required("id"): cv.positive IDEN_TEMPLATE: Final = "__IDEN__" IDEN_JSON_TEMPLATE: Final = '"__IDEN__"' +COMPRESSED_STATE_STATE = "s" +COMPRESSED_STATE_ATTRIBUTES = "a" +COMPRESSED_STATE_CONTEXT = "c" +COMPRESSED_STATE_LAST_CHANGED = "lc" +COMPRESSED_STATE_LAST_UPDATED = "lu" + +STATE_DIFF_ADDITIONS = "+" +STATE_DIFF_REMOVALS = "-" + +ENTITY_EVENT_ADD = "a" +ENTITY_EVENT_REMOVE = "r" +ENTITY_EVENT_CHANGE = "c" + def result_message(iden: int, result: Any = None) -> dict[str, Any]: """Return a success result message.""" @@ -74,6 +87,110 @@ def _cached_event_message(event: Event) -> str: return message_to_json(event_message(IDEN_TEMPLATE, event)) +def cached_state_diff_message(iden: int, event: Event) -> str: + """Return an event message. + + Serialize to json once per message. + + Since we can have many clients connected that are + all getting many of the same events (mostly state changed) + we can avoid serializing the same data for each connection. + """ + return _cached_state_diff_message(event).replace(IDEN_JSON_TEMPLATE, str(iden), 1) + + +@lru_cache(maxsize=128) +def _cached_state_diff_message(event: Event) -> str: + """Cache and serialize the event to json. + + The IDEN_TEMPLATE is used which will be replaced + with the actual iden in cached_event_message + """ + return message_to_json(event_message(IDEN_TEMPLATE, _state_diff_event(event))) + + +def _state_diff_event(event: Event) -> dict: + """Convert a state_changed event to the minimal version. + + State update example + + { + "a": {entity_id: compressed_state,…} + "c": {entity_id: diff,…} + "r": [entity_id,…] + } + """ + if (event_new_state := event.data["new_state"]) is None: + return {ENTITY_EVENT_REMOVE: [event.data["entity_id"]]} + assert isinstance(event_new_state, State) + if (event_old_state := event.data["old_state"]) is None: + return { + ENTITY_EVENT_ADD: { + event_new_state.entity_id: compressed_state_dict_add(event_new_state) + } + } + assert isinstance(event_old_state, State) + return _state_diff(event_old_state, event_new_state) + + +def _state_diff( + old_state: State, new_state: State +) -> dict[str, dict[str, dict[str, dict[str, str | list[str]]]]]: + """Create a diff dict that can be used to overlay changes.""" + diff: dict = {STATE_DIFF_ADDITIONS: {}} + additions = diff[STATE_DIFF_ADDITIONS] + if old_state.state != new_state.state: + additions[COMPRESSED_STATE_STATE] = new_state.state + if old_state.last_changed != new_state.last_changed: + additions[COMPRESSED_STATE_LAST_CHANGED] = new_state.last_changed.timestamp() + elif old_state.last_updated != new_state.last_updated: + additions[COMPRESSED_STATE_LAST_UPDATED] = new_state.last_updated.timestamp() + if old_state.context.parent_id != new_state.context.parent_id: + additions.setdefault(COMPRESSED_STATE_CONTEXT, {})[ + "parent_id" + ] = new_state.context.parent_id + if old_state.context.user_id != new_state.context.user_id: + additions.setdefault(COMPRESSED_STATE_CONTEXT, {})[ + "user_id" + ] = new_state.context.user_id + if old_state.context.id != new_state.context.id: + if COMPRESSED_STATE_CONTEXT in additions: + additions[COMPRESSED_STATE_CONTEXT]["id"] = new_state.context.id + else: + additions[COMPRESSED_STATE_CONTEXT] = new_state.context.id + old_attributes = old_state.attributes + for key, value in new_state.attributes.items(): + if old_attributes.get(key) != value: + additions.setdefault(COMPRESSED_STATE_ATTRIBUTES, {})[key] = value + if removed := set(old_attributes).difference(new_state.attributes): + diff[STATE_DIFF_REMOVALS] = {COMPRESSED_STATE_ATTRIBUTES: removed} + return {ENTITY_EVENT_CHANGE: {new_state.entity_id: diff}} + + +def compressed_state_dict_add(state: State) -> dict[str, Any]: + """Build a compressed dict of a state for adds. + + Omits the lu (last_updated) if it matches (lc) last_changed. + + Sends c (context) as a string if it only contains an id. + """ + if state.context.parent_id is None and state.context.user_id is None: + context: dict[str, Any] | str = state.context.id # type: ignore[unreachable] + else: + context = state.context.as_dict() + compressed_state: dict[str, Any] = { + COMPRESSED_STATE_STATE: state.state, + COMPRESSED_STATE_ATTRIBUTES: state.attributes, + COMPRESSED_STATE_CONTEXT: context, + } + if state.last_changed == state.last_updated: + compressed_state[COMPRESSED_STATE_LAST_CHANGED] = state.last_changed.timestamp() + else: + compressed_state[COMPRESSED_STATE_LAST_CHANGED] = state.last_changed.timestamp() + compressed_state[COMPRESSED_STATE_LAST_UPDATED] = state.last_updated.timestamp() + return compressed_state + + def message_to_json(message: dict[str, Any]) -> str: """Serialize a websocket message to json.""" try: diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 742d9bddd38..007e130ff6f 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -1,4 +1,5 @@ """Tests for WebSocket API commands.""" +from copy import deepcopy import datetime from unittest.mock import ANY, patch @@ -14,7 +15,7 @@ from homeassistant.components.websocket_api.auth import ( ) from homeassistant.components.websocket_api.const import URL from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS -from homeassistant.core import Context, HomeAssistant, callback +from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity from homeassistant.helpers.dispatcher import async_dispatcher_send @@ -23,6 +24,38 @@ from homeassistant.setup import DATA_SETUP_TIME, async_setup_component from tests.common import MockEntity, MockEntityPlatform, async_mock_service +STATE_KEY_SHORT_NAMES = { + "entity_id": "e", + "state": "s", + "last_changed": "lc", + "last_updated": "lu", + "context": "c", + "attributes": "a", +} +STATE_KEY_LONG_NAMES = {v: k for k, v in STATE_KEY_SHORT_NAMES.items()} + + +def _apply_entities_changes(state_dict: dict, change_dict: dict) -> None: + """Apply a diff set to a dict. + + Port of the client side merging + """ + additions = change_dict.get("+", {}) + if "lc" in additions: + additions["lu"] = additions["lc"] + if attributes := additions.pop("a", None): + state_dict["attributes"].update(attributes) + if context := additions.pop("c", None): + if isinstance(context, str): + state_dict["context"]["id"] = context + else: + state_dict["context"].update(context) + for k, v in additions.items(): + state_dict[STATE_KEY_LONG_NAMES[k]] = v + for key, items in change_dict.get("-", {}).items(): + for item in items: + del state_dict[STATE_KEY_LONG_NAMES[key]][item] + async def test_fire_event(hass, websocket_client): """Test fire event command.""" @@ -666,6 +699,349 @@ async def test_subscribe_unsubscribe_events_state_changed( assert msg["event"]["data"]["entity_id"] == "light.permitted" +async def test_subscribe_entities_with_unserializable_state( + hass, websocket_client, hass_admin_user +): + """Test subscribe entities with an unserializeable state.""" + + class CannotSerializeMe: + """Cannot serialize this.""" + + def __init__(self): + """Init cannot serialize this.""" + + hass.states.async_set("light.permitted", "off", {"color": "red"}) + hass.states.async_set( + "light.cannot_serialize", + "off", + {"color": "red", "cannot_serialize": CannotSerializeMe()}, + ) + original_state = hass.states.get("light.cannot_serialize") + assert isinstance(original_state, State) + state_dict = { + "attributes": dict(original_state.attributes), + "context": dict(original_state.context.as_dict()), + "entity_id": original_state.entity_id, + "last_changed": original_state.last_changed.isoformat(), + "last_updated": original_state.last_updated.isoformat(), + "state": original_state.state, + } + hass_admin_user.groups = [] + hass_admin_user.mock_policy( + { + "entities": { + "entity_ids": {"light.permitted": True, "light.cannot_serialize": True} + } + } + ) + + await websocket_client.send_json({"id": 7, "type": "subscribe_entities"}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "a": { + "light.permitted": { + "a": {"color": "red"}, + "c": ANY, + "lc": ANY, + "s": "off", + } + } + } + hass.states.async_set("light.permitted", "on", {"effect": "help"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": { + "light.permitted": { + "+": { + "a": {"effect": "help"}, + "c": ANY, + "lc": ANY, + "s": "on", + }, + "-": {"a": ["color"]}, + } + } + } + hass.states.async_set("light.cannot_serialize", "on", {"effect": "help"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + # Order does not matter + msg["event"]["c"]["light.cannot_serialize"]["-"]["a"] = set( + msg["event"]["c"]["light.cannot_serialize"]["-"]["a"] + ) + assert msg["event"] == { + "c": { + "light.cannot_serialize": { + "+": {"a": {"effect": "help"}, "c": ANY, "lc": ANY, "s": "on"}, + "-": {"a": {"color", "cannot_serialize"}}, + } + } + } + change_set = msg["event"]["c"]["light.cannot_serialize"] + _apply_entities_changes(state_dict, change_set) + assert state_dict == { + "attributes": {"effect": "help"}, + "context": { + "id": ANY, + "parent_id": None, + "user_id": None, + }, + "entity_id": "light.cannot_serialize", + "last_changed": ANY, + "last_updated": ANY, + "state": "on", + } + hass.states.async_set( + "light.cannot_serialize", + "off", + {"color": "red", "cannot_serialize": CannotSerializeMe()}, + ) + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "result" + assert msg["error"] == { + "code": "unknown_error", + "message": "Invalid JSON in response", + } + + +async def test_subscribe_unsubscribe_entities(hass, websocket_client, hass_admin_user): + """Test subscribe/unsubscribe entities.""" + + hass.states.async_set("light.permitted", "off", {"color": "red"}) + original_state = hass.states.get("light.permitted") + assert isinstance(original_state, State) + state_dict = { + "attributes": dict(original_state.attributes), + "context": dict(original_state.context.as_dict()), + "entity_id": original_state.entity_id, + "last_changed": original_state.last_changed.isoformat(), + "last_updated": original_state.last_updated.isoformat(), + "state": original_state.state, + } + hass_admin_user.groups = [] + hass_admin_user.mock_policy({"entities": {"entity_ids": {"light.permitted": True}}}) + + await websocket_client.send_json({"id": 7, "type": "subscribe_entities"}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert isinstance(msg["event"]["a"]["light.permitted"]["c"], str) + assert msg["event"] == { + "a": { + "light.permitted": { + "a": {"color": "red"}, + "c": ANY, + "lc": ANY, + "s": "off", + } + } + } + hass.states.async_set("light.not_permitted", "on") + hass.states.async_set("light.permitted", "on", {"color": "blue"}) + hass.states.async_set("light.permitted", "on", {"effect": "help"}) + hass.states.async_set( + "light.permitted", "on", {"effect": "help", "color": ["blue", "green"]} + ) + hass.states.async_remove("light.permitted") + hass.states.async_set("light.permitted", "on", {"effect": "help", "color": "blue"}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": { + "light.permitted": { + "+": { + "a": {"color": "blue"}, + "c": ANY, + "lc": ANY, + "s": "on", + } + } + } + } + + change_set = msg["event"]["c"]["light.permitted"] + additions = deepcopy(change_set["+"]) + _apply_entities_changes(state_dict, change_set) + assert state_dict == { + "attributes": {"color": "blue"}, + "context": { + "id": additions["c"], + "parent_id": None, + "user_id": None, + }, + "entity_id": "light.permitted", + "last_changed": additions["lc"], + "last_updated": additions["lc"], + "state": "on", + } + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": { + "light.permitted": { + "+": { + "a": {"effect": "help"}, + "c": ANY, + "lu": ANY, + }, + "-": {"a": ["color"]}, + } + } + } + + change_set = msg["event"]["c"]["light.permitted"] + additions = deepcopy(change_set["+"]) + _apply_entities_changes(state_dict, change_set) + + assert state_dict == { + "attributes": {"effect": "help"}, + "context": { + "id": additions["c"], + "parent_id": None, + "user_id": None, + }, + "entity_id": "light.permitted", + "last_changed": ANY, + "last_updated": additions["lu"], + "state": "on", + } + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": { + "light.permitted": { + "+": { + "a": {"color": ["blue", "green"]}, + "c": ANY, + "lu": ANY, + } + } + } + } + + change_set = msg["event"]["c"]["light.permitted"] + additions = deepcopy(change_set["+"]) + _apply_entities_changes(state_dict, change_set) + + assert state_dict == { + "attributes": {"effect": "help", "color": ["blue", "green"]}, + "context": { + "id": additions["c"], + "parent_id": None, + "user_id": None, + }, + "entity_id": "light.permitted", + "last_changed": ANY, + "last_updated": additions["lu"], + "state": "on", + } + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == {"r": ["light.permitted"]} + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "a": { + "light.permitted": { + "a": {"color": "blue", "effect": "help"}, + "c": ANY, + "lc": ANY, + "s": "on", + } + } + } + + +async def test_subscribe_unsubscribe_entities_specific_entities( + hass, websocket_client, hass_admin_user +): + """Test subscribe/unsubscribe entities with a list of entity ids.""" + + hass.states.async_set("light.permitted", "off", {"color": "red"}) + hass.states.async_set("light.not_intrested", "off", {"color": "blue"}) + original_state = hass.states.get("light.permitted") + assert isinstance(original_state, State) + hass_admin_user.groups = [] + hass_admin_user.mock_policy( + { + "entities": { + "entity_ids": {"light.permitted": True, "light.not_intrested": True} + } + } + ) + + await websocket_client.send_json( + {"id": 7, "type": "subscribe_entities", "entity_ids": ["light.permitted"]} + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert isinstance(msg["event"]["a"]["light.permitted"]["c"], str) + assert msg["event"] == { + "a": { + "light.permitted": { + "a": {"color": "red"}, + "c": ANY, + "lc": ANY, + "s": "off", + } + } + } + hass.states.async_set("light.not_intrested", "on", {"effect": "help"}) + hass.states.async_set("light.not_permitted", "on") + hass.states.async_set("light.permitted", "on", {"color": "blue"}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": { + "light.permitted": { + "+": { + "a": {"color": "blue"}, + "c": ANY, + "lc": ANY, + "s": "on", + } + } + } + } + + async def test_render_template_renders_template(hass, websocket_client): """Test simple template is rendered and updated.""" hass.states.async_set("light.test", "on")