From f1b5dcdd1bd9d1118b886d60daa66b97057034c7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 9 Mar 2024 20:30:17 -1000 Subject: [PATCH] Refactor handling of device updates in ESPHome (#112864) --- homeassistant/components/esphome/entity.py | 13 +--- .../components/esphome/entry_data.py | 46 +++++++------ homeassistant/components/esphome/manager.py | 4 +- homeassistant/components/esphome/update.py | 10 +-- tests/components/esphome/test_update.py | 65 +++++++------------ 5 files changed, 58 insertions(+), 80 deletions(-) diff --git a/homeassistant/components/esphome/entity.py b/homeassistant/components/esphome/entity.py index aa98ccef70c..b4cc54b0bb7 100644 --- a/homeassistant/components/esphome/entity.py +++ b/homeassistant/components/esphome/entity.py @@ -22,7 +22,6 @@ from homeassistant.helpers import entity_platform import homeassistant.helpers.config_validation as cv import homeassistant.helpers.device_registry as dr from homeassistant.helpers.device_registry import DeviceInfo -from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -205,25 +204,19 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): async def async_added_to_hass(self) -> None: """Register callbacks.""" entry_data = self._entry_data - hass = self.hass - key = self._key - static_info = self._static_info - self.async_on_remove( - async_dispatcher_connect( - hass, - entry_data.signal_device_updated, + entry_data.async_subscribe_device_updated( self._on_device_update, ) ) self.async_on_remove( entry_data.async_subscribe_state_update( - self._state_type, key, self._on_state_update + self._state_type, self._key, self._on_state_update ) ) self.async_on_remove( entry_data.async_register_key_static_info_updated_callback( - static_info, self._on_static_info_update + self._static_info, self._on_static_info_update ) ) self._update_state_from_entry_data() diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 2a34089cbe9..66ac1ac6c05 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -108,18 +108,17 @@ class RuntimeEntryData: device_info: DeviceInfo | None = None bluetooth_device: ESPHomeBluetoothDevice | None = None api_version: APIVersion = field(default_factory=APIVersion) - cleanup_callbacks: list[Callable[[], None]] = field(default_factory=list) - disconnect_callbacks: set[Callable[[], None]] = field(default_factory=set) - state_subscriptions: dict[ - tuple[type[EntityState], int], Callable[[], None] - ] = field(default_factory=dict) + cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list) + disconnect_callbacks: set[CALLBACK_TYPE] = field(default_factory=set) + state_subscriptions: dict[tuple[type[EntityState], int], CALLBACK_TYPE] = field( + default_factory=dict + ) + device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set) loaded_platforms: set[Platform] = field(default_factory=set) platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _storage_contents: StoreData | None = None _pending_storage: Callable[[], StoreData] | None = None - assist_pipeline_update_callbacks: list[Callable[[], None]] = field( - default_factory=list - ) + assist_pipeline_update_callbacks: list[CALLBACK_TYPE] = field(default_factory=list) assist_pipeline_state: bool = False entity_info_callbacks: dict[ type[EntityInfo], list[Callable[[list[EntityInfo]], None]] @@ -143,11 +142,6 @@ class RuntimeEntryData: "_", " " ) - @property - def signal_device_updated(self) -> str: - """Return the signal to listen to for core device state update.""" - return f"esphome_{self.entry_id}_on_device_update" - @property def signal_static_info_updated(self) -> str: """Return the signal to listen to for updates on static info.""" @@ -216,15 +210,15 @@ class RuntimeEntryData: @callback def async_subscribe_assist_pipeline_update( - self, update_callback: Callable[[], None] - ) -> Callable[[], None]: + self, update_callback: CALLBACK_TYPE + ) -> CALLBACK_TYPE: """Subscribe to assist pipeline updates.""" self.assist_pipeline_update_callbacks.append(update_callback) return partial(self._async_unsubscribe_assist_pipeline_update, update_callback) @callback def _async_unsubscribe_assist_pipeline_update( - self, update_callback: Callable[[], None] + self, update_callback: CALLBACK_TYPE ) -> None: """Unsubscribe to assist pipeline updates.""" self.assist_pipeline_update_callbacks.remove(update_callback) @@ -307,13 +301,24 @@ class RuntimeEntryData: # Then send dispatcher event async_dispatcher_send(hass, self.signal_static_info_updated, infos) + @callback + def async_subscribe_device_updated(self, callback_: CALLBACK_TYPE) -> CALLBACK_TYPE: + """Subscribe to state updates.""" + self.device_update_subscriptions.add(callback_) + return partial(self._async_unsubscribe_device_update, callback_) + + @callback + def _async_unsubscribe_device_update(self, callback_: CALLBACK_TYPE) -> None: + """Unsubscribe to device updates.""" + self.device_update_subscriptions.remove(callback_) + @callback def async_subscribe_state_update( self, state_type: type[EntityState], state_key: int, - entity_callback: Callable[[], None], - ) -> Callable[[], None]: + entity_callback: CALLBACK_TYPE, + ) -> CALLBACK_TYPE: """Subscribe to state updates.""" subscription_key = (state_type, state_key) self.state_subscriptions[subscription_key] = entity_callback @@ -359,9 +364,10 @@ class RuntimeEntryData: _LOGGER.exception("Error while calling subscription: %s", ex) @callback - def async_update_device_state(self, hass: HomeAssistant) -> None: + def async_update_device_state(self) -> None: """Distribute an update of a core device state like availability.""" - async_dispatcher_send(hass, self.signal_device_updated) + for callback_ in self.device_update_subscriptions.copy(): + callback_() async def async_load_from_store(self) -> tuple[list[EntityInfo], list[UserService]]: """Load the retained data from store and return de-serialized data.""" diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index 27fb2cece89..3848171f806 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -455,7 +455,7 @@ class ESPHomeManager: self.device_id = _async_setup_device_registry(hass, entry, entry_data) - entry_data.async_update_device_state(hass) + entry_data.async_update_device_state() await entry_data.async_update_static_infos( hass, entry, entity_infos, device_info.mac_address ) @@ -510,7 +510,7 @@ class ESPHomeManager: # since it generates a lot of state changed events and database # writes when we already know we're shutting down and the state # will be cleared anyway. - entry_data.async_update_device_state(hass) + entry_data.async_update_device_state() async def on_connect_error(self, err: Exception) -> None: """Start reauth flow if appropriate connect error type.""" diff --git a/homeassistant/components/esphome/update.py b/homeassistant/components/esphome/update.py index 2219900cced..5a565f9914d 100644 --- a/homeassistant/components/esphome/update.py +++ b/homeassistant/components/esphome/update.py @@ -61,9 +61,7 @@ async def async_setup_entry( return unsubs = [ - async_dispatcher_connect( - hass, entry_data.signal_device_updated, _async_setup_update_entity - ), + entry_data.async_subscribe_device_updated(_async_setup_update_entity), dashboard.async_add_listener(_async_setup_update_entity), ] @@ -159,11 +157,7 @@ class ESPHomeUpdateEntity(CoordinatorEntity[ESPHomeDashboard], UpdateEntity): ) ) self.async_on_remove( - async_dispatcher_connect( - hass, - entry_data.signal_device_updated, - self._handle_device_update, - ) + entry_data.async_subscribe_device_updated(self._handle_device_update) ) async def async_install( diff --git a/tests/components/esphome/test_update.py b/tests/components/esphome/test_update.py index 9d5745e6594..d17f2f4623a 100644 --- a/tests/components/esphome/test_update.py +++ b/tests/components/esphome/test_update.py @@ -208,15 +208,25 @@ async def test_update_static_info( @pytest.mark.parametrize( - "expected_disconnect_state", [(True, STATE_ON), (False, STATE_UNAVAILABLE)] + ("expected_disconnect", "expected_state", "has_deep_sleep"), + [ + (True, STATE_ON, False), + (False, STATE_UNAVAILABLE, False), + (True, STATE_ON, True), + (False, STATE_ON, True), + ], ) async def test_update_device_state_for_availability( hass: HomeAssistant, - stub_reconnect, - expected_disconnect_state: tuple[bool, str], - mock_config_entry, - mock_device_info, + expected_disconnect: bool, + expected_state: str, + has_deep_sleep: bool, mock_dashboard, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], ) -> None: """Test ESPHome update entity changes availability with the device.""" mock_dashboard["configured"] = [ @@ -226,46 +236,21 @@ async def test_update_device_state_for_availability( }, ] await async_get_dashboard(hass).async_refresh() - - signal_device_updated = f"esphome_{mock_config_entry.entry_id}_on_device_update" - runtime_data = Mock( - available=True, - expected_disconnect=False, - device_info=mock_device_info, - signal_device_updated=signal_device_updated, + mock_device = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={"has_deep_sleep": has_deep_sleep}, ) - with patch( - "homeassistant.components.esphome.update.DomainData.get_entry_data", - return_value=runtime_data, - ): - assert await hass.config_entries.async_forward_entry_setup( - mock_config_entry, "update" - ) - - state = hass.states.get("update.none_firmware") + state = hass.states.get("update.test_firmware") assert state is not None - assert state.state == "on" - - expected_disconnect, expected_state = expected_disconnect_state - - runtime_data.available = False - runtime_data.expected_disconnect = expected_disconnect - async_dispatcher_send(hass, signal_device_updated) - - state = hass.states.get("update.none_firmware") + assert state.state == STATE_ON + await mock_device.mock_disconnect(expected_disconnect) + state = hass.states.get("update.test_firmware") assert state.state == expected_state - # Deep sleep devices should still be available - runtime_data.device_info = dataclasses.replace( - runtime_data.device_info, has_deep_sleep=True - ) - - async_dispatcher_send(hass, signal_device_updated) - - state = hass.states.get("update.none_firmware") - assert state.state == "on" - async def test_update_entity_dashboard_not_available_startup( hass: HomeAssistant,