From 81458dbf6f5bbeac9299182067a7762cd02d2717 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 5 Jan 2024 08:51:49 -1000 Subject: [PATCH] Add test coverage for ESPHome state subscription (#107045) --- homeassistant/components/esphome/manager.py | 54 +++++++++--------- tests/components/esphome/conftest.py | 23 ++++++++ tests/components/esphome/test_manager.py | 63 +++++++++++++++++++++ 3 files changed, 114 insertions(+), 26 deletions(-) diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index f0263bdc48b..b897ffc9408 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -285,42 +285,44 @@ class ESPHomeManager: await self.cli.send_home_assistant_state(entity_id, attribute, str(send_state)) + async def _send_home_assistant_state_event( + self, + attribute: str | None, + event: EventType[EventStateChangedData], + ) -> None: + """Forward Home Assistant states updates to ESPHome.""" + event_data = event.data + new_state = event_data["new_state"] + old_state = event_data["old_state"] + + if new_state is None or old_state is None: + return + + # Only communicate changes to the state or attribute tracked + if (not attribute and old_state.state == new_state.state) or ( + attribute + and old_state.attributes.get(attribute) + == new_state.attributes.get(attribute) + ): + return + + await self._send_home_assistant_state( + event.data["entity_id"], attribute, new_state + ) + @callback def async_on_state_subscription( self, entity_id: str, attribute: str | None = None ) -> None: """Subscribe and forward states for requested entities.""" hass = self.hass - - async def send_home_assistant_state_event( - event: EventType[EventStateChangedData], - ) -> None: - """Forward Home Assistant states updates to ESPHome.""" - event_data = event.data - new_state = event_data["new_state"] - old_state = event_data["old_state"] - - if new_state is None or old_state is None: - return - - # Only communicate changes to the state or attribute tracked - if (not attribute and old_state.state == new_state.state) or ( - attribute - and old_state.attributes.get(attribute) - == new_state.attributes.get(attribute) - ): - return - - await self._send_home_assistant_state( - event.data["entity_id"], attribute, new_state - ) - self.entry_data.disconnect_callbacks.add( async_track_state_change_event( - hass, [entity_id], send_home_assistant_state_event + hass, + [entity_id], + partial(self._send_home_assistant_state_event, attribute), ) ) - # Send initial state hass.async_create_task( self._send_home_assistant_state( diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index 3acc5112720..0ac940018d7 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -180,6 +180,9 @@ class MockESPHomeDevice: self.service_call_callback: Callable[[HomeassistantServiceCall], None] self.on_disconnect: Callable[[bool], None] self.on_connect: Callable[[bool], None] + self.home_assistant_state_subscription_callback: Callable[ + [str, str | None], None + ] def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None: """Set the state callback.""" @@ -215,6 +218,19 @@ class MockESPHomeDevice: """Mock connecting.""" await self.on_connect() + def set_home_assistant_state_subscription_callback( + self, + on_state_sub: Callable[[str, str | None], None], + ) -> None: + """Set the state call callback.""" + self.home_assistant_state_subscription_callback = on_state_sub + + def mock_home_assistant_state_subscription( + self, entity_id: str, attribute: str | None + ) -> None: + """Mock a state subscription.""" + self.home_assistant_state_subscription_callback(entity_id, attribute) + async def _mock_generic_device_entry( hass: HomeAssistant, @@ -260,6 +276,12 @@ async def _mock_generic_device_entry( """Subscribe to service calls.""" mock_device.set_service_call_callback(callback) + async def _subscribe_home_assistant_states( + on_state_sub: Callable[[str, str | None], None], + ) -> None: + """Subscribe to home assistant states.""" + mock_device.set_home_assistant_state_subscription_callback(on_state_sub) + mock_client.device_info = AsyncMock(return_value=device_info) mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) mock_client.list_entities_services = AsyncMock( @@ -267,6 +289,7 @@ async def _mock_generic_device_entry( ) mock_client.subscribe_states = _subscribe_states mock_client.subscribe_service_calls = _subscribe_service_calls + mock_client.subscribe_home_assistant_states = _subscribe_home_assistant_states try_connect_done = Event() diff --git a/tests/components/esphome/test_manager.py b/tests/components/esphome/test_manager.py index a1ba05d4a94..96a8a341308 100644 --- a/tests/components/esphome/test_manager.py +++ b/tests/components/esphome/test_manager.py @@ -533,6 +533,69 @@ async def test_connection_aborted_wrong_device( assert "Unexpected device found at" not in caplog.text +async def test_state_subscription( + mock_client: APIClient, + hass: HomeAssistant, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test ESPHome subscribes to state changes.""" + device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + ) + await hass.async_block_till_done() + hass.states.async_set("binary_sensor.test", "on", {"bool": True, "float": 3.0}) + device.mock_home_assistant_state_subscription("binary_sensor.test", None) + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [ + call("binary_sensor.test", None, "on") + ] + mock_client.send_home_assistant_state.reset_mock() + hass.states.async_set("binary_sensor.test", "off", {"bool": True, "float": 3.0}) + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [ + call("binary_sensor.test", None, "off") + ] + mock_client.send_home_assistant_state.reset_mock() + device.mock_home_assistant_state_subscription("binary_sensor.test", "bool") + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [ + call("binary_sensor.test", "bool", "on") + ] + mock_client.send_home_assistant_state.reset_mock() + hass.states.async_set("binary_sensor.test", "off", {"bool": False, "float": 3.0}) + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [ + call("binary_sensor.test", "bool", "off") + ] + mock_client.send_home_assistant_state.reset_mock() + device.mock_home_assistant_state_subscription("binary_sensor.test", "float") + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [ + call("binary_sensor.test", "float", "3.0") + ] + mock_client.send_home_assistant_state.reset_mock() + hass.states.async_set("binary_sensor.test", "on", {"bool": True, "float": 4.0}) + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [ + call("binary_sensor.test", None, "on"), + call("binary_sensor.test", "bool", "on"), + call("binary_sensor.test", "float", "4.0"), + ] + mock_client.send_home_assistant_state.reset_mock() + hass.states.async_set("binary_sensor.test", "on", {}) + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [] + hass.states.async_remove("binary_sensor.test") + await hass.async_block_till_done() + assert mock_client.send_home_assistant_state.mock_calls == [] + + async def test_debug_logging( mock_client: APIClient, hass: HomeAssistant,