From 33cdcce19151afbd26b7c7bfc747958e1b1b6427 Mon Sep 17 00:00:00 2001 From: Luke Lashley Date: Mon, 12 Feb 2024 03:37:37 -0500 Subject: [PATCH] Better teardown and setup of Roborock connections (#106092) Co-authored-by: Robert Resch --- homeassistant/components/roborock/__init__.py | 16 +++++----- .../components/roborock/coordinator.py | 3 +- homeassistant/components/roborock/device.py | 29 +++++++++++++++++-- homeassistant/components/roborock/select.py | 4 +-- homeassistant/components/roborock/sensor.py | 4 +-- homeassistant/components/roborock/vacuum.py | 16 +++++----- tests/components/roborock/test_init.py | 2 +- 7 files changed, 47 insertions(+), 27 deletions(-) diff --git a/homeassistant/components/roborock/__init__.py b/homeassistant/components/roborock/__init__.py index f8ceb121fe4..f4293213c00 100644 --- a/homeassistant/components/roborock/__init__.py +++ b/homeassistant/components/roborock/__init__.py @@ -115,6 +115,7 @@ async def setup_device( device.name, ) _LOGGER.debug(err) + await mqtt_client.async_release() raise err coordinator = RoborockDataUpdateCoordinator( hass, device, networking, product_info, mqtt_client @@ -130,6 +131,7 @@ async def setup_device( try: await coordinator.async_config_entry_first_refresh() except ConfigEntryNotReady as ex: + await coordinator.release() if isinstance(coordinator.api, RoborockMqttClient): _LOGGER.warning( "Not setting up %s because the we failed to get data for the first time using the online client. " @@ -158,14 +160,10 @@ async def setup_device( async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Handle removal of an entry.""" - unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - if unload_ok: - await asyncio.gather( - *( - coordinator.release() - for coordinator in hass.data[DOMAIN][entry.entry_id].values() - ) - ) + if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): + release_tasks = set() + for coordinator in hass.data[DOMAIN][entry.entry_id].values(): + release_tasks.add(coordinator.release()) hass.data[DOMAIN].pop(entry.entry_id) - + await asyncio.gather(*release_tasks) return unload_ok diff --git a/homeassistant/components/roborock/coordinator.py b/homeassistant/components/roborock/coordinator.py index 3864a90b16d..7154a36f7b8 100644 --- a/homeassistant/components/roborock/coordinator.py +++ b/homeassistant/components/roborock/coordinator.py @@ -79,7 +79,8 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]): async def release(self) -> None: """Disconnect from API.""" - await self.api.async_disconnect() + await self.api.async_release() + await self.cloud_api.async_release() async def _update_device_prop(self) -> None: """Update device properties.""" diff --git a/homeassistant/components/roborock/device.py b/homeassistant/components/roborock/device.py index 17531f6c627..2921a372e00 100644 --- a/homeassistant/components/roborock/device.py +++ b/homeassistant/components/roborock/device.py @@ -1,5 +1,4 @@ """Support for Roborock device base class.""" - from typing import Any from roborock.api import AttributeCache, RoborockClient @@ -7,6 +6,7 @@ from roborock.cloud_api import RoborockMqttClient from roborock.command_cache import CacheableAttribute from roborock.containers import Consumable, Status from roborock.exceptions import RoborockException +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand from homeassistant.exceptions import HomeAssistantError @@ -24,7 +24,10 @@ class RoborockEntity(Entity): _attr_has_entity_name = True def __init__( - self, unique_id: str, device_info: DeviceInfo, api: RoborockClient + self, + unique_id: str, + device_info: DeviceInfo, + api: RoborockClient, ) -> None: """Initialize the coordinated Roborock Device.""" self._attr_unique_id = unique_id @@ -75,6 +78,9 @@ class RoborockCoordinatedEntity( self, unique_id: str, coordinator: RoborockDataUpdateCoordinator, + listener_request: list[RoborockDataProtocol] + | RoborockDataProtocol + | None = None, ) -> None: """Initialize the coordinated Roborock Device.""" RoborockEntity.__init__( @@ -85,6 +91,23 @@ class RoborockCoordinatedEntity( ) CoordinatorEntity.__init__(self, coordinator=coordinator) self._attr_unique_id = unique_id + if isinstance(listener_request, RoborockDataProtocol): + listener_request = [listener_request] + self.listener_requests = listener_request or [] + + async def async_added_to_hass(self) -> None: + """Add listeners when the device is added to hass.""" + await super().async_added_to_hass() + for listener_request in self.listener_requests: + self.api.add_listener( + listener_request, self._update_from_listener, cache=self.api.cache + ) + + async def async_will_remove_from_hass(self) -> None: + """Remove listeners when the device is removed from hass.""" + for listener_request in self.listener_requests: + self.api.remove_listener(listener_request, self._update_from_listener) + await super().async_will_remove_from_hass() @property def _device_status(self) -> Status: @@ -107,7 +130,7 @@ class RoborockCoordinatedEntity( await self.coordinator.async_refresh() return res - def _update_from_listener(self, value: Status | Consumable): + def _update_from_listener(self, value: Status | Consumable) -> None: """Update the status or consumable data from a listener and then write the new entity state.""" if isinstance(value, Status): self.coordinator.roborock_device_info.props.status = value diff --git a/homeassistant/components/roborock/select.py b/homeassistant/components/roborock/select.py index ae5dd12689d..3fdd10c97d5 100644 --- a/homeassistant/components/roborock/select.py +++ b/homeassistant/components/roborock/select.py @@ -107,10 +107,8 @@ class RoborockSelectEntity(RoborockCoordinatedEntity, SelectEntity): ) -> None: """Create a select entity.""" self.entity_description = entity_description - super().__init__(unique_id, coordinator) + super().__init__(unique_id, coordinator, entity_description.protocol_listener) self._attr_options = options - if (protocol := self.entity_description.protocol_listener) is not None: - self.api.add_listener(protocol, self._update_from_listener, self.api.cache) async def async_select_option(self, option: str) -> None: """Set the option.""" diff --git a/homeassistant/components/roborock/sensor.py b/homeassistant/components/roborock/sensor.py index d5258879acb..8d723ec57cd 100644 --- a/homeassistant/components/roborock/sensor.py +++ b/homeassistant/components/roborock/sensor.py @@ -232,10 +232,8 @@ class RoborockSensorEntity(RoborockCoordinatedEntity, SensorEntity): description: RoborockSensorDescription, ) -> None: """Initialize the entity.""" - super().__init__(unique_id, coordinator) self.entity_description = description - if (protocol := self.entity_description.protocol_listener) is not None: - self.api.add_listener(protocol, self._update_from_listener, self.api.cache) + super().__init__(unique_id, coordinator, description.protocol_listener) @property def native_value(self) -> StateType | datetime.datetime: diff --git a/homeassistant/components/roborock/vacuum.py b/homeassistant/components/roborock/vacuum.py index 3b8f0e756b7..dafbb731bd2 100644 --- a/homeassistant/components/roborock/vacuum.py +++ b/homeassistant/components/roborock/vacuum.py @@ -92,14 +92,16 @@ class RoborockVacuum(RoborockCoordinatedEntity, StateVacuumEntity): ) -> None: """Initialize a vacuum.""" StateVacuumEntity.__init__(self) - RoborockCoordinatedEntity.__init__(self, unique_id, coordinator) + RoborockCoordinatedEntity.__init__( + self, + unique_id, + coordinator, + listener_request=[ + RoborockDataProtocol.FAN_POWER, + RoborockDataProtocol.STATE, + ], + ) self._attr_fan_speed_list = self._device_status.fan_power_options - self.api.add_listener( - RoborockDataProtocol.FAN_POWER, self._update_from_listener, self.api.cache - ) - self.api.add_listener( - RoborockDataProtocol.STATE, self._update_from_listener, self.api.cache - ) @property def state(self) -> str | None: diff --git a/tests/components/roborock/test_init.py b/tests/components/roborock/test_init.py index 608263a3496..e037d81786e 100644 --- a/tests/components/roborock/test_init.py +++ b/tests/components/roborock/test_init.py @@ -18,7 +18,7 @@ async def test_unload_entry( assert len(hass.config_entries.async_entries(DOMAIN)) == 1 assert setup_entry.state is ConfigEntryState.LOADED with patch( - "homeassistant.components.roborock.coordinator.RoborockLocalClient.async_disconnect" + "homeassistant.components.roborock.coordinator.RoborockLocalClient.async_release" ) as mock_disconnect: assert await hass.config_entries.async_unload(setup_entry.entry_id) await hass.async_block_till_done()