Refactor MQTT discovery (#67966)

* Proof of concept

* remove notify platform

* remove loose test

* Add rework from #67912 (#1)

* Move notify serviceupdater to Mixins

* Move tag discovery handler to Mixins

* fix tests

* Add typing for async_load_platform_helper

* Add add entry unload support for notify platform

* Simplify discovery updates

* Remove not needed extra logic

* Cleanup inrelevant or duplicate code

* reuse update_device and move to mixins

* Remove notify platform

* revert changes to notify platform

* Rename update class

* unify tag entry setup

* Use shared code for device_trigger `update_device`

* PoC shared dispatcher for device_trigger

* Fix bugs

* Improve typing - remove async_update

* Unload config_entry and tests

* Release dispatcher after setup and deduplicate

* closures to methods, revert `in` to `=`, updates

* Re-add update support for tag platform

* Re-add update support for device-trigger platform

* Cleanup rediscovery code revert related changes

* Undo discovery code shift

* Update homeassistant/components/mqtt/mixins.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Update homeassistant/components/mqtt/device_trigger.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Update homeassistant/components/mqtt/mixins.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* revert doc string changes

* move conditions

* typing and check config_entry_id

* Update homeassistant/components/mqtt/mixins.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* cleanup not used attribute

* Remove entry_unload code and tests

* update  comment

* add second comment

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Jan Bouwhuis 2022-04-15 12:35:08 +02:00 committed by GitHub
parent c932407560
commit 3b2aae5045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 366 additions and 288 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable
import logging
from typing import Any
from typing import Any, cast
import attr
import voluptuous as vol
@ -13,6 +13,7 @@ from homeassistant.components.automation import (
AutomationTriggerInfo,
)
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
CONF_DEVICE,
CONF_DEVICE_ID,
@ -23,30 +24,19 @@ from homeassistant.const import (
)
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import ConfigType
from . import debug_info, trigger as mqtt_trigger
from .. import mqtt
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_TOPIC,
CONF_PAYLOAD,
CONF_QOS,
CONF_TOPIC,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash
from .const import ATTR_DISCOVERY_HASH, CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, DOMAIN
from .discovery import MQTT_DISCOVERY_DONE
from .mixins import (
CONF_CONNECTIONS,
CONF_IDENTIFIERS,
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
cleanup_device_registry,
device_info_from_config,
MqttDiscoveryDeviceUpdate,
send_discovery_done,
update_device,
)
_LOGGER = logging.getLogger(__name__)
@ -89,6 +79,8 @@ TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
DEVICE_TRIGGERS = "mqtt_device_triggers"
LOG_NAME = "Device trigger"
@attr.s(slots=True)
class TriggerInstance:
@ -99,7 +91,7 @@ class TriggerInstance:
trigger: Trigger = attr.ib()
remove: CALLBACK_TYPE | None = attr.ib(default=None)
async def async_attach_trigger(self):
async def async_attach_trigger(self) -> None:
"""Attach MQTT trigger."""
mqtt_config = {
mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN,
@ -132,14 +124,15 @@ class Trigger:
hass: HomeAssistant = attr.ib()
payload: str | None = attr.ib()
qos: int | None = attr.ib()
remove_signal: Callable[[], None] | None = attr.ib()
subtype: str = attr.ib()
topic: str | None = attr.ib()
type: str = attr.ib()
value_template: str | None = attr.ib()
trigger_instances: list[TriggerInstance] = attr.ib(factory=list)
async def add_trigger(self, action, automation_info):
async def add_trigger(
self, action: AutomationActionType, automation_info: AutomationTriggerInfo
) -> Callable:
"""Add MQTT trigger."""
instance = TriggerInstance(action, automation_info, self)
self.trigger_instances.append(instance)
@ -160,9 +153,8 @@ class Trigger:
return async_remove
async def update_trigger(self, config, discovery_hash, remove_signal):
async def update_trigger(self, config: ConfigType) -> None:
"""Update MQTT device trigger."""
self.remove_signal = remove_signal
self.type = config[CONF_TYPE]
self.subtype = config[CONF_SUBTYPE]
self.payload = config[CONF_PAYLOAD]
@ -178,7 +170,7 @@ class Trigger:
for trig in self.trigger_instances:
await trig.async_attach_trigger()
def detach_trigger(self):
def detach_trigger(self) -> None:
"""Remove MQTT device trigger."""
# Mark trigger as unknown
self.topic = None
@ -190,110 +182,110 @@ class Trigger:
trig.remove = None
def _update_device(hass, config_entry, config):
"""Update device registry."""
device_registry = dr.async_get(hass)
config_entry_id = config_entry.entry_id
device_info = device_info_from_config(config[CONF_DEVICE])
class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
"""Setup a MQTT device trigger with auto discovery."""
if config_entry_id is not None and device_info is not None:
device_info["config_entry_id"] = config_entry_id
device_registry.async_get_or_create(**device_info)
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
device_id: str,
discovery_data: dict,
config_entry: ConfigEntry,
) -> None:
"""Initialize."""
self._config = config
self._config_entry = config_entry
self.device_id = device_id
self.discovery_data = discovery_data
self.hass = hass
MqttDiscoveryDeviceUpdate.__init__(
self,
hass,
discovery_data,
device_id,
config_entry,
LOG_NAME,
)
async def async_setup(self) -> None:
"""Initialize the device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
if discovery_id not in self.hass.data.setdefault(DEVICE_TRIGGERS, {}):
self.hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=self.hass,
device_id=self.device_id,
discovery_data=self.discovery_data,
type=self._config[CONF_TYPE],
subtype=self._config[CONF_SUBTYPE],
topic=self._config[CONF_TOPIC],
payload=self._config[CONF_PAYLOAD],
qos=self._config[CONF_QOS],
value_template=self._config[CONF_VALUE_TEMPLATE],
)
else:
await self.hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
self._config
)
debug_info.add_trigger_discovery_data(
self.hass, discovery_hash, self.discovery_data, self.device_id
)
async def async_update(self, discovery_data: dict) -> None:
"""Handle MQTT device trigger discovery updates."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
debug_info.update_trigger_discovery_data(
self.hass, discovery_hash, discovery_data
)
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
update_device(self.hass, self._config_entry, config)
device_trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
await device_trigger.update_trigger(config)
async def async_tear_down(self) -> None:
"""Cleanup device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
if discovery_id in self.hass.data[DEVICE_TRIGGERS]:
_LOGGER.info("Removing trigger: %s", discovery_hash)
trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
trigger.detach_trigger()
debug_info.remove_trigger_discovery_data(self.hass, discovery_hash)
async def async_setup_trigger(hass, config, config_entry, discovery_data):
async def async_setup_trigger(
hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict
) -> None:
"""Set up the MQTT device trigger."""
config = TRIGGER_DISCOVERY_SCHEMA(config)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
remove_signal = None
async def discovery_update(payload):
"""Handle discovery update."""
_LOGGER.info(
"Got update for trigger with hash: %s '%s'", discovery_hash, payload
)
if not payload:
# Empty payload: Remove trigger
_LOGGER.info("Removing trigger: %s", discovery_hash)
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
if discovery_id in hass.data[DEVICE_TRIGGERS]:
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
device_trigger.detach_trigger()
clear_discovery_hash(hass, discovery_hash)
remove_signal()
await cleanup_device_registry(hass, device.id, config_entry.entry_id)
else:
# Non-empty payload: Update trigger
_LOGGER.info("Updating trigger: %s", discovery_hash)
debug_info.update_trigger_discovery_data(hass, discovery_hash, payload)
config = TRIGGER_DISCOVERY_SCHEMA(payload)
_update_device(hass, config_entry, config)
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
await device_trigger.update_trigger(config, discovery_hash, remove_signal)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
remove_signal = async_dispatcher_connect(
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update
)
_update_device(hass, config_entry, config)
device_registry = dr.async_get(hass)
device = device_registry.async_get_device(
{(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]},
{tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]},
)
if device is None:
if (device_id := update_device(hass, config_entry, config)) is None:
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
return
if DEVICE_TRIGGERS not in hass.data:
hass.data[DEVICE_TRIGGERS] = {}
if discovery_id not in hass.data[DEVICE_TRIGGERS]:
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=hass,
device_id=device.id,
discovery_data=discovery_data,
type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE],
topic=config[CONF_TOPIC],
payload=config[CONF_PAYLOAD],
qos=config[CONF_QOS],
remove_signal=remove_signal,
value_template=config[CONF_VALUE_TEMPLATE],
mqtt_device_trigger = MqttDeviceTrigger(
hass, config, device_id, discovery_data, config_entry
)
else:
await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
config, discovery_hash, remove_signal
)
debug_info.add_trigger_discovery_data(
hass, discovery_hash, discovery_data, device.id
)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
await mqtt_device_trigger.async_setup()
send_discovery_done(hass, discovery_data)
async def async_removed_from_device(hass: HomeAssistant, device_id: str):
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device."""
triggers = await async_get_triggers(hass, device_id)
for trig in triggers:
device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID])
if device_trigger:
discovery_hash = device_trigger.discovery_data[ATTR_DISCOVERY_HASH]
discovery_topic = device_trigger.discovery_data[ATTR_DISCOVERY_TOPIC]
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
device_trigger.detach_trigger()
clear_discovery_hash(hass, discovery_hash)
device_trigger.remove_signal()
mqtt.publish(
hass,
discovery_topic,
"",
retain=True,
device_trigger: Trigger = hass.data[DEVICE_TRIGGERS].pop(
trig[CONF_DISCOVERY_ID]
)
if device_trigger:
device_trigger.detach_trigger()
discovery_data = cast(dict, device_trigger.discovery_data)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
async def async_get_triggers(
@ -328,8 +320,7 @@ async def async_attach_trigger(
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
if DEVICE_TRIGGERS not in hass.data:
hass.data[DEVICE_TRIGGERS] = {}
hass.data.setdefault(DEVICE_TRIGGERS, {})
device_id = config[CONF_DEVICE_ID]
discovery_id = config[CONF_DISCOVERY_ID]
@ -338,7 +329,6 @@ async def async_attach_trigger(
hass=hass,
device_id=device_id,
discovery_data=None,
remove_signal=None,
type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE],
topic=None,

View File

@ -1,4 +1,6 @@
"""Support for MQTT discovery."""
from __future__ import annotations
import asyncio
from collections import deque
import functools
@ -73,20 +75,22 @@ LAST_DISCOVERY = "mqtt_last_discovery"
TOPIC_BASE = "~"
def clear_discovery_hash(hass, discovery_hash):
class MQTTConfig(dict):
"""Dummy class to allow adding attributes."""
discovery_data: dict
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple) -> None:
"""Clear entry in ALREADY_DISCOVERED list."""
del hass.data[ALREADY_DISCOVERED][discovery_hash]
def set_discovery_hash(hass, discovery_hash):
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple):
"""Clear entry in ALREADY_DISCOVERED list."""
hass.data[ALREADY_DISCOVERED][discovery_hash] = {}
class MQTTConfig(dict):
"""Dummy class to allow adding attributes."""
async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None
) -> None:
@ -181,6 +185,7 @@ async def async_start( # noqa: C901
await async_process_discovery_payload(component, discovery_id, payload)
async def async_process_discovery_payload(component, discovery_id, payload):
"""Process the payload of a new discovery."""
_LOGGER.debug("Process discovery payload %s", payload)
discovery_hash = (component, discovery_id)

View File

@ -5,7 +5,7 @@ from abc import abstractmethod
from collections.abc import Callable
import json
import logging
from typing import Any, Protocol, final
from typing import Any, Protocol, cast, final
import voluptuous as vol
@ -32,6 +32,7 @@ from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
@ -45,7 +46,7 @@ from homeassistant.helpers.entity import (
)
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import (
DATA_MQTT,
@ -496,8 +497,10 @@ class MqttAvailability(Entity):
return self._available_latest
async def cleanup_device_registry(hass, device_id, config_entry_id):
"""Remove device registry entry if there are no remaining entities or triggers."""
async def cleanup_device_registry(
hass: HomeAssistant, device_id: str | None, config_entry_id: str | None
) -> None:
"""Remove MQTT from the device registry entry if there are no remaining entities, triggers or tags."""
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from . import device_trigger, tag
@ -506,6 +509,7 @@ async def cleanup_device_registry(hass, device_id, config_entry_id):
entity_registry = er.async_get(hass)
if (
device_id
and config_entry_id
and not er.async_entries_for_device(
entity_registry, device_id, include_disabled_entities=False
)
@ -517,14 +521,163 @@ async def cleanup_device_registry(hass, device_id, config_entry_id):
)
def get_discovery_hash(discovery_data: dict) -> tuple:
"""Get the discovery hash from the discovery data."""
return discovery_data[ATTR_DISCOVERY_HASH]
def send_discovery_done(hass: HomeAssistant, discovery_data: dict) -> None:
"""Acknowledge a discovery message has been handled."""
discovery_hash = get_discovery_hash(discovery_data)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
def stop_discovery_updates(
hass: HomeAssistant,
discovery_data: dict,
remove_discovery_updated: Callable[[], None] | None = None,
) -> None:
"""Stop discovery updates of being sent."""
if remove_discovery_updated:
remove_discovery_updated()
remove_discovery_updated = None
discovery_hash = get_discovery_hash(discovery_data)
clear_discovery_hash(hass, discovery_hash)
async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: dict):
"""Clear retained discovery topic in broker to avoid rediscovery after a restart of HA."""
discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC]
await async_publish(hass, discovery_topic, "", retain=True)
class MqttDiscoveryDeviceUpdate:
"""Add support for auto discovery for platforms without an entity."""
def __init__(
self,
hass: HomeAssistant,
discovery_data: dict,
device_id: str | None,
config_entry: ConfigEntry,
log_name: str,
) -> None:
"""Initialize the update service."""
self.hass = hass
self.log_name = log_name
self._discovery_data = discovery_data
self._device_id = device_id
self._config_entry = config_entry
self._config_entry_id = config_entry.entry_id
self._skip_device_removal: bool = False
discovery_hash = get_discovery_hash(discovery_data)
self._remove_discovery_updated = async_dispatcher_connect(
hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
self.async_discovery_update,
)
if device_id is not None:
self._remove_device_updated = hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self._async_device_removed
)
_LOGGER.info(
"%s %s has been initialized",
self.log_name,
discovery_hash,
)
async def async_discovery_update(
self,
discovery_payload: DiscoveryInfoType | None,
) -> None:
"""Handle discovery update."""
discovery_hash = get_discovery_hash(self._discovery_data)
_LOGGER.info(
"Got update for %s with hash: %s '%s'",
self.log_name,
discovery_hash,
discovery_payload,
)
if (
discovery_payload
and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
):
_LOGGER.info(
"%s %s updating",
self.log_name,
discovery_hash,
)
await self.async_update(discovery_payload)
if not discovery_payload:
# Unregister and clean up the current discovery instance
stop_discovery_updates(
self.hass, self._discovery_data, self._remove_discovery_updated
)
await self._async_tear_down()
send_discovery_done(self.hass, self._discovery_data)
_LOGGER.info(
"%s %s has been removed",
self.log_name,
discovery_hash,
)
else:
# Normal update without change
send_discovery_done(self.hass, self._discovery_data)
_LOGGER.info(
"%s %s no changes",
self.log_name,
discovery_hash,
)
return
async def _async_device_removed(self, event: Event) -> None:
"""Handle the manual removal of a device."""
if self._skip_device_removal or not async_removed_from_device(
self.hass, event, cast(str, self._device_id), self._config_entry_id
):
return
# Prevent a second cleanup round after the device is removed
self._remove_device_updated()
self._skip_device_removal = True
# Unregister and clean up and publish an empty payload
# so the service is not rediscovered after a restart
stop_discovery_updates(
self.hass, self._discovery_data, self._remove_discovery_updated
)
await self._async_tear_down()
await async_remove_discovery_payload(self.hass, self._discovery_data)
async def _async_tear_down(self) -> None:
"""Handle the cleanup of the discovery service."""
# Cleanup platform resources
await self.async_tear_down()
# remove the service for auto discovery updates and clean up the device registry
if not self._skip_device_removal:
# Prevent a second cleanup round after the device is removed
self._skip_device_removal = True
await cleanup_device_registry(
self.hass, self._device_id, self._config_entry_id
)
async def async_update(self, discovery_data: dict) -> None:
"""Handle the update of platform specific parts, extend to the platform."""
@abstractmethod
async def async_tear_down(self) -> None:
"""Handle the cleanup of platform specific parts, extend to the platform."""
class MqttDiscoveryUpdate(Entity):
"""Mixin used to handle updated discovery message."""
"""Mixin used to handle updated discovery message for entity based platforms."""
def __init__(self, discovery_data, discovery_update=None) -> None:
"""Initialize the discovery update mixin."""
self._discovery_data = discovery_data
self._discovery_update = discovery_update
self._remove_signal: Callable | None = None
self._remove_discovery_updated: Callable | None = None
self._removed_from_hass = False
async def async_added_to_hass(self) -> None:
@ -572,7 +725,7 @@ class MqttDiscoveryUpdate(Entity):
else:
# Non-empty, unchanged payload: Ignore to avoid changing states
_LOGGER.info("Ignoring unchanged update for: %s", self.entity_id)
self.async_send_discovery_done()
send_discovery_done(self.hass, self._discovery_data)
if discovery_hash:
debug_info.add_entity_discovery_data(
@ -580,24 +733,12 @@ class MqttDiscoveryUpdate(Entity):
)
# Set in case the entity has been removed and is re-added, for example when changing entity_id
set_discovery_hash(self.hass, discovery_hash)
self._remove_signal = async_dispatcher_connect(
self._remove_discovery_updated = async_dispatcher_connect(
self.hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
discovery_callback,
)
@callback
def async_send_discovery_done(self) -> None:
"""Acknowledge a discovery message has been handled."""
discovery_hash = (
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
)
if not discovery_hash:
return
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
async def async_removed_from_registry(self) -> None:
"""Clear retained discovery topic in broker."""
if not self._removed_from_hass:
@ -606,18 +747,14 @@ class MqttDiscoveryUpdate(Entity):
self._cleanup_discovery_on_remove()
# Clear the discovery topic so the entity is not rediscovered after a restart
discovery_topic = self._discovery_data[ATTR_DISCOVERY_TOPIC]
await async_publish(self.hass, discovery_topic, "", retain=True)
await async_remove_discovery_payload(self.hass, self._discovery_data)
@callback
def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform."""
if self._discovery_data:
discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH]
clear_discovery_hash(self.hass, discovery_hash)
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
stop_discovery_updates(self.hass, self._discovery_data)
send_discovery_done(self.hass, self._discovery_data)
super().add_to_platform_abort()
async def async_will_remove_from_hass(self) -> None:
@ -627,13 +764,11 @@ class MqttDiscoveryUpdate(Entity):
def _cleanup_discovery_on_remove(self) -> None:
"""Stop listening to signal and cleanup discovery data."""
if self._discovery_data and not self._removed_from_hass:
clear_discovery_hash(self.hass, self._discovery_data[ATTR_DISCOVERY_HASH])
stop_discovery_updates(
self.hass, self._discovery_data, self._remove_discovery_updated
)
self._removed_from_hass = True
if self._remove_signal:
self._remove_signal()
self._remove_signal = None
def device_info_from_config(config) -> DeviceInfo | None:
"""Return a device description for device registry."""
@ -737,7 +872,8 @@ class MqttEntity(
self._prepare_subscribe_topics()
await self._subscribe_topics()
await self.mqtt_async_added_to_hass()
self.async_send_discovery_done()
if self._discovery_data is not None:
send_discovery_done(self.hass, self._discovery_data)
async def mqtt_async_added_to_hass(self):
"""Call before the discovery message is acknowledged.
@ -839,6 +975,28 @@ class MqttEntity(
return self._unique_id
def update_device(
hass: HomeAssistant,
config_entry: ConfigEntry,
config: ConfigType,
) -> str | None:
"""Update device registry."""
if CONF_DEVICE not in config:
return None
device = None
device_registry = dr.async_get(hass)
config_entry_id = config_entry.entry_id
device_info = device_info_from_config(config[CONF_DEVICE])
if config_entry_id is not None and device_info is not None:
update_device_info = cast(dict, device_info)
update_device_info["config_entry_id"] = config_entry_id
device = device_registry.async_get_or_create(**update_device_info)
return device.id if device else None
@callback
def async_removed_from_device(
hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str

View File

@ -1,4 +1,4 @@
"""Modesl used by multiple MQTT modules."""
"""Models used by multiple MQTT modules."""
from __future__ import annotations
from collections.abc import Awaitable, Callable

View File

@ -1,40 +1,31 @@
"""Provides tag scanning for MQTT."""
from __future__ import annotations
import functools
import logging
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM, CONF_VALUE_TEMPLATE
from homeassistant.helpers import device_registry as dr
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.typing import ConfigType
from . import MqttValueTemplate, subscription
from .. import mqtt
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_TOPIC,
CONF_QOS,
CONF_TOPIC,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash
from .const import ATTR_DISCOVERY_HASH, CONF_QOS, CONF_TOPIC
from .mixins import (
CONF_CONNECTIONS,
CONF_IDENTIFIERS,
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
async_removed_from_device,
MqttDiscoveryDeviceUpdate,
async_setup_entry_helper,
cleanup_device_registry,
device_info_from_config,
send_discovery_done,
update_device,
)
from .models import ReceiveMessage
from .subscription import EntitySubscription
from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
LOG_NAME = "Tag"
TAG = "tag"
TAGS = "mqtt_tags"
@ -50,33 +41,25 @@ PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
)
async def async_setup_entry(hass, config_entry):
"""Set up MQTT tag scan dynamically through MQTT discovery."""
setup = functools.partial(async_setup_tag, hass, config_entry=config_entry)
await async_setup_entry_helper(hass, "tag", setup, PLATFORM_SCHEMA)
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Set up MQTT device automation dynamically through MQTT discovery."""
setup = functools.partial(_async_setup_tag, hass, config_entry=config_entry)
await async_setup_entry_helper(hass, TAG, setup, PLATFORM_SCHEMA)
async def async_setup_tag(hass, config, config_entry, discovery_data):
async def _async_setup_tag(
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: dict,
) -> None:
"""Set up the MQTT tag scanner."""
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
device_id = None
if CONF_DEVICE in config:
_update_device(hass, config_entry, config)
device_registry = dr.async_get(hass)
device = device_registry.async_get_device(
{(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]},
{tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]},
)
if device is None:
return
device_id = device.id
if TAGS not in hass.data:
hass.data[TAGS] = {}
device_id = update_device(hass, config_entry, config)
hass.data.setdefault(TAGS, {})
if device_id not in hass.data[TAGS]:
hass.data[TAGS][device_id] = {}
@ -88,91 +71,65 @@ async def async_setup_tag(hass, config, config_entry, discovery_data):
config_entry,
)
await tag_scanner.setup()
await tag_scanner.subscribe_topics()
if device_id:
hass.data[TAGS][device_id][discovery_id] = tag_scanner
send_discovery_done(hass, discovery_data)
def async_has_tags(hass, device_id):
def async_has_tags(hass: HomeAssistant, device_id: str) -> bool:
"""Device has tag scanners."""
if TAGS not in hass.data or device_id not in hass.data[TAGS]:
return False
return hass.data[TAGS][device_id] != {}
class MQTTTagScanner:
class MQTTTagScanner(MqttDiscoveryDeviceUpdate):
"""MQTT Tag scanner."""
def __init__(self, hass, config, device_id, discovery_data, config_entry):
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
device_id: str | None,
discovery_data: dict,
config_entry: ConfigEntry,
) -> None:
"""Initialize."""
self._config = config
self._config_entry = config_entry
self.device_id = device_id
self.discovery_data = discovery_data
self.hass = hass
self._remove_discovery = None
self._remove_device_updated = None
self._sub_state = None
self._value_template = None
self._setup_from_config(config)
async def discovery_update(self, payload):
"""Handle discovery update."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
_LOGGER.info(
"Got update for tag scanner with hash: %s '%s'", discovery_hash, payload
)
if not payload:
# Empty payload: Remove tag scanner
_LOGGER.info("Removing tag scanner: %s", discovery_hash)
self.tear_down()
if self.device_id:
await cleanup_device_registry(
self.hass, self.device_id, self._config_entry.entry_id
)
else:
# Non-empty payload: Update tag scanner
_LOGGER.info("Updating tag scanner: %s", discovery_hash)
config = PLATFORM_SCHEMA(payload)
self._config = config
if self.device_id:
_update_device(self.hass, self._config_entry, config)
self._setup_from_config(config)
await self.subscribe_topics()
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
def _setup_from_config(self, config):
self._sub_state: dict[str, EntitySubscription] | None = None
self._value_template = MqttValueTemplate(
config.get(CONF_VALUE_TEMPLATE),
hass=self.hass,
).async_render_with_possible_json_value
async def setup(self):
"""Set up the MQTT tag scanner."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
await self.subscribe_topics()
if self.device_id:
self._remove_device_updated = self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.device_updated
)
self._remove_discovery = async_dispatcher_connect(
self.hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
self.discovery_update,
)
async_dispatcher_send(
self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
MqttDiscoveryDeviceUpdate.__init__(
self, hass, discovery_data, device_id, config_entry, LOG_NAME
)
async def subscribe_topics(self):
async def async_update(self, discovery_data: dict) -> None:
"""Handle MQTT tag discovery updates."""
# Update tag scanner
config = PLATFORM_SCHEMA(discovery_data)
self._config = config
self._value_template = MqttValueTemplate(
config.get(CONF_VALUE_TEMPLATE),
hass=self.hass,
).async_render_with_possible_json_value
update_device(self.hass, self._config_entry, config)
await self.subscribe_topics()
async def subscribe_topics(self) -> None:
"""Subscribe to MQTT topics."""
async def tag_scanned(msg):
@callback
async def tag_scanned(msg: ReceiveMessage) -> None:
tag_id = self._value_template(msg.payload, "").strip()
if not tag_id: # No output from template, ignore
return
@ -195,44 +152,12 @@ class MQTTTagScanner:
)
await subscription.async_subscribe_topics(self.hass, self._sub_state)
async def device_updated(self, event):
"""Handle the update or removal of a device."""
if not async_removed_from_device(
self.hass, event, self.device_id, self._config_entry.entry_id
):
return
# Stop subscribing to discovery updates to not trigger when we clear the
# discovery topic
self.tear_down()
# Clear the discovery topic so the entity is not rediscovered after a restart
discovery_topic = self.discovery_data[ATTR_DISCOVERY_TOPIC]
mqtt.publish(self.hass, discovery_topic, "", retain=True)
def tear_down(self):
async def async_tear_down(self) -> None:
"""Cleanup tag scanner."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
clear_discovery_hash(self.hass, discovery_hash)
if self.device_id:
self._remove_device_updated()
self._remove_discovery()
self._sub_state = subscription.async_unsubscribe_topics(
self.hass, self._sub_state
)
if self.device_id:
self.hass.data[TAGS][self.device_id].pop(discovery_id)
def _update_device(hass, config_entry, config):
"""Update device registry."""
device_registry = dr.async_get(hass)
config_entry_id = config_entry.entry_id
device_info = device_info_from_config(config[CONF_DEVICE])
if config_entry_id is not None and device_info is not None:
device_info["config_entry_id"] = config_entry_id
device_registry.async_get_or_create(**device_info)