diff --git a/homeassistant/components/tasmota/__init__.py b/homeassistant/components/tasmota/__init__.py index 2d664bb46ee..44dd2489177 100644 --- a/homeassistant/components/tasmota/__init__.py +++ b/homeassistant/components/tasmota/__init__.py @@ -26,6 +26,7 @@ from homeassistant.components.mqtt.subscription import ( from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.config_entries import ConfigEntry from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.helpers import device_registry as dr from homeassistant.helpers.device_registry import ( CONNECTION_NETWORK_MAC, EVENT_DEVICE_REGISTRY_UPDATED, @@ -72,7 +73,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: tasmota_mqtt = TasmotaMQTTClient(_publish, _subscribe_topics, _unsubscribe_topics) - device_registry = await hass.helpers.device_registry.async_get_registry() + device_registry = dr.async_get(hass) async def async_discover_device(config: TasmotaDeviceConfig, mac: str) -> None: """Discover and add a Tasmota device.""" @@ -80,25 +81,40 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass, mac, config, entry, tasmota_mqtt, device_registry ) - async def async_device_removed(event: Event) -> None: + async def async_device_updated(event: Event) -> None: """Handle the removal of a device.""" - device_registry = await hass.helpers.device_registry.async_get_registry() - if event.data["action"] != "remove": + device_registry = dr.async_get(hass) + device_id = event.data["device_id"] + if event.data["action"] not in ("remove", "update"): return - device = device_registry.deleted_devices[event.data["device_id"]] + connections: set[tuple[str, str]] + if event.data["action"] == "update": + if "config_entries" not in event.data["changes"]: + return - if entry.entry_id not in device.config_entries: - return + device = device_registry.async_get(device_id) + if not device: + # The device is already removed, do cleanup when we get "remove" event + return + if entry.entry_id in device.config_entries: + # Not removed from device + return + connections = device.connections + else: + deleted_device = device_registry.deleted_devices[event.data["device_id"]] + connections = deleted_device.connections + if entry.entry_id not in deleted_device.config_entries: + return - macs = [c[1] for c in device.connections if c[0] == CONNECTION_NETWORK_MAC] + macs = [c[1] for c in connections if c[0] == CONNECTION_NETWORK_MAC] for mac in macs: await clear_discovery_topic( mac, entry.data[CONF_DISCOVERY_PREFIX], tasmota_mqtt ) hass.data[DATA_UNSUB].append( - hass.bus.async_listen(EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed) + hass.bus.async_listen(EVENT_DEVICE_REGISTRY_UPDATED, async_device_updated) ) async def start_platforms() -> None: @@ -138,7 +154,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data.pop(DATA_REMOVE_DISCOVER_COMPONENT.format(platform))() # deattach device triggers - device_registry = await hass.helpers.device_registry.async_get_registry() + device_registry = dr.async_get(hass) devices = async_entries_for_config_entry(device_registry, entry.entry_id) for device in devices: await device_automation.async_remove_automations(hass, device.id) @@ -156,11 +172,13 @@ async def _remove_device( """Remove device from device registry.""" device = device_registry.async_get_device(set(), {(CONNECTION_NETWORK_MAC, mac)}) - if device is None: + if device is None or config_entry.entry_id not in device.config_entries: return - _LOGGER.debug("Removing tasmota device %s", mac) - device_registry.async_remove_device(device.id) + _LOGGER.debug("Removing tasmota from device %s", mac) + device_registry.async_update_device( + device.id, remove_config_entry_id=config_entry.entry_id + ) await clear_discovery_topic( mac, config_entry.data[CONF_DISCOVERY_PREFIX], tasmota_mqtt ) @@ -203,13 +221,13 @@ async def async_setup_device( @websocket_api.websocket_command( {vol.Required("type"): "tasmota/device/remove", vol.Required("device_id"): str} ) -@websocket_api.async_response -async def websocket_remove_device( +@callback +def websocket_remove_device( hass: HomeAssistant, connection: ActiveConnection, msg: dict ) -> None: """Delete device.""" device_id = msg["device_id"] - dev_registry = await hass.helpers.device_registry.async_get_registry() + dev_registry = dr.async_get(hass) if not (device := dev_registry.async_get(device_id)): connection.send_error( @@ -217,8 +235,9 @@ async def websocket_remove_device( ) return - for config_entry in device.config_entries: - config_entry = hass.config_entries.async_get_entry(config_entry) + for config_entry_id in device.config_entries: + config_entry = hass.config_entries.async_get_entry(config_entry_id) + assert config_entry # Only delete the device if it belongs to a Tasmota device entry if config_entry.domain == DOMAIN: dev_registry.async_remove_device(device_id) @@ -228,3 +247,11 @@ async def websocket_remove_device( connection.send_error( msg["id"], websocket_api.const.ERR_NOT_FOUND, "Non Tasmota device" ) + + +async def async_remove_config_entry_device( + hass: HomeAssistant, config_entry: ConfigEntry, device_entry: dr.DeviceEntry +) -> bool: + """Remove Tasmota config entry from a device.""" + # Just return True, cleanup is done on when handling device registry events + return True diff --git a/homeassistant/components/tasmota/device_trigger.py b/homeassistant/components/tasmota/device_trigger.py index 61efbb76e23..aca5a2848e3 100644 --- a/homeassistant/components/tasmota/device_trigger.py +++ b/homeassistant/components/tasmota/device_trigger.py @@ -20,7 +20,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.typing import ConfigType @@ -220,7 +220,7 @@ async def async_setup_trigger( hass, TASMOTA_DISCOVERY_ENTITY_UPDATED.format(*discovery_hash), discovery_update ) - device_registry = await hass.helpers.device_registry.async_get_registry() + device_registry = dr.async_get(hass) device = device_registry.async_get_device( set(), {(CONNECTION_NETWORK_MAC, tasmota_trigger.cfg.mac)}, diff --git a/homeassistant/components/tasmota/discovery.py b/homeassistant/components/tasmota/discovery.py index 67aea199fe4..da9e809bd8b 100644 --- a/homeassistant/components/tasmota/discovery.py +++ b/homeassistant/components/tasmota/discovery.py @@ -21,7 +21,7 @@ from hatasmota.sensor import TasmotaBaseSensorConfig from homeassistant.components import sensor from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from homeassistant.helpers import device_registry as dev_reg +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity_registry import async_entries_for_device @@ -61,7 +61,7 @@ async def async_start( ) -> None: """Start Tasmota device discovery.""" - async def _discover_entity( + def _discover_entity( tasmota_entity_config: TasmotaEntityConfig | None, discovery_hash: DiscoveryHashType, platform: str, @@ -69,7 +69,7 @@ async def async_start( """Handle adding or updating a discovered entity.""" if not tasmota_entity_config: # Entity disabled, clean up entity registry - entity_registry = await hass.helpers.entity_registry.async_get_registry() + entity_registry = er.async_get(hass) unique_id = unique_id_from_hash(discovery_hash) entity_id = entity_registry.async_get_entity_id(platform, DOMAIN, unique_id) if entity_id: @@ -158,7 +158,7 @@ async def async_start( for platform in PLATFORMS: tasmota_entities = tasmota_get_entities_for_platform(payload, platform) for (tasmota_entity_config, discovery_hash) in tasmota_entities: - await _discover_entity(tasmota_entity_config, discovery_hash, platform) + _discover_entity(tasmota_entity_config, discovery_hash, platform) async def async_sensors_discovered( sensors: list[tuple[TasmotaBaseSensorConfig, DiscoveryHashType]], mac: str @@ -166,10 +166,10 @@ async def async_start( """Handle discovery of (additional) sensors.""" platform = sensor.DOMAIN - device_registry = await hass.helpers.device_registry.async_get_registry() - entity_registry = await hass.helpers.entity_registry.async_get_registry() + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) device = device_registry.async_get_device( - set(), {(dev_reg.CONNECTION_NETWORK_MAC, mac)} + set(), {(dr.CONNECTION_NETWORK_MAC, mac)} ) if device is None: @@ -186,7 +186,7 @@ async def async_start( for (tasmota_sensor_config, discovery_hash) in sensors: if tasmota_sensor_config: orphaned_entities.discard(tasmota_sensor_config.unique_id) - await _discover_entity(tasmota_sensor_config, discovery_hash, platform) + _discover_entity(tasmota_sensor_config, discovery_hash, platform) for unique_id in orphaned_entities: entity_id = entity_registry.async_get_entity_id(platform, DOMAIN, unique_id) if entity_id: diff --git a/tests/components/tasmota/test_discovery.py b/tests/components/tasmota/test_discovery.py index 713d0f5ae67..90ca5d918fd 100644 --- a/tests/components/tasmota/test_discovery.py +++ b/tests/components/tasmota/test_discovery.py @@ -10,7 +10,7 @@ from homeassistant.helpers import device_registry as dr from .conftest import setup_tasmota_helper from .test_common import DEFAULT_CONFIG, DEFAULT_CONFIG_9_0_0_3 -from tests.common import async_fire_mqtt_message +from tests.common import MockConfigEntry, async_fire_mqtt_message async def test_subscribing_config_topic(hass, mqtt_mock, setup_tasmota): @@ -261,6 +261,111 @@ async def test_device_remove( assert device_entry is None +async def test_device_remove_multiple_config_entries_1( + hass, mqtt_mock, caplog, device_reg, entity_reg, setup_tasmota +): + """Test removing a discovered device.""" + config = copy.deepcopy(DEFAULT_CONFIG) + mac = config["mac"] + + mock_entry = MockConfigEntry(domain="test") + mock_entry.add_to_hass(hass) + + device_reg.async_get_or_create( + config_entry_id=mock_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, mac)}, + ) + + tasmota_entry = hass.config_entries.async_entries("tasmota")[0] + + async_fire_mqtt_message( + hass, + f"{DEFAULT_PREFIX}/{mac}/config", + json.dumps(config), + ) + await hass.async_block_till_done() + + # Verify device entry is created + device_entry = device_reg.async_get_device( + set(), {(dr.CONNECTION_NETWORK_MAC, mac)} + ) + assert device_entry is not None + assert device_entry.config_entries == {tasmota_entry.entry_id, mock_entry.entry_id} + + async_fire_mqtt_message( + hass, + f"{DEFAULT_PREFIX}/{mac}/config", + "", + ) + await hass.async_block_till_done() + + # Verify device entry is not removed + device_entry = device_reg.async_get_device( + set(), {(dr.CONNECTION_NETWORK_MAC, mac)} + ) + assert device_entry is not None + assert device_entry.config_entries == {mock_entry.entry_id} + + +async def test_device_remove_multiple_config_entries_2( + hass, mqtt_mock, caplog, device_reg, entity_reg, setup_tasmota +): + """Test removing a discovered device.""" + config = copy.deepcopy(DEFAULT_CONFIG) + mac = config["mac"] + + mock_entry = MockConfigEntry(domain="test") + mock_entry.add_to_hass(hass) + + device_reg.async_get_or_create( + config_entry_id=mock_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, mac)}, + ) + + other_device_entry = device_reg.async_get_or_create( + config_entry_id=mock_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "other_device")}, + ) + + tasmota_entry = hass.config_entries.async_entries("tasmota")[0] + + async_fire_mqtt_message( + hass, + f"{DEFAULT_PREFIX}/{mac}/config", + json.dumps(config), + ) + await hass.async_block_till_done() + + # Verify device entry is created + device_entry = device_reg.async_get_device( + set(), {(dr.CONNECTION_NETWORK_MAC, mac)} + ) + assert device_entry is not None + assert device_entry.config_entries == {tasmota_entry.entry_id, mock_entry.entry_id} + assert other_device_entry.id != device_entry.id + + # Remove other config entry from the device + device_reg.async_update_device( + device_entry.id, remove_config_entry_id=mock_entry.entry_id + ) + await hass.async_block_till_done() + + # Verify device entry is not removed + device_entry = device_reg.async_get_device( + set(), {(dr.CONNECTION_NETWORK_MAC, mac)} + ) + assert device_entry is not None + assert device_entry.config_entries == {tasmota_entry.entry_id} + mqtt_mock.async_publish.assert_not_called() + + # Remove other config entry from the other device - Tasmota should not do any cleanup + device_reg.async_update_device( + other_device_entry.id, remove_config_entry_id=mock_entry.entry_id + ) + await hass.async_block_till_done() + mqtt_mock.async_publish.assert_not_called() + + async def test_device_remove_stale(hass, mqtt_mock, caplog, device_reg, setup_tasmota): """Test removing a stale (undiscovered) device does not throw.""" mac = "00000049A3BC"