diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index ec1ab267a37..341aebdf9e4 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import lru_cache import logging -from typing import Any, Final +from typing import TYPE_CHECKING, Any, Final, cast import voluptuous as vol @@ -121,14 +121,16 @@ def _state_diff_event(event: Event) -> dict: """ if (event_new_state := event.data["new_state"]) is None: return {ENTITY_EVENT_REMOVE: [event.data["entity_id"]]} - assert isinstance(event_new_state, State) + if TYPE_CHECKING: + event_new_state = cast(State, event_new_state) if (event_old_state := event.data["old_state"]) is None: return { ENTITY_EVENT_ADD: { event_new_state.entity_id: event_new_state.as_compressed_state() } } - assert isinstance(event_old_state, State) + if TYPE_CHECKING: + event_old_state = cast(State, event_old_state) return _state_diff(event_old_state, event_new_state) @@ -136,27 +138,28 @@ def _state_diff( old_state: State, new_state: State ) -> dict[str, dict[str, dict[str, dict[str, str | list[str]]]]]: """Create a diff dict that can be used to overlay changes.""" - diff: dict = {STATE_DIFF_ADDITIONS: {}} - additions = diff[STATE_DIFF_ADDITIONS] + additions: dict[str, Any] = {} + diff: dict[str, dict[str, Any]] = {STATE_DIFF_ADDITIONS: additions} + new_state_context = new_state.context + old_state_context = old_state.context if old_state.state != new_state.state: additions[COMPRESSED_STATE_STATE] = new_state.state if old_state.last_changed != new_state.last_changed: additions[COMPRESSED_STATE_LAST_CHANGED] = new_state.last_changed.timestamp() elif old_state.last_updated != new_state.last_updated: additions[COMPRESSED_STATE_LAST_UPDATED] = new_state.last_updated.timestamp() - if old_state.context.parent_id != new_state.context.parent_id: - additions.setdefault(COMPRESSED_STATE_CONTEXT, {})[ - "parent_id" - ] = new_state.context.parent_id - if old_state.context.user_id != new_state.context.user_id: - additions.setdefault(COMPRESSED_STATE_CONTEXT, {})[ - "user_id" - ] = new_state.context.user_id - if old_state.context.id != new_state.context.id: + if old_state_context.parent_id != new_state_context.parent_id: + additions[COMPRESSED_STATE_CONTEXT] = {"parent_id": new_state_context.parent_id} + if old_state_context.user_id != new_state_context.user_id: if COMPRESSED_STATE_CONTEXT in additions: - additions[COMPRESSED_STATE_CONTEXT]["id"] = new_state.context.id + additions[COMPRESSED_STATE_CONTEXT]["user_id"] = new_state_context.user_id else: - additions[COMPRESSED_STATE_CONTEXT] = new_state.context.id + additions[COMPRESSED_STATE_CONTEXT] = {"user_id": new_state_context.user_id} + if old_state_context.id != new_state_context.id: + if COMPRESSED_STATE_CONTEXT in additions: + additions[COMPRESSED_STATE_CONTEXT]["id"] = new_state_context.id + else: + additions[COMPRESSED_STATE_CONTEXT] = new_state_context.id if (old_attributes := old_state.attributes) != ( new_attributes := new_state.attributes ): diff --git a/tests/components/websocket_api/test_messages.py b/tests/components/websocket_api/test_messages.py index ec7523a02e5..d2102b651b7 100644 --- a/tests/components/websocket_api/test_messages.py +++ b/tests/components/websocket_api/test_messages.py @@ -3,11 +3,14 @@ import pytest from homeassistant.components.websocket_api.messages import ( _cached_event_message as lru_event_cache, + _state_diff_event, cached_event_message, message_to_json, ) from homeassistant.const import EVENT_STATE_CHANGED -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Context, Event, HomeAssistant, State, callback + +from tests.common import async_capture_events async def test_cached_event_message(hass: HomeAssistant) -> None: @@ -79,6 +82,147 @@ async def test_cached_event_message_with_different_idens(hass: HomeAssistant) -> assert cache_info.currsize == 1 +async def test_state_diff_event(hass: HomeAssistant) -> None: + """Test building state_diff_message.""" + state_change_events = async_capture_events(hass, EVENT_STATE_CHANGED) + context = Context(user_id="user-id", parent_id="parent-id", id="id") + hass.states.async_set("light.window", "on", context=context) + hass.states.async_set("light.window", "off", context=context) + await hass.async_block_till_done() + + last_state_event: Event = state_change_events[-1] + new_state: State = last_state_event.data["new_state"] + message = _state_diff_event(last_state_event) + assert message == { + "c": { + "light.window": { + "+": {"lc": new_state.last_changed.timestamp(), "s": "off"} + } + } + } + + hass.states.async_set( + "light.window", + "red", + context=Context(user_id="user-id", parent_id="new-parent-id", id="id"), + ) + await hass.async_block_till_done() + last_state_event: Event = state_change_events[-1] + new_state: State = last_state_event.data["new_state"] + message = _state_diff_event(last_state_event) + + assert message == { + "c": { + "light.window": { + "+": { + "c": {"parent_id": "new-parent-id"}, + "lc": new_state.last_changed.timestamp(), + "s": "red", + } + } + } + } + + hass.states.async_set( + "light.window", + "green", + context=Context( + user_id="new-user-id", parent_id="another-new-parent-id", id="id" + ), + ) + await hass.async_block_till_done() + last_state_event: Event = state_change_events[-1] + new_state: State = last_state_event.data["new_state"] + message = _state_diff_event(last_state_event) + + assert message == { + "c": { + "light.window": { + "+": { + "c": { + "parent_id": "another-new-parent-id", + "user_id": "new-user-id", + }, + "lc": new_state.last_changed.timestamp(), + "s": "green", + } + } + } + } + + hass.states.async_set( + "light.window", + "blue", + context=Context( + user_id="another-new-user-id", parent_id="another-new-parent-id", id="id" + ), + ) + await hass.async_block_till_done() + last_state_event: Event = state_change_events[-1] + new_state: State = last_state_event.data["new_state"] + message = _state_diff_event(last_state_event) + + assert message == { + "c": { + "light.window": { + "+": { + "c": {"user_id": "another-new-user-id"}, + "lc": new_state.last_changed.timestamp(), + "s": "blue", + } + } + } + } + + hass.states.async_set( + "light.window", + "yellow", + context=Context( + user_id="another-new-user-id", + parent_id="another-new-parent-id", + id="id-new", + ), + ) + await hass.async_block_till_done() + last_state_event: Event = state_change_events[-1] + new_state: State = last_state_event.data["new_state"] + message = _state_diff_event(last_state_event) + + assert message == { + "c": { + "light.window": { + "+": { + "c": "id-new", + "lc": new_state.last_changed.timestamp(), + "s": "yellow", + } + } + } + } + + new_context = Context() + hass.states.async_set( + "light.window", "purple", {"new": "attr"}, context=new_context + ) + await hass.async_block_till_done() + last_state_event: Event = state_change_events[-1] + new_state: State = last_state_event.data["new_state"] + message = _state_diff_event(last_state_event) + + assert message == { + "c": { + "light.window": { + "+": { + "a": {"new": "attr"}, + "c": {"id": new_context.id, "parent_id": None, "user_id": None}, + "lc": new_state.last_changed.timestamp(), + "s": "purple", + } + } + } + } + + async def test_message_to_json(caplog: pytest.LogCaptureFixture) -> None: """Test we can serialize websocket messages."""