Refactor MQTT discovery (#67966)

* Proof of concept

* remove notify platform

* remove loose test

* Add rework from #67912 (#1)

* Move notify serviceupdater to Mixins

* Move tag discovery handler to Mixins

* fix tests

* Add typing for async_load_platform_helper

* Add add entry unload support for notify platform

* Simplify discovery updates

* Remove not needed extra logic

* Cleanup inrelevant or duplicate code

* reuse update_device and move to mixins

* Remove notify platform

* revert changes to notify platform

* Rename update class

* unify tag entry setup

* Use shared code for device_trigger `update_device`

* PoC shared dispatcher for device_trigger

* Fix bugs

* Improve typing - remove async_update

* Unload config_entry and tests

* Release dispatcher after setup and deduplicate

* closures to methods, revert `in` to `=`, updates

* Re-add update support for tag platform

* Re-add update support for device-trigger platform

* Cleanup rediscovery code revert related changes

* Undo discovery code shift

* Update homeassistant/components/mqtt/mixins.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Update homeassistant/components/mqtt/device_trigger.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Update homeassistant/components/mqtt/mixins.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* revert doc string changes

* move conditions

* typing and check config_entry_id

* Update homeassistant/components/mqtt/mixins.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* cleanup not used attribute

* Remove entry_unload code and tests

* update  comment

* add second comment

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Jan Bouwhuis 2022-04-15 12:35:08 +02:00 committed by GitHub
parent c932407560
commit 3b2aae5045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 366 additions and 288 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import logging import logging
from typing import Any from typing import Any, cast
import attr import attr
import voluptuous as vol import voluptuous as vol
@ -13,6 +13,7 @@ from homeassistant.components.automation import (
AutomationTriggerInfo, AutomationTriggerInfo,
) )
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE, CONF_DEVICE,
CONF_DEVICE_ID, CONF_DEVICE_ID,
@ -23,30 +24,19 @@ from homeassistant.const import (
) )
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import async_dispatcher_send
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import debug_info, trigger as mqtt_trigger from . import debug_info, trigger as mqtt_trigger
from .. import mqtt from .. import mqtt
from .const import ( from .const import ATTR_DISCOVERY_HASH, CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, DOMAIN
ATTR_DISCOVERY_HASH, from .discovery import MQTT_DISCOVERY_DONE
ATTR_DISCOVERY_TOPIC,
CONF_PAYLOAD,
CONF_QOS,
CONF_TOPIC,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash
from .mixins import ( from .mixins import (
CONF_CONNECTIONS,
CONF_IDENTIFIERS,
MQTT_ENTITY_DEVICE_INFO_SCHEMA, MQTT_ENTITY_DEVICE_INFO_SCHEMA,
cleanup_device_registry, MqttDiscoveryDeviceUpdate,
device_info_from_config, send_discovery_done,
update_device,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -89,6 +79,8 @@ TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
DEVICE_TRIGGERS = "mqtt_device_triggers" DEVICE_TRIGGERS = "mqtt_device_triggers"
LOG_NAME = "Device trigger"
@attr.s(slots=True) @attr.s(slots=True)
class TriggerInstance: class TriggerInstance:
@ -99,7 +91,7 @@ class TriggerInstance:
trigger: Trigger = attr.ib() trigger: Trigger = attr.ib()
remove: CALLBACK_TYPE | None = attr.ib(default=None) remove: CALLBACK_TYPE | None = attr.ib(default=None)
async def async_attach_trigger(self): async def async_attach_trigger(self) -> None:
"""Attach MQTT trigger.""" """Attach MQTT trigger."""
mqtt_config = { mqtt_config = {
mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN, mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN,
@ -132,14 +124,15 @@ class Trigger:
hass: HomeAssistant = attr.ib() hass: HomeAssistant = attr.ib()
payload: str | None = attr.ib() payload: str | None = attr.ib()
qos: int | None = attr.ib() qos: int | None = attr.ib()
remove_signal: Callable[[], None] | None = attr.ib()
subtype: str = attr.ib() subtype: str = attr.ib()
topic: str | None = attr.ib() topic: str | None = attr.ib()
type: str = attr.ib() type: str = attr.ib()
value_template: str | None = attr.ib() value_template: str | None = attr.ib()
trigger_instances: list[TriggerInstance] = attr.ib(factory=list) trigger_instances: list[TriggerInstance] = attr.ib(factory=list)
async def add_trigger(self, action, automation_info): async def add_trigger(
self, action: AutomationActionType, automation_info: AutomationTriggerInfo
) -> Callable:
"""Add MQTT trigger.""" """Add MQTT trigger."""
instance = TriggerInstance(action, automation_info, self) instance = TriggerInstance(action, automation_info, self)
self.trigger_instances.append(instance) self.trigger_instances.append(instance)
@ -160,9 +153,8 @@ class Trigger:
return async_remove return async_remove
async def update_trigger(self, config, discovery_hash, remove_signal): async def update_trigger(self, config: ConfigType) -> None:
"""Update MQTT device trigger.""" """Update MQTT device trigger."""
self.remove_signal = remove_signal
self.type = config[CONF_TYPE] self.type = config[CONF_TYPE]
self.subtype = config[CONF_SUBTYPE] self.subtype = config[CONF_SUBTYPE]
self.payload = config[CONF_PAYLOAD] self.payload = config[CONF_PAYLOAD]
@ -178,7 +170,7 @@ class Trigger:
for trig in self.trigger_instances: for trig in self.trigger_instances:
await trig.async_attach_trigger() await trig.async_attach_trigger()
def detach_trigger(self): def detach_trigger(self) -> None:
"""Remove MQTT device trigger.""" """Remove MQTT device trigger."""
# Mark trigger as unknown # Mark trigger as unknown
self.topic = None self.topic = None
@ -190,110 +182,110 @@ class Trigger:
trig.remove = None trig.remove = None
def _update_device(hass, config_entry, config): class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
"""Update device registry.""" """Setup a MQTT device trigger with auto discovery."""
device_registry = dr.async_get(hass)
config_entry_id = config_entry.entry_id
device_info = device_info_from_config(config[CONF_DEVICE])
if config_entry_id is not None and device_info is not None: def __init__(
device_info["config_entry_id"] = config_entry_id self,
device_registry.async_get_or_create(**device_info) hass: HomeAssistant,
config: ConfigType,
device_id: str,
discovery_data: dict,
config_entry: ConfigEntry,
) -> None:
"""Initialize."""
self._config = config
self._config_entry = config_entry
self.device_id = device_id
self.discovery_data = discovery_data
self.hass = hass
MqttDiscoveryDeviceUpdate.__init__(
self,
hass,
discovery_data,
device_id,
config_entry,
LOG_NAME,
)
async def async_setup(self) -> None:
"""Initialize the device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
if discovery_id not in self.hass.data.setdefault(DEVICE_TRIGGERS, {}):
self.hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=self.hass,
device_id=self.device_id,
discovery_data=self.discovery_data,
type=self._config[CONF_TYPE],
subtype=self._config[CONF_SUBTYPE],
topic=self._config[CONF_TOPIC],
payload=self._config[CONF_PAYLOAD],
qos=self._config[CONF_QOS],
value_template=self._config[CONF_VALUE_TEMPLATE],
)
else:
await self.hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
self._config
)
debug_info.add_trigger_discovery_data(
self.hass, discovery_hash, self.discovery_data, self.device_id
)
async def async_update(self, discovery_data: dict) -> None:
"""Handle MQTT device trigger discovery updates."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
debug_info.update_trigger_discovery_data(
self.hass, discovery_hash, discovery_data
)
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
update_device(self.hass, self._config_entry, config)
device_trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
await device_trigger.update_trigger(config)
async def async_tear_down(self) -> None:
"""Cleanup device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
if discovery_id in self.hass.data[DEVICE_TRIGGERS]:
_LOGGER.info("Removing trigger: %s", discovery_hash)
trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
trigger.detach_trigger()
debug_info.remove_trigger_discovery_data(self.hass, discovery_hash)
async def async_setup_trigger(hass, config, config_entry, discovery_data): async def async_setup_trigger(
hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict
) -> None:
"""Set up the MQTT device trigger.""" """Set up the MQTT device trigger."""
config = TRIGGER_DISCOVERY_SCHEMA(config) config = TRIGGER_DISCOVERY_SCHEMA(config)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
remove_signal = None
async def discovery_update(payload): if (device_id := update_device(hass, config_entry, config)) is None:
"""Handle discovery update."""
_LOGGER.info(
"Got update for trigger with hash: %s '%s'", discovery_hash, payload
)
if not payload:
# Empty payload: Remove trigger
_LOGGER.info("Removing trigger: %s", discovery_hash)
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
if discovery_id in hass.data[DEVICE_TRIGGERS]:
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
device_trigger.detach_trigger()
clear_discovery_hash(hass, discovery_hash)
remove_signal()
await cleanup_device_registry(hass, device.id, config_entry.entry_id)
else:
# Non-empty payload: Update trigger
_LOGGER.info("Updating trigger: %s", discovery_hash)
debug_info.update_trigger_discovery_data(hass, discovery_hash, payload)
config = TRIGGER_DISCOVERY_SCHEMA(payload)
_update_device(hass, config_entry, config)
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
await device_trigger.update_trigger(config, discovery_hash, remove_signal)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
remove_signal = async_dispatcher_connect(
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update
)
_update_device(hass, config_entry, config)
device_registry = dr.async_get(hass)
device = device_registry.async_get_device(
{(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]},
{tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]},
)
if device is None:
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
return return
if DEVICE_TRIGGERS not in hass.data: mqtt_device_trigger = MqttDeviceTrigger(
hass.data[DEVICE_TRIGGERS] = {} hass, config, device_id, discovery_data, config_entry
if discovery_id not in hass.data[DEVICE_TRIGGERS]:
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=hass,
device_id=device.id,
discovery_data=discovery_data,
type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE],
topic=config[CONF_TOPIC],
payload=config[CONF_PAYLOAD],
qos=config[CONF_QOS],
remove_signal=remove_signal,
value_template=config[CONF_VALUE_TEMPLATE],
) )
else: await mqtt_device_trigger.async_setup()
await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( send_discovery_done(hass, discovery_data)
config, discovery_hash, remove_signal
)
debug_info.add_trigger_discovery_data(
hass, discovery_hash, discovery_data, device.id
)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
async def async_removed_from_device(hass: HomeAssistant, device_id: str): async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device.""" """Handle Mqtt removed from a device."""
triggers = await async_get_triggers(hass, device_id) triggers = await async_get_triggers(hass, device_id)
for trig in triggers: for trig in triggers:
device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID]) device_trigger: Trigger = hass.data[DEVICE_TRIGGERS].pop(
if device_trigger: trig[CONF_DISCOVERY_ID]
discovery_hash = device_trigger.discovery_data[ATTR_DISCOVERY_HASH]
discovery_topic = device_trigger.discovery_data[ATTR_DISCOVERY_TOPIC]
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
device_trigger.detach_trigger()
clear_discovery_hash(hass, discovery_hash)
device_trigger.remove_signal()
mqtt.publish(
hass,
discovery_topic,
"",
retain=True,
) )
if device_trigger:
device_trigger.detach_trigger()
discovery_data = cast(dict, device_trigger.discovery_data)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
async def async_get_triggers( async def async_get_triggers(
@ -328,8 +320,7 @@ async def async_attach_trigger(
automation_info: AutomationTriggerInfo, automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
if DEVICE_TRIGGERS not in hass.data: hass.data.setdefault(DEVICE_TRIGGERS, {})
hass.data[DEVICE_TRIGGERS] = {}
device_id = config[CONF_DEVICE_ID] device_id = config[CONF_DEVICE_ID]
discovery_id = config[CONF_DISCOVERY_ID] discovery_id = config[CONF_DISCOVERY_ID]
@ -338,7 +329,6 @@ async def async_attach_trigger(
hass=hass, hass=hass,
device_id=device_id, device_id=device_id,
discovery_data=None, discovery_data=None,
remove_signal=None,
type=config[CONF_TYPE], type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE], subtype=config[CONF_SUBTYPE],
topic=None, topic=None,

View File

@ -1,4 +1,6 @@
"""Support for MQTT discovery.""" """Support for MQTT discovery."""
from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
import functools import functools
@ -73,20 +75,22 @@ LAST_DISCOVERY = "mqtt_last_discovery"
TOPIC_BASE = "~" TOPIC_BASE = "~"
def clear_discovery_hash(hass, discovery_hash): class MQTTConfig(dict):
"""Dummy class to allow adding attributes."""
discovery_data: dict
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple) -> None:
"""Clear entry in ALREADY_DISCOVERED list.""" """Clear entry in ALREADY_DISCOVERED list."""
del hass.data[ALREADY_DISCOVERED][discovery_hash] del hass.data[ALREADY_DISCOVERED][discovery_hash]
def set_discovery_hash(hass, discovery_hash): def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple):
"""Clear entry in ALREADY_DISCOVERED list.""" """Clear entry in ALREADY_DISCOVERED list."""
hass.data[ALREADY_DISCOVERED][discovery_hash] = {} hass.data[ALREADY_DISCOVERED][discovery_hash] = {}
class MQTTConfig(dict):
"""Dummy class to allow adding attributes."""
async def async_start( # noqa: C901 async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None hass: HomeAssistant, discovery_topic, config_entry=None
) -> None: ) -> None:
@ -181,6 +185,7 @@ async def async_start( # noqa: C901
await async_process_discovery_payload(component, discovery_id, payload) await async_process_discovery_payload(component, discovery_id, payload)
async def async_process_discovery_payload(component, discovery_id, payload): async def async_process_discovery_payload(component, discovery_id, payload):
"""Process the payload of a new discovery."""
_LOGGER.debug("Process discovery payload %s", payload) _LOGGER.debug("Process discovery payload %s", payload)
discovery_hash = (component, discovery_id) discovery_hash = (component, discovery_id)

View File

@ -5,7 +5,7 @@ from abc import abstractmethod
from collections.abc import Callable from collections.abc import Callable
import json import json
import logging import logging
from typing import Any, Protocol, final from typing import Any, Protocol, cast, final
import voluptuous as vol import voluptuous as vol
@ -32,6 +32,7 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
) )
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
@ -45,7 +46,7 @@ from homeassistant.helpers.entity import (
) )
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import ( from . import (
DATA_MQTT, DATA_MQTT,
@ -496,8 +497,10 @@ class MqttAvailability(Entity):
return self._available_latest return self._available_latest
async def cleanup_device_registry(hass, device_id, config_entry_id): async def cleanup_device_registry(
"""Remove device registry entry if there are no remaining entities or triggers.""" hass: HomeAssistant, device_id: str | None, config_entry_id: str | None
) -> None:
"""Remove MQTT from the device registry entry if there are no remaining entities, triggers or tags."""
# Local import to avoid circular dependencies # Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
from . import device_trigger, tag from . import device_trigger, tag
@ -506,6 +509,7 @@ async def cleanup_device_registry(hass, device_id, config_entry_id):
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
if ( if (
device_id device_id
and config_entry_id
and not er.async_entries_for_device( and not er.async_entries_for_device(
entity_registry, device_id, include_disabled_entities=False entity_registry, device_id, include_disabled_entities=False
) )
@ -517,14 +521,163 @@ async def cleanup_device_registry(hass, device_id, config_entry_id):
) )
def get_discovery_hash(discovery_data: dict) -> tuple:
"""Get the discovery hash from the discovery data."""
return discovery_data[ATTR_DISCOVERY_HASH]
def send_discovery_done(hass: HomeAssistant, discovery_data: dict) -> None:
"""Acknowledge a discovery message has been handled."""
discovery_hash = get_discovery_hash(discovery_data)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
def stop_discovery_updates(
hass: HomeAssistant,
discovery_data: dict,
remove_discovery_updated: Callable[[], None] | None = None,
) -> None:
"""Stop discovery updates of being sent."""
if remove_discovery_updated:
remove_discovery_updated()
remove_discovery_updated = None
discovery_hash = get_discovery_hash(discovery_data)
clear_discovery_hash(hass, discovery_hash)
async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: dict):
"""Clear retained discovery topic in broker to avoid rediscovery after a restart of HA."""
discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC]
await async_publish(hass, discovery_topic, "", retain=True)
class MqttDiscoveryDeviceUpdate:
"""Add support for auto discovery for platforms without an entity."""
def __init__(
self,
hass: HomeAssistant,
discovery_data: dict,
device_id: str | None,
config_entry: ConfigEntry,
log_name: str,
) -> None:
"""Initialize the update service."""
self.hass = hass
self.log_name = log_name
self._discovery_data = discovery_data
self._device_id = device_id
self._config_entry = config_entry
self._config_entry_id = config_entry.entry_id
self._skip_device_removal: bool = False
discovery_hash = get_discovery_hash(discovery_data)
self._remove_discovery_updated = async_dispatcher_connect(
hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
self.async_discovery_update,
)
if device_id is not None:
self._remove_device_updated = hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self._async_device_removed
)
_LOGGER.info(
"%s %s has been initialized",
self.log_name,
discovery_hash,
)
async def async_discovery_update(
self,
discovery_payload: DiscoveryInfoType | None,
) -> None:
"""Handle discovery update."""
discovery_hash = get_discovery_hash(self._discovery_data)
_LOGGER.info(
"Got update for %s with hash: %s '%s'",
self.log_name,
discovery_hash,
discovery_payload,
)
if (
discovery_payload
and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
):
_LOGGER.info(
"%s %s updating",
self.log_name,
discovery_hash,
)
await self.async_update(discovery_payload)
if not discovery_payload:
# Unregister and clean up the current discovery instance
stop_discovery_updates(
self.hass, self._discovery_data, self._remove_discovery_updated
)
await self._async_tear_down()
send_discovery_done(self.hass, self._discovery_data)
_LOGGER.info(
"%s %s has been removed",
self.log_name,
discovery_hash,
)
else:
# Normal update without change
send_discovery_done(self.hass, self._discovery_data)
_LOGGER.info(
"%s %s no changes",
self.log_name,
discovery_hash,
)
return
async def _async_device_removed(self, event: Event) -> None:
"""Handle the manual removal of a device."""
if self._skip_device_removal or not async_removed_from_device(
self.hass, event, cast(str, self._device_id), self._config_entry_id
):
return
# Prevent a second cleanup round after the device is removed
self._remove_device_updated()
self._skip_device_removal = True
# Unregister and clean up and publish an empty payload
# so the service is not rediscovered after a restart
stop_discovery_updates(
self.hass, self._discovery_data, self._remove_discovery_updated
)
await self._async_tear_down()
await async_remove_discovery_payload(self.hass, self._discovery_data)
async def _async_tear_down(self) -> None:
"""Handle the cleanup of the discovery service."""
# Cleanup platform resources
await self.async_tear_down()
# remove the service for auto discovery updates and clean up the device registry
if not self._skip_device_removal:
# Prevent a second cleanup round after the device is removed
self._skip_device_removal = True
await cleanup_device_registry(
self.hass, self._device_id, self._config_entry_id
)
async def async_update(self, discovery_data: dict) -> None:
"""Handle the update of platform specific parts, extend to the platform."""
@abstractmethod
async def async_tear_down(self) -> None:
"""Handle the cleanup of platform specific parts, extend to the platform."""
class MqttDiscoveryUpdate(Entity): class MqttDiscoveryUpdate(Entity):
"""Mixin used to handle updated discovery message.""" """Mixin used to handle updated discovery message for entity based platforms."""
def __init__(self, discovery_data, discovery_update=None) -> None: def __init__(self, discovery_data, discovery_update=None) -> None:
"""Initialize the discovery update mixin.""" """Initialize the discovery update mixin."""
self._discovery_data = discovery_data self._discovery_data = discovery_data
self._discovery_update = discovery_update self._discovery_update = discovery_update
self._remove_signal: Callable | None = None self._remove_discovery_updated: Callable | None = None
self._removed_from_hass = False self._removed_from_hass = False
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
@ -572,7 +725,7 @@ class MqttDiscoveryUpdate(Entity):
else: else:
# Non-empty, unchanged payload: Ignore to avoid changing states # Non-empty, unchanged payload: Ignore to avoid changing states
_LOGGER.info("Ignoring unchanged update for: %s", self.entity_id) _LOGGER.info("Ignoring unchanged update for: %s", self.entity_id)
self.async_send_discovery_done() send_discovery_done(self.hass, self._discovery_data)
if discovery_hash: if discovery_hash:
debug_info.add_entity_discovery_data( debug_info.add_entity_discovery_data(
@ -580,24 +733,12 @@ class MqttDiscoveryUpdate(Entity):
) )
# Set in case the entity has been removed and is re-added, for example when changing entity_id # Set in case the entity has been removed and is re-added, for example when changing entity_id
set_discovery_hash(self.hass, discovery_hash) set_discovery_hash(self.hass, discovery_hash)
self._remove_signal = async_dispatcher_connect( self._remove_discovery_updated = async_dispatcher_connect(
self.hass, self.hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash), MQTT_DISCOVERY_UPDATED.format(discovery_hash),
discovery_callback, discovery_callback,
) )
@callback
def async_send_discovery_done(self) -> None:
"""Acknowledge a discovery message has been handled."""
discovery_hash = (
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
)
if not discovery_hash:
return
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
async def async_removed_from_registry(self) -> None: async def async_removed_from_registry(self) -> None:
"""Clear retained discovery topic in broker.""" """Clear retained discovery topic in broker."""
if not self._removed_from_hass: if not self._removed_from_hass:
@ -606,18 +747,14 @@ class MqttDiscoveryUpdate(Entity):
self._cleanup_discovery_on_remove() self._cleanup_discovery_on_remove()
# Clear the discovery topic so the entity is not rediscovered after a restart # Clear the discovery topic so the entity is not rediscovered after a restart
discovery_topic = self._discovery_data[ATTR_DISCOVERY_TOPIC] await async_remove_discovery_payload(self.hass, self._discovery_data)
await async_publish(self.hass, discovery_topic, "", retain=True)
@callback @callback
def add_to_platform_abort(self) -> None: def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform.""" """Abort adding an entity to a platform."""
if self._discovery_data: if self._discovery_data:
discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH] stop_discovery_updates(self.hass, self._discovery_data)
clear_discovery_hash(self.hass, discovery_hash) send_discovery_done(self.hass, self._discovery_data)
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
super().add_to_platform_abort() super().add_to_platform_abort()
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
@ -627,13 +764,11 @@ class MqttDiscoveryUpdate(Entity):
def _cleanup_discovery_on_remove(self) -> None: def _cleanup_discovery_on_remove(self) -> None:
"""Stop listening to signal and cleanup discovery data.""" """Stop listening to signal and cleanup discovery data."""
if self._discovery_data and not self._removed_from_hass: if self._discovery_data and not self._removed_from_hass:
clear_discovery_hash(self.hass, self._discovery_data[ATTR_DISCOVERY_HASH]) stop_discovery_updates(
self.hass, self._discovery_data, self._remove_discovery_updated
)
self._removed_from_hass = True self._removed_from_hass = True
if self._remove_signal:
self._remove_signal()
self._remove_signal = None
def device_info_from_config(config) -> DeviceInfo | None: def device_info_from_config(config) -> DeviceInfo | None:
"""Return a device description for device registry.""" """Return a device description for device registry."""
@ -737,7 +872,8 @@ class MqttEntity(
self._prepare_subscribe_topics() self._prepare_subscribe_topics()
await self._subscribe_topics() await self._subscribe_topics()
await self.mqtt_async_added_to_hass() await self.mqtt_async_added_to_hass()
self.async_send_discovery_done() if self._discovery_data is not None:
send_discovery_done(self.hass, self._discovery_data)
async def mqtt_async_added_to_hass(self): async def mqtt_async_added_to_hass(self):
"""Call before the discovery message is acknowledged. """Call before the discovery message is acknowledged.
@ -839,6 +975,28 @@ class MqttEntity(
return self._unique_id return self._unique_id
def update_device(
hass: HomeAssistant,
config_entry: ConfigEntry,
config: ConfigType,
) -> str | None:
"""Update device registry."""
if CONF_DEVICE not in config:
return None
device = None
device_registry = dr.async_get(hass)
config_entry_id = config_entry.entry_id
device_info = device_info_from_config(config[CONF_DEVICE])
if config_entry_id is not None and device_info is not None:
update_device_info = cast(dict, device_info)
update_device_info["config_entry_id"] = config_entry_id
device = device_registry.async_get_or_create(**update_device_info)
return device.id if device else None
@callback @callback
def async_removed_from_device( def async_removed_from_device(
hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str

View File

@ -1,4 +1,4 @@
"""Modesl used by multiple MQTT modules.""" """Models used by multiple MQTT modules."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable

