From 3b2aae5045f9f08dc8f174c5d975852588e1a132 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Fri, 15 Apr 2022 12:35:08 +0200 Subject: [PATCH] Refactor MQTT discovery (#67966) * Proof of concept * remove notify platform * remove loose test * Add rework from #67912 (#1) * Move notify serviceupdater to Mixins * Move tag discovery handler to Mixins * fix tests * Add typing for async_load_platform_helper * Add add entry unload support for notify platform * Simplify discovery updates * Remove not needed extra logic * Cleanup inrelevant or duplicate code * reuse update_device and move to mixins * Remove notify platform * revert changes to notify platform * Rename update class * unify tag entry setup * Use shared code for device_trigger `update_device` * PoC shared dispatcher for device_trigger * Fix bugs * Improve typing - remove async_update * Unload config_entry and tests * Release dispatcher after setup and deduplicate * closures to methods, revert `in` to `=`, updates * Re-add update support for tag platform * Re-add update support for device-trigger platform * Cleanup rediscovery code revert related changes * Undo discovery code shift * Update homeassistant/components/mqtt/mixins.py Co-authored-by: Erik Montnemery * Update homeassistant/components/mqtt/device_trigger.py Co-authored-by: Erik Montnemery * Update homeassistant/components/mqtt/mixins.py Co-authored-by: Erik Montnemery * revert doc string changes * move conditions * typing and check config_entry_id * Update homeassistant/components/mqtt/mixins.py Co-authored-by: Erik Montnemery * cleanup not used attribute * Remove entry_unload code and tests * update comment * add second comment Co-authored-by: Erik Montnemery --- .../components/mqtt/device_trigger.py | 218 ++++++++--------- homeassistant/components/mqtt/discovery.py | 17 +- homeassistant/components/mqtt/mixins.py | 224 +++++++++++++++--- homeassistant/components/mqtt/models.py | 2 +- homeassistant/components/mqtt/tag.py | 193 +++++---------- 5 files changed, 366 insertions(+), 288 deletions(-) diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index 71c0a9f9364..56cfc3efc6b 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Callable import logging -from typing import Any +from typing import Any, cast import attr import voluptuous as vol @@ -13,6 +13,7 @@ from homeassistant.components.automation import ( AutomationTriggerInfo, ) from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_DEVICE, CONF_DEVICE_ID, @@ -23,30 +24,19 @@ from homeassistant.const import ( ) from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import config_validation as cv, device_registry as dr -from homeassistant.helpers.dispatcher import ( - async_dispatcher_connect, - async_dispatcher_send, -) +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import ConfigType from . import debug_info, trigger as mqtt_trigger from .. import mqtt -from .const import ( - ATTR_DISCOVERY_HASH, - ATTR_DISCOVERY_TOPIC, - CONF_PAYLOAD, - CONF_QOS, - CONF_TOPIC, - DOMAIN, -) -from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash +from .const import ATTR_DISCOVERY_HASH, CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, DOMAIN +from .discovery import MQTT_DISCOVERY_DONE from .mixins import ( - CONF_CONNECTIONS, - CONF_IDENTIFIERS, MQTT_ENTITY_DEVICE_INFO_SCHEMA, - cleanup_device_registry, - device_info_from_config, + MqttDiscoveryDeviceUpdate, + send_discovery_done, + update_device, ) _LOGGER = logging.getLogger(__name__) @@ -89,6 +79,8 @@ TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( DEVICE_TRIGGERS = "mqtt_device_triggers" +LOG_NAME = "Device trigger" + @attr.s(slots=True) class TriggerInstance: @@ -99,7 +91,7 @@ class TriggerInstance: trigger: Trigger = attr.ib() remove: CALLBACK_TYPE | None = attr.ib(default=None) - async def async_attach_trigger(self): + async def async_attach_trigger(self) -> None: """Attach MQTT trigger.""" mqtt_config = { mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN, @@ -132,14 +124,15 @@ class Trigger: hass: HomeAssistant = attr.ib() payload: str | None = attr.ib() qos: int | None = attr.ib() - remove_signal: Callable[[], None] | None = attr.ib() subtype: str = attr.ib() topic: str | None = attr.ib() type: str = attr.ib() value_template: str | None = attr.ib() trigger_instances: list[TriggerInstance] = attr.ib(factory=list) - async def add_trigger(self, action, automation_info): + async def add_trigger( + self, action: AutomationActionType, automation_info: AutomationTriggerInfo + ) -> Callable: """Add MQTT trigger.""" instance = TriggerInstance(action, automation_info, self) self.trigger_instances.append(instance) @@ -160,9 +153,8 @@ class Trigger: return async_remove - async def update_trigger(self, config, discovery_hash, remove_signal): + async def update_trigger(self, config: ConfigType) -> None: """Update MQTT device trigger.""" - self.remove_signal = remove_signal self.type = config[CONF_TYPE] self.subtype = config[CONF_SUBTYPE] self.payload = config[CONF_PAYLOAD] @@ -178,7 +170,7 @@ class Trigger: for trig in self.trigger_instances: await trig.async_attach_trigger() - def detach_trigger(self): + def detach_trigger(self) -> None: """Remove MQTT device trigger.""" # Mark trigger as unknown self.topic = None @@ -190,110 +182,110 @@ class Trigger: trig.remove = None -def _update_device(hass, config_entry, config): - """Update device registry.""" - device_registry = dr.async_get(hass) - config_entry_id = config_entry.entry_id - device_info = device_info_from_config(config[CONF_DEVICE]) +class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): + """Setup a MQTT device trigger with auto discovery.""" - if config_entry_id is not None and device_info is not None: - device_info["config_entry_id"] = config_entry_id - device_registry.async_get_or_create(**device_info) + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + device_id: str, + discovery_data: dict, + config_entry: ConfigEntry, + ) -> None: + """Initialize.""" + self._config = config + self._config_entry = config_entry + self.device_id = device_id + self.discovery_data = discovery_data + self.hass = hass + + MqttDiscoveryDeviceUpdate.__init__( + self, + hass, + discovery_data, + device_id, + config_entry, + LOG_NAME, + ) + + async def async_setup(self) -> None: + """Initialize the device trigger.""" + discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] + discovery_id = discovery_hash[1] + if discovery_id not in self.hass.data.setdefault(DEVICE_TRIGGERS, {}): + self.hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( + hass=self.hass, + device_id=self.device_id, + discovery_data=self.discovery_data, + type=self._config[CONF_TYPE], + subtype=self._config[CONF_SUBTYPE], + topic=self._config[CONF_TOPIC], + payload=self._config[CONF_PAYLOAD], + qos=self._config[CONF_QOS], + value_template=self._config[CONF_VALUE_TEMPLATE], + ) + else: + await self.hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( + self._config + ) + debug_info.add_trigger_discovery_data( + self.hass, discovery_hash, self.discovery_data, self.device_id + ) + + async def async_update(self, discovery_data: dict) -> None: + """Handle MQTT device trigger discovery updates.""" + discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] + discovery_id = discovery_hash[1] + debug_info.update_trigger_discovery_data( + self.hass, discovery_hash, discovery_data + ) + config = TRIGGER_DISCOVERY_SCHEMA(discovery_data) + update_device(self.hass, self._config_entry, config) + device_trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id] + await device_trigger.update_trigger(config) + + async def async_tear_down(self) -> None: + """Cleanup device trigger.""" + discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] + discovery_id = discovery_hash[1] + if discovery_id in self.hass.data[DEVICE_TRIGGERS]: + _LOGGER.info("Removing trigger: %s", discovery_hash) + trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id] + trigger.detach_trigger() + debug_info.remove_trigger_discovery_data(self.hass, discovery_hash) -async def async_setup_trigger(hass, config, config_entry, discovery_data): +async def async_setup_trigger( + hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict +) -> None: """Set up the MQTT device trigger.""" config = TRIGGER_DISCOVERY_SCHEMA(config) discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] - discovery_id = discovery_hash[1] - remove_signal = None - async def discovery_update(payload): - """Handle discovery update.""" - _LOGGER.info( - "Got update for trigger with hash: %s '%s'", discovery_hash, payload - ) - if not payload: - # Empty payload: Remove trigger - _LOGGER.info("Removing trigger: %s", discovery_hash) - debug_info.remove_trigger_discovery_data(hass, discovery_hash) - if discovery_id in hass.data[DEVICE_TRIGGERS]: - device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] - device_trigger.detach_trigger() - clear_discovery_hash(hass, discovery_hash) - remove_signal() - await cleanup_device_registry(hass, device.id, config_entry.entry_id) - else: - # Non-empty payload: Update trigger - _LOGGER.info("Updating trigger: %s", discovery_hash) - debug_info.update_trigger_discovery_data(hass, discovery_hash, payload) - config = TRIGGER_DISCOVERY_SCHEMA(payload) - _update_device(hass, config_entry, config) - device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] - await device_trigger.update_trigger(config, discovery_hash, remove_signal) - async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) - - remove_signal = async_dispatcher_connect( - hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update - ) - - _update_device(hass, config_entry, config) - - device_registry = dr.async_get(hass) - device = device_registry.async_get_device( - {(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]}, - {tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]}, - ) - - if device is None: + if (device_id := update_device(hass, config_entry, config)) is None: async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) return - if DEVICE_TRIGGERS not in hass.data: - hass.data[DEVICE_TRIGGERS] = {} - if discovery_id not in hass.data[DEVICE_TRIGGERS]: - hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( - hass=hass, - device_id=device.id, - discovery_data=discovery_data, - type=config[CONF_TYPE], - subtype=config[CONF_SUBTYPE], - topic=config[CONF_TOPIC], - payload=config[CONF_PAYLOAD], - qos=config[CONF_QOS], - remove_signal=remove_signal, - value_template=config[CONF_VALUE_TEMPLATE], - ) - else: - await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( - config, discovery_hash, remove_signal - ) - debug_info.add_trigger_discovery_data( - hass, discovery_hash, discovery_data, device.id + mqtt_device_trigger = MqttDeviceTrigger( + hass, config, device_id, discovery_data, config_entry ) - - async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) + await mqtt_device_trigger.async_setup() + send_discovery_done(hass, discovery_data) -async def async_removed_from_device(hass: HomeAssistant, device_id: str): +async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: """Handle Mqtt removed from a device.""" triggers = await async_get_triggers(hass, device_id) for trig in triggers: - device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID]) + device_trigger: Trigger = hass.data[DEVICE_TRIGGERS].pop( + trig[CONF_DISCOVERY_ID] + ) if device_trigger: - discovery_hash = device_trigger.discovery_data[ATTR_DISCOVERY_HASH] - discovery_topic = device_trigger.discovery_data[ATTR_DISCOVERY_TOPIC] - - debug_info.remove_trigger_discovery_data(hass, discovery_hash) device_trigger.detach_trigger() - clear_discovery_hash(hass, discovery_hash) - device_trigger.remove_signal() - mqtt.publish( - hass, - discovery_topic, - "", - retain=True, - ) + discovery_data = cast(dict, device_trigger.discovery_data) + discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] + debug_info.remove_trigger_discovery_data(hass, discovery_hash) async def async_get_triggers( @@ -328,8 +320,7 @@ async def async_attach_trigger( automation_info: AutomationTriggerInfo, ) -> CALLBACK_TYPE: """Attach a trigger.""" - if DEVICE_TRIGGERS not in hass.data: - hass.data[DEVICE_TRIGGERS] = {} + hass.data.setdefault(DEVICE_TRIGGERS, {}) device_id = config[CONF_DEVICE_ID] discovery_id = config[CONF_DISCOVERY_ID] @@ -338,7 +329,6 @@ async def async_attach_trigger( hass=hass, device_id=device_id, discovery_data=None, - remove_signal=None, type=config[CONF_TYPE], subtype=config[CONF_SUBTYPE], topic=None, diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 11bc0f6839a..fae443dc411 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -1,4 +1,6 @@ """Support for MQTT discovery.""" +from __future__ import annotations + import asyncio from collections import deque import functools @@ -73,20 +75,22 @@ LAST_DISCOVERY = "mqtt_last_discovery" TOPIC_BASE = "~" -def clear_discovery_hash(hass, discovery_hash): +class MQTTConfig(dict): + """Dummy class to allow adding attributes.""" + + discovery_data: dict + + +def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple) -> None: """Clear entry in ALREADY_DISCOVERED list.""" del hass.data[ALREADY_DISCOVERED][discovery_hash] -def set_discovery_hash(hass, discovery_hash): +def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple): """Clear entry in ALREADY_DISCOVERED list.""" hass.data[ALREADY_DISCOVERED][discovery_hash] = {} -class MQTTConfig(dict): - """Dummy class to allow adding attributes.""" - - async def async_start( # noqa: C901 hass: HomeAssistant, discovery_topic, config_entry=None ) -> None: @@ -181,6 +185,7 @@ async def async_start( # noqa: C901 await async_process_discovery_payload(component, discovery_id, payload) async def async_process_discovery_payload(component, discovery_id, payload): + """Process the payload of a new discovery.""" _LOGGER.debug("Process discovery payload %s", payload) discovery_hash = (component, discovery_id) diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index bf3431c5324..43f75f08459 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -5,7 +5,7 @@ from abc import abstractmethod from collections.abc import Callable import json import logging -from typing import Any, Protocol, final +from typing import Any, Protocol, cast, final import voluptuous as vol @@ -32,6 +32,7 @@ from homeassistant.helpers import ( device_registry as dr, entity_registry as er, ) +from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, @@ -45,7 +46,7 @@ from homeassistant.helpers.entity import ( ) from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.reload import async_setup_reload_service -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import ( DATA_MQTT, @@ -496,8 +497,10 @@ class MqttAvailability(Entity): return self._available_latest -async def cleanup_device_registry(hass, device_id, config_entry_id): - """Remove device registry entry if there are no remaining entities or triggers.""" +async def cleanup_device_registry( + hass: HomeAssistant, device_id: str | None, config_entry_id: str | None +) -> None: + """Remove MQTT from the device registry entry if there are no remaining entities, triggers or tags.""" # Local import to avoid circular dependencies # pylint: disable-next=import-outside-toplevel from . import device_trigger, tag @@ -506,6 +509,7 @@ async def cleanup_device_registry(hass, device_id, config_entry_id): entity_registry = er.async_get(hass) if ( device_id + and config_entry_id and not er.async_entries_for_device( entity_registry, device_id, include_disabled_entities=False ) @@ -517,14 +521,163 @@ async def cleanup_device_registry(hass, device_id, config_entry_id): ) +def get_discovery_hash(discovery_data: dict) -> tuple: + """Get the discovery hash from the discovery data.""" + return discovery_data[ATTR_DISCOVERY_HASH] + + +def send_discovery_done(hass: HomeAssistant, discovery_data: dict) -> None: + """Acknowledge a discovery message has been handled.""" + discovery_hash = get_discovery_hash(discovery_data) + async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) + + +def stop_discovery_updates( + hass: HomeAssistant, + discovery_data: dict, + remove_discovery_updated: Callable[[], None] | None = None, +) -> None: + """Stop discovery updates of being sent.""" + if remove_discovery_updated: + remove_discovery_updated() + remove_discovery_updated = None + discovery_hash = get_discovery_hash(discovery_data) + clear_discovery_hash(hass, discovery_hash) + + +async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: dict): + """Clear retained discovery topic in broker to avoid rediscovery after a restart of HA.""" + discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC] + await async_publish(hass, discovery_topic, "", retain=True) + + +class MqttDiscoveryDeviceUpdate: + """Add support for auto discovery for platforms without an entity.""" + + def __init__( + self, + hass: HomeAssistant, + discovery_data: dict, + device_id: str | None, + config_entry: ConfigEntry, + log_name: str, + ) -> None: + """Initialize the update service.""" + + self.hass = hass + self.log_name = log_name + + self._discovery_data = discovery_data + self._device_id = device_id + self._config_entry = config_entry + self._config_entry_id = config_entry.entry_id + self._skip_device_removal: bool = False + + discovery_hash = get_discovery_hash(discovery_data) + self._remove_discovery_updated = async_dispatcher_connect( + hass, + MQTT_DISCOVERY_UPDATED.format(discovery_hash), + self.async_discovery_update, + ) + if device_id is not None: + self._remove_device_updated = hass.bus.async_listen( + EVENT_DEVICE_REGISTRY_UPDATED, self._async_device_removed + ) + _LOGGER.info( + "%s %s has been initialized", + self.log_name, + discovery_hash, + ) + + async def async_discovery_update( + self, + discovery_payload: DiscoveryInfoType | None, + ) -> None: + """Handle discovery update.""" + discovery_hash = get_discovery_hash(self._discovery_data) + _LOGGER.info( + "Got update for %s with hash: %s '%s'", + self.log_name, + discovery_hash, + discovery_payload, + ) + if ( + discovery_payload + and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD] + ): + _LOGGER.info( + "%s %s updating", + self.log_name, + discovery_hash, + ) + await self.async_update(discovery_payload) + if not discovery_payload: + # Unregister and clean up the current discovery instance + stop_discovery_updates( + self.hass, self._discovery_data, self._remove_discovery_updated + ) + await self._async_tear_down() + send_discovery_done(self.hass, self._discovery_data) + _LOGGER.info( + "%s %s has been removed", + self.log_name, + discovery_hash, + ) + else: + # Normal update without change + send_discovery_done(self.hass, self._discovery_data) + _LOGGER.info( + "%s %s no changes", + self.log_name, + discovery_hash, + ) + return + + async def _async_device_removed(self, event: Event) -> None: + """Handle the manual removal of a device.""" + if self._skip_device_removal or not async_removed_from_device( + self.hass, event, cast(str, self._device_id), self._config_entry_id + ): + return + # Prevent a second cleanup round after the device is removed + self._remove_device_updated() + self._skip_device_removal = True + # Unregister and clean up and publish an empty payload + # so the service is not rediscovered after a restart + stop_discovery_updates( + self.hass, self._discovery_data, self._remove_discovery_updated + ) + await self._async_tear_down() + await async_remove_discovery_payload(self.hass, self._discovery_data) + + async def _async_tear_down(self) -> None: + """Handle the cleanup of the discovery service.""" + # Cleanup platform resources + await self.async_tear_down() + # remove the service for auto discovery updates and clean up the device registry + if not self._skip_device_removal: + # Prevent a second cleanup round after the device is removed + self._skip_device_removal = True + await cleanup_device_registry( + self.hass, self._device_id, self._config_entry_id + ) + + async def async_update(self, discovery_data: dict) -> None: + """Handle the update of platform specific parts, extend to the platform.""" + + @abstractmethod + async def async_tear_down(self) -> None: + """Handle the cleanup of platform specific parts, extend to the platform.""" + + class MqttDiscoveryUpdate(Entity): - """Mixin used to handle updated discovery message.""" + """Mixin used to handle updated discovery message for entity based platforms.""" def __init__(self, discovery_data, discovery_update=None) -> None: """Initialize the discovery update mixin.""" self._discovery_data = discovery_data self._discovery_update = discovery_update - self._remove_signal: Callable | None = None + self._remove_discovery_updated: Callable | None = None self._removed_from_hass = False async def async_added_to_hass(self) -> None: @@ -572,7 +725,7 @@ class MqttDiscoveryUpdate(Entity): else: # Non-empty, unchanged payload: Ignore to avoid changing states _LOGGER.info("Ignoring unchanged update for: %s", self.entity_id) - self.async_send_discovery_done() + send_discovery_done(self.hass, self._discovery_data) if discovery_hash: debug_info.add_entity_discovery_data( @@ -580,24 +733,12 @@ class MqttDiscoveryUpdate(Entity): ) # Set in case the entity has been removed and is re-added, for example when changing entity_id set_discovery_hash(self.hass, discovery_hash) - self._remove_signal = async_dispatcher_connect( + self._remove_discovery_updated = async_dispatcher_connect( self.hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_callback, ) - @callback - def async_send_discovery_done(self) -> None: - """Acknowledge a discovery message has been handled.""" - discovery_hash = ( - self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None - ) - if not discovery_hash: - return - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None - ) - async def async_removed_from_registry(self) -> None: """Clear retained discovery topic in broker.""" if not self._removed_from_hass: @@ -606,18 +747,14 @@ class MqttDiscoveryUpdate(Entity): self._cleanup_discovery_on_remove() # Clear the discovery topic so the entity is not rediscovered after a restart - discovery_topic = self._discovery_data[ATTR_DISCOVERY_TOPIC] - await async_publish(self.hass, discovery_topic, "", retain=True) + await async_remove_discovery_payload(self.hass, self._discovery_data) @callback def add_to_platform_abort(self) -> None: """Abort adding an entity to a platform.""" if self._discovery_data: - discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH] - clear_discovery_hash(self.hass, discovery_hash) - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None - ) + stop_discovery_updates(self.hass, self._discovery_data) + send_discovery_done(self.hass, self._discovery_data) super().add_to_platform_abort() async def async_will_remove_from_hass(self) -> None: @@ -627,13 +764,11 @@ class MqttDiscoveryUpdate(Entity): def _cleanup_discovery_on_remove(self) -> None: """Stop listening to signal and cleanup discovery data.""" if self._discovery_data and not self._removed_from_hass: - clear_discovery_hash(self.hass, self._discovery_data[ATTR_DISCOVERY_HASH]) + stop_discovery_updates( + self.hass, self._discovery_data, self._remove_discovery_updated + ) self._removed_from_hass = True - if self._remove_signal: - self._remove_signal() - self._remove_signal = None - def device_info_from_config(config) -> DeviceInfo | None: """Return a device description for device registry.""" @@ -737,7 +872,8 @@ class MqttEntity( self._prepare_subscribe_topics() await self._subscribe_topics() await self.mqtt_async_added_to_hass() - self.async_send_discovery_done() + if self._discovery_data is not None: + send_discovery_done(self.hass, self._discovery_data) async def mqtt_async_added_to_hass(self): """Call before the discovery message is acknowledged. @@ -839,6 +975,28 @@ class MqttEntity( return self._unique_id +def update_device( + hass: HomeAssistant, + config_entry: ConfigEntry, + config: ConfigType, +) -> str | None: + """Update device registry.""" + if CONF_DEVICE not in config: + return None + + device = None + device_registry = dr.async_get(hass) + config_entry_id = config_entry.entry_id + device_info = device_info_from_config(config[CONF_DEVICE]) + + if config_entry_id is not None and device_info is not None: + update_device_info = cast(dict, device_info) + update_device_info["config_entry_id"] = config_entry_id + device = device_registry.async_get_or_create(**update_device_info) + + return device.id if device else None + + @callback def async_removed_from_device( hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index f5a0270481e..9cec65d7254 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -1,4 +1,4 @@ -"""Modesl used by multiple MQTT modules.""" +"""Models used by multiple MQTT modules.""" from __future__ import annotations from collections.abc import Awaitable, Callable diff --git a/homeassistant/components/mqtt/tag.py b/homeassistant/components/mqtt/tag.py index a2541c064c0..5bfbbd73bce 100644 --- a/homeassistant/components/mqtt/tag.py +++ b/homeassistant/components/mqtt/tag.py @@ -1,40 +1,31 @@ """Provides tag scanning for MQTT.""" +from __future__ import annotations + import functools -import logging import voluptuous as vol +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_DEVICE, CONF_PLATFORM, CONF_VALUE_TEMPLATE -from homeassistant.helpers import device_registry as dr +from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED -from homeassistant.helpers.dispatcher import ( - async_dispatcher_connect, - async_dispatcher_send, -) +from homeassistant.helpers.typing import ConfigType from . import MqttValueTemplate, subscription from .. import mqtt -from .const import ( - ATTR_DISCOVERY_HASH, - ATTR_DISCOVERY_TOPIC, - CONF_QOS, - CONF_TOPIC, - DOMAIN, -) -from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash +from .const import ATTR_DISCOVERY_HASH, CONF_QOS, CONF_TOPIC from .mixins import ( - CONF_CONNECTIONS, - CONF_IDENTIFIERS, MQTT_ENTITY_DEVICE_INFO_SCHEMA, - async_removed_from_device, + MqttDiscoveryDeviceUpdate, async_setup_entry_helper, - cleanup_device_registry, - device_info_from_config, + send_discovery_done, + update_device, ) +from .models import ReceiveMessage +from .subscription import EntitySubscription from .util import valid_subscribe_topic -_LOGGER = logging.getLogger(__name__) +LOG_NAME = "Tag" TAG = "tag" TAGS = "mqtt_tags" @@ -50,35 +41,27 @@ PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( ) -async def async_setup_entry(hass, config_entry): - """Set up MQTT tag scan dynamically through MQTT discovery.""" - setup = functools.partial(async_setup_tag, hass, config_entry=config_entry) - await async_setup_entry_helper(hass, "tag", setup, PLATFORM_SCHEMA) +async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None: + """Set up MQTT device automation dynamically through MQTT discovery.""" + + setup = functools.partial(_async_setup_tag, hass, config_entry=config_entry) + await async_setup_entry_helper(hass, TAG, setup, PLATFORM_SCHEMA) -async def async_setup_tag(hass, config, config_entry, discovery_data): +async def _async_setup_tag( + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: dict, +) -> None: """Set up the MQTT tag scanner.""" discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] - device_id = None - if CONF_DEVICE in config: - _update_device(hass, config_entry, config) - - device_registry = dr.async_get(hass) - device = device_registry.async_get_device( - {(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]}, - {tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]}, - ) - - if device is None: - return - device_id = device.id - - if TAGS not in hass.data: - hass.data[TAGS] = {} - if device_id not in hass.data[TAGS]: - hass.data[TAGS][device_id] = {} + device_id = update_device(hass, config_entry, config) + hass.data.setdefault(TAGS, {}) + if device_id not in hass.data[TAGS]: + hass.data[TAGS][device_id] = {} tag_scanner = MQTTTagScanner( hass, @@ -88,91 +71,65 @@ async def async_setup_tag(hass, config, config_entry, discovery_data): config_entry, ) - await tag_scanner.setup() + await tag_scanner.subscribe_topics() if device_id: hass.data[TAGS][device_id][discovery_id] = tag_scanner + send_discovery_done(hass, discovery_data) -def async_has_tags(hass, device_id): + +def async_has_tags(hass: HomeAssistant, device_id: str) -> bool: """Device has tag scanners.""" if TAGS not in hass.data or device_id not in hass.data[TAGS]: return False return hass.data[TAGS][device_id] != {} -class MQTTTagScanner: +class MQTTTagScanner(MqttDiscoveryDeviceUpdate): """MQTT Tag scanner.""" - def __init__(self, hass, config, device_id, discovery_data, config_entry): + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + device_id: str | None, + discovery_data: dict, + config_entry: ConfigEntry, + ) -> None: """Initialize.""" self._config = config self._config_entry = config_entry self.device_id = device_id self.discovery_data = discovery_data self.hass = hass - self._remove_discovery = None - self._remove_device_updated = None - self._sub_state = None - self._value_template = None - - self._setup_from_config(config) - - async def discovery_update(self, payload): - """Handle discovery update.""" - discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] - _LOGGER.info( - "Got update for tag scanner with hash: %s '%s'", discovery_hash, payload - ) - if not payload: - # Empty payload: Remove tag scanner - _LOGGER.info("Removing tag scanner: %s", discovery_hash) - self.tear_down() - if self.device_id: - await cleanup_device_registry( - self.hass, self.device_id, self._config_entry.entry_id - ) - else: - # Non-empty payload: Update tag scanner - _LOGGER.info("Updating tag scanner: %s", discovery_hash) - config = PLATFORM_SCHEMA(payload) - self._config = config - if self.device_id: - _update_device(self.hass, self._config_entry, config) - self._setup_from_config(config) - await self.subscribe_topics() - - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None - ) - - def _setup_from_config(self, config): + self._sub_state: dict[str, EntitySubscription] | None = None self._value_template = MqttValueTemplate( config.get(CONF_VALUE_TEMPLATE), hass=self.hass, ).async_render_with_possible_json_value - async def setup(self): - """Set up the MQTT tag scanner.""" - discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] - await self.subscribe_topics() - if self.device_id: - self._remove_device_updated = self.hass.bus.async_listen( - EVENT_DEVICE_REGISTRY_UPDATED, self.device_updated - ) - self._remove_discovery = async_dispatcher_connect( - self.hass, - MQTT_DISCOVERY_UPDATED.format(discovery_hash), - self.discovery_update, - ) - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None + MqttDiscoveryDeviceUpdate.__init__( + self, hass, discovery_data, device_id, config_entry, LOG_NAME ) - async def subscribe_topics(self): + async def async_update(self, discovery_data: dict) -> None: + """Handle MQTT tag discovery updates.""" + # Update tag scanner + config = PLATFORM_SCHEMA(discovery_data) + self._config = config + self._value_template = MqttValueTemplate( + config.get(CONF_VALUE_TEMPLATE), + hass=self.hass, + ).async_render_with_possible_json_value + update_device(self.hass, self._config_entry, config) + await self.subscribe_topics() + + async def subscribe_topics(self) -> None: """Subscribe to MQTT topics.""" - async def tag_scanned(msg): + @callback + async def tag_scanned(msg: ReceiveMessage) -> None: tag_id = self._value_template(msg.payload, "").strip() if not tag_id: # No output from template, ignore return @@ -195,44 +152,12 @@ class MQTTTagScanner: ) await subscription.async_subscribe_topics(self.hass, self._sub_state) - async def device_updated(self, event): - """Handle the update or removal of a device.""" - if not async_removed_from_device( - self.hass, event, self.device_id, self._config_entry.entry_id - ): - return - - # Stop subscribing to discovery updates to not trigger when we clear the - # discovery topic - self.tear_down() - - # Clear the discovery topic so the entity is not rediscovered after a restart - discovery_topic = self.discovery_data[ATTR_DISCOVERY_TOPIC] - mqtt.publish(self.hass, discovery_topic, "", retain=True) - - def tear_down(self): + async def async_tear_down(self) -> None: """Cleanup tag scanner.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] - - clear_discovery_hash(self.hass, discovery_hash) - if self.device_id: - self._remove_device_updated() - self._remove_discovery() - self._sub_state = subscription.async_unsubscribe_topics( self.hass, self._sub_state ) if self.device_id: self.hass.data[TAGS][self.device_id].pop(discovery_id) - - -def _update_device(hass, config_entry, config): - """Update device registry.""" - device_registry = dr.async_get(hass) - config_entry_id = config_entry.entry_id - device_info = device_info_from_config(config[CONF_DEVICE]) - - if config_entry_id is not None and device_info is not None: - device_info["config_entry_id"] = config_entry_id - device_registry.async_get_or_create(**device_info)