View File

@ -1,40 +1,31 @@
"""Provides tag scanning for MQTT.""" """Provides tag scanning for MQTT."""
from __future__ import annotations
import functools import functools
import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM, CONF_VALUE_TEMPLATE from homeassistant.const import CONF_DEVICE, CONF_PLATFORM, CONF_VALUE_TEMPLATE
from homeassistant.helpers import device_registry as dr from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from . import MqttValueTemplate, subscription from . import MqttValueTemplate, subscription
from .. import mqtt from .. import mqtt
from .const import ( from .const import ATTR_DISCOVERY_HASH, CONF_QOS, CONF_TOPIC
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_TOPIC,
CONF_QOS,
CONF_TOPIC,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash
from .mixins import ( from .mixins import (
CONF_CONNECTIONS,
CONF_IDENTIFIERS,
MQTT_ENTITY_DEVICE_INFO_SCHEMA, MQTT_ENTITY_DEVICE_INFO_SCHEMA,
async_removed_from_device, MqttDiscoveryDeviceUpdate,
async_setup_entry_helper, async_setup_entry_helper,
cleanup_device_registry, send_discovery_done,
device_info_from_config, update_device,
) )
from .models import ReceiveMessage
from .subscription import EntitySubscription
from .util import valid_subscribe_topic from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) LOG_NAME = "Tag"
TAG = "tag" TAG = "tag"
TAGS = "mqtt_tags" TAGS = "mqtt_tags"
@ -50,33 +41,25 @@ PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
) )
async def async_setup_entry(hass, config_entry): async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Set up MQTT tag scan dynamically through MQTT discovery.""" """Set up MQTT device automation dynamically through MQTT discovery."""
setup = functools.partial(async_setup_tag, hass, config_entry=config_entry)
await async_setup_entry_helper(hass, "tag", setup, PLATFORM_SCHEMA) setup = functools.partial(_async_setup_tag, hass, config_entry=config_entry)
await async_setup_entry_helper(hass, TAG, setup, PLATFORM_SCHEMA)
async def async_setup_tag(hass, config, config_entry, discovery_data): async def _async_setup_tag(
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: dict,
) -> None:
"""Set up the MQTT tag scanner.""" """Set up the MQTT tag scanner."""
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
device_id = None device_id = update_device(hass, config_entry, config)
if CONF_DEVICE in config: hass.data.setdefault(TAGS, {})
_update_device(hass, config_entry, config)
device_registry = dr.async_get(hass)
device = device_registry.async_get_device(
{(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]},
{tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]},
)
if device is None:
return
device_id = device.id
if TAGS not in hass.data:
hass.data[TAGS] = {}
if device_id not in hass.data[TAGS]: if device_id not in hass.data[TAGS]:
hass.data[TAGS][device_id] = {} hass.data[TAGS][device_id] = {}
@ -88,91 +71,65 @@ async def async_setup_tag(hass, config, config_entry, discovery_data):
config_entry, config_entry,
) )
await tag_scanner.setup() await tag_scanner.subscribe_topics()
if device_id: if device_id:
hass.data[TAGS][device_id][discovery_id] = tag_scanner hass.data[TAGS][device_id][discovery_id] = tag_scanner
send_discovery_done(hass, discovery_data)
def async_has_tags(hass, device_id):
def async_has_tags(hass: HomeAssistant, device_id: str) -> bool:
"""Device has tag scanners.""" """Device has tag scanners."""
if TAGS not in hass.data or device_id not in hass.data[TAGS]: if TAGS not in hass.data or device_id not in hass.data[TAGS]:
return False return False
return hass.data[TAGS][device_id] != {} return hass.data[TAGS][device_id] != {}
class MQTTTagScanner: class MQTTTagScanner(MqttDiscoveryDeviceUpdate):
"""MQTT Tag scanner.""" """MQTT Tag scanner."""
def __init__(self, hass, config, device_id, discovery_data, config_entry): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
device_id: str | None,
discovery_data: dict,
config_entry: ConfigEntry,
) -> None:
"""Initialize.""" """Initialize."""
self._config = config self._config = config
self._config_entry = config_entry self._config_entry = config_entry
self.device_id = device_id self.device_id = device_id
self.discovery_data = discovery_data self.discovery_data = discovery_data
self.hass = hass self.hass = hass
self._remove_discovery = None self._sub_state: dict[str, EntitySubscription] | None = None
self._remove_device_updated = None
self._sub_state = None
self._value_template = None
self._setup_from_config(config)
async def discovery_update(self, payload):
"""Handle discovery update."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
_LOGGER.info(
"Got update for tag scanner with hash: %s '%s'", discovery_hash, payload
)
if not payload:
# Empty payload: Remove tag scanner
_LOGGER.info("Removing tag scanner: %s", discovery_hash)
self.tear_down()
if self.device_id:
await cleanup_device_registry(
self.hass, self.device_id, self._config_entry.entry_id
)
else:
# Non-empty payload: Update tag scanner
_LOGGER.info("Updating tag scanner: %s", discovery_hash)
config = PLATFORM_SCHEMA(payload)
self._config = config
if self.device_id:
_update_device(self.hass, self._config_entry, config)
self._setup_from_config(config)
await self.subscribe_topics()
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
def _setup_from_config(self, config):
self._value_template = MqttValueTemplate( self._value_template = MqttValueTemplate(
config.get(CONF_VALUE_TEMPLATE), config.get(CONF_VALUE_TEMPLATE),
hass=self.hass, hass=self.hass,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def setup(self): MqttDiscoveryDeviceUpdate.__init__(
"""Set up the MQTT tag scanner.""" self, hass, discovery_data, device_id, config_entry, LOG_NAME
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
await self.subscribe_topics()
if self.device_id:
self._remove_device_updated = self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.device_updated
)
self._remove_discovery = async_dispatcher_connect(
self.hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
self.discovery_update,
)
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
) )
async def subscribe_topics(self): async def async_update(self, discovery_data: dict) -> None:
"""Handle MQTT tag discovery updates."""
# Update tag scanner
config = PLATFORM_SCHEMA(discovery_data)
self._config = config
self._value_template = MqttValueTemplate(
config.get(CONF_VALUE_TEMPLATE),
hass=self.hass,
).async_render_with_possible_json_value
update_device(self.hass, self._config_entry, config)
await self.subscribe_topics()
async def subscribe_topics(self) -> None:
"""Subscribe to MQTT topics.""" """Subscribe to MQTT topics."""
async def tag_scanned(msg): @callback
async def tag_scanned(msg: ReceiveMessage) -> None:
tag_id = self._value_template(msg.payload, "").strip() tag_id = self._value_template(msg.payload, "").strip()
if not tag_id: # No output from template, ignore if not tag_id: # No output from template, ignore
return return
@ -195,44 +152,12 @@ class MQTTTagScanner:
) )
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
async def device_updated(self, event): async def async_tear_down(self) -> None:
"""Handle the update or removal of a device."""
if not async_removed_from_device(
self.hass, event, self.device_id, self._config_entry.entry_id
):
return
# Stop subscribing to discovery updates to not trigger when we clear the
# discovery topic
self.tear_down()
# Clear the discovery topic so the entity is not rediscovered after a restart
discovery_topic = self.discovery_data[ATTR_DISCOVERY_TOPIC]
mqtt.publish(self.hass, discovery_topic, "", retain=True)
def tear_down(self):
"""Cleanup tag scanner.""" """Cleanup tag scanner."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
clear_discovery_hash(self.hass, discovery_hash)
if self.device_id:
self._remove_device_updated()
self._remove_discovery()
self._sub_state = subscription.async_unsubscribe_topics( self._sub_state = subscription.async_unsubscribe_topics(
self.hass, self._sub_state self.hass, self._sub_state
) )
if self.device_id: if self.device_id:
self.hass.data[TAGS][self.device_id].pop(discovery_id) self.hass.data[TAGS][self.device_id].pop(discovery_id)
def _update_device(hass, config_entry, config):
"""Update device registry."""
device_registry = dr.async_get(hass)
config_entry_id = config_entry.entry_id
device_info = device_info_from_config(config[CONF_DEVICE])
if config_entry_id is not None and device_info is not None:
device_info["config_entry_id"] = config_entry_id
device_registry.async_get_or_create(**device_info)