diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 86eeca2017c..907b1a1dd11 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -76,8 +76,8 @@ from .const import ( # noqa: F401 DEFAULT_QOS, DEFAULT_RETAIN, DOMAIN, + ENTITY_PLATFORMS, MQTT_CONNECTION_STATE, - RELOADABLE_PLATFORMS, TEMPLATE_ERRORS, ) from .models import ( # noqa: F401 @@ -438,7 +438,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: for entity in list(mqtt_platform.entities.values()) if getattr(entity, "_discovery_data", None) is None and mqtt_platform.config_entry - and mqtt_platform.domain in RELOADABLE_PLATFORMS + and mqtt_platform.domain in ENTITY_PLATFORMS ] await asyncio.gather(*tasks) diff --git a/homeassistant/components/mqtt/abbreviations.py b/homeassistant/components/mqtt/abbreviations.py index 3c1d0abdb66..215585f465a 100644 --- a/homeassistant/components/mqtt/abbreviations.py +++ b/homeassistant/components/mqtt/abbreviations.py @@ -30,6 +30,7 @@ ABBREVIATIONS = { "cmd_on_tpl": "command_on_template", "cmd_t": "command_topic", "cmd_tpl": "command_template", + "cmps": "components", "cod_arm_req": "code_arm_required", "cod_dis_req": "code_disarm_required", "cod_form": "code_format", @@ -92,6 +93,7 @@ ABBREVIATIONS = { "min_mirs": "min_mireds", "max_temp": "max_temp", "min_temp": "min_temp", + "migr_discvry": "migrate_discovery", "mode": "mode", "mode_cmd_tpl": "mode_command_template", "mode_cmd_t": "mode_command_topic", @@ -109,6 +111,7 @@ ABBREVIATIONS = { "osc_cmd_tpl": "oscillation_command_template", "osc_stat_t": "oscillation_state_topic", "osc_val_tpl": "oscillation_value_template", + "p": "platform", "pause_cmd_t": "pause_command_topic", "pause_mw_cmd_tpl": "pause_command_template", "pct_cmd_t": "percentage_command_topic", diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 4fa8b7db02a..a626e0e5b28 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -376,7 +376,9 @@ class MQTT: self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict( set ) - self._wildcard_subscriptions: set[Subscription] = set() + # To ensure the wildcard subscriptions order is preserved, we use a dict + # with `None` values instead of a set. + self._wildcard_subscriptions: dict[Subscription, None] = {} # _retained_topics prevents a Subscription from receiving a # retained message more than once per topic. This prevents flooding # already active subscribers when new subscribers subscribe to a topic @@ -754,7 +756,7 @@ class MQTT: if subscription.is_simple_match: self._simple_subscriptions[subscription.topic].add(subscription) else: - self._wildcard_subscriptions.add(subscription) + self._wildcard_subscriptions[subscription] = None @callback def _async_untrack_subscription(self, subscription: Subscription) -> None: @@ -772,7 +774,7 @@ class MQTT: if not simple_subscriptions[topic]: del simple_subscriptions[topic] else: - self._wildcard_subscriptions.remove(subscription) + del self._wildcard_subscriptions[subscription] except (KeyError, ValueError) as exc: raise HomeAssistantError("Can't remove subscription twice") from exc diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index e672e2bac39..9f1c55a54e0 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -90,6 +90,7 @@ CONF_TEMP_MIN = "min_temp" CONF_CERTIFICATE = "certificate" CONF_CLIENT_KEY = "client_key" CONF_CLIENT_CERT = "client_cert" +CONF_COMPONENTS = "components" CONF_TLS_INSECURE = "tls_insecure" # Device and integration info options @@ -159,7 +160,7 @@ MQTT_CONNECTION_STATE = "mqtt_connection_state" PAYLOAD_EMPTY_JSON = "{}" PAYLOAD_NONE = "None" -RELOADABLE_PLATFORMS = [ +ENTITY_PLATFORMS = [ Platform.ALARM_CONTROL_PANEL, Platform.BINARY_SENSOR, Platform.BUTTON, @@ -190,7 +191,7 @@ RELOADABLE_PLATFORMS = [ TEMPLATE_ERRORS = (jinja2.TemplateError, TemplateError, TypeError, ValueError) -SUPPORTED_COMPONENTS = { +SUPPORTED_COMPONENTS = ( "alarm_control_panel", "binary_sensor", "button", @@ -219,4 +220,4 @@ SUPPORTED_COMPONENTS = { "vacuum", "valve", "water_heater", -} +) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index bdaf71f8740..a5ddb3ef4e6 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -12,6 +12,8 @@ import re import time from typing import TYPE_CHECKING, Any +import voluptuous as vol + from homeassistant.config_entries import ( SOURCE_MQTT, ConfigEntry, @@ -25,7 +27,7 @@ from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, ) -from homeassistant.helpers.service_info.mqtt import MqttServiceInfo +from homeassistant.helpers.service_info.mqtt import MqttServiceInfo, ReceivePayloadType from homeassistant.helpers.typing import DiscoveryInfoType from homeassistant.loader import async_get_mqtt from homeassistant.util.json import json_loads_object @@ -38,13 +40,14 @@ from .const import ( ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC, CONF_AVAILABILITY, + CONF_COMPONENTS, CONF_ORIGIN, CONF_TOPIC, DOMAIN, SUPPORTED_COMPONENTS, ) -from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage -from .schemas import MQTT_ORIGIN_INFO_SCHEMA +from .models import DATA_MQTT, MqttComponentConfig, MqttOriginInfo, ReceiveMessage +from .schemas import DEVICE_DISCOVERY_SCHEMA, MQTT_ORIGIN_INFO_SCHEMA, SHARED_OPTIONS from .util import async_forward_entry_setup_and_setup_discovery ABBREVIATIONS_SET = set(ABBREVIATIONS) @@ -70,10 +73,18 @@ MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat( TOPIC_BASE = "~" +CONF_MIGRATE_DISCOVERY = "migrate_discovery" + +MIGRATE_DISCOVERY_SCHEMA = vol.Schema( + {vol.Optional(CONF_MIGRATE_DISCOVERY): True}, +) + class MQTTDiscoveryPayload(dict[str, Any]): """Class to hold and MQTT discovery payload and discovery data.""" + device_discovery: bool = False + migrate_discovery: bool = False discovery_data: DiscoveryInfoType @@ -85,6 +96,24 @@ class MQTTIntegrationDiscoveryConfig: msg: ReceiveMessage +@callback +def _async_process_discovery_migration(payload: MQTTDiscoveryPayload) -> bool: + """Process a discovery migration request in the discovery payload.""" + # Allow abbreviation + if migr_discvry := (payload.pop("migr_discvry", None)): + payload[CONF_MIGRATE_DISCOVERY] = migr_discvry + if CONF_MIGRATE_DISCOVERY in payload: + try: + MIGRATE_DISCOVERY_SCHEMA(payload) + except vol.Invalid as exc: + _LOGGER.warning(exc) + return False + payload.migrate_discovery = True + payload.clear() + return True + return False + + def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None: """Clear entry from already discovered list.""" hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash) @@ -96,36 +125,51 @@ def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> @callback -def async_log_discovery_origin_info( - message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO -) -> None: - """Log information about the discovery and origin.""" - if not _LOGGER.isEnabledFor(level): - # bail early if logging is disabled - return +def get_origin_log_string( + discovery_payload: MQTTDiscoveryPayload, *, include_url: bool +) -> str: + """Get the origin information from a discovery payload for logging.""" if CONF_ORIGIN not in discovery_payload: - _LOGGER.log(level, message) - return + return "" origin_info: MqttOriginInfo = discovery_payload[CONF_ORIGIN] sw_version_log = "" if sw_version := origin_info.get("sw_version"): sw_version_log = f", version: {sw_version}" support_url_log = "" - if support_url := origin_info.get("support_url"): + if include_url and (support_url := get_origin_support_url(discovery_payload)): support_url_log = f", support URL: {support_url}" + return f" from external application {origin_info["name"]}{sw_version_log}{support_url_log}" + + +@callback +def get_origin_support_url(discovery_payload: MQTTDiscoveryPayload) -> str | None: + """Get the origin information support URL from a discovery payload.""" + if CONF_ORIGIN not in discovery_payload: + return "" + origin_info: MqttOriginInfo = discovery_payload[CONF_ORIGIN] + return origin_info.get("support_url") + + +@callback +def async_log_discovery_origin_info( + message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO +) -> None: + """Log information about the discovery and origin.""" + # We only log origin info once per device discovery + if not _LOGGER.isEnabledFor(level): + # bail out early if logging is disabled + return _LOGGER.log( level, - "%s from external application %s%s%s", + "%s%s", message, - origin_info["name"], - sw_version_log, - support_url_log, + get_origin_log_string(discovery_payload, include_url=True), ) @callback def _replace_abbreviations( - payload: Any | dict[str, Any], + payload: dict[str, Any] | str, abbreviations: dict[str, str], abbreviations_set: set[str], ) -> None: @@ -137,11 +181,20 @@ def _replace_abbreviations( @callback -def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None: +def _replace_all_abbreviations( + discovery_payload: dict[str, Any], component_only: bool = False +) -> None: """Replace all abbreviations in an MQTT discovery payload.""" _replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET) + if CONF_AVAILABILITY in discovery_payload: + for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]): + _replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET) + + if component_only: + return + if CONF_ORIGIN in discovery_payload: _replace_abbreviations( discovery_payload[CONF_ORIGIN], @@ -156,13 +209,15 @@ def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None: DEVICE_ABBREVIATIONS_SET, ) - if CONF_AVAILABILITY in discovery_payload: - for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]): - _replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET) + if CONF_COMPONENTS in discovery_payload: + if not isinstance(discovery_payload[CONF_COMPONENTS], dict): + return + for comp_conf in discovery_payload[CONF_COMPONENTS].values(): + _replace_all_abbreviations(comp_conf, component_only=True) @callback -def _replace_topic_base(discovery_payload: dict[str, Any]) -> None: +def _replace_topic_base(discovery_payload: MQTTDiscoveryPayload) -> None: """Replace topic base in MQTT discovery data.""" base = discovery_payload.pop(TOPIC_BASE) for key, value in discovery_payload.items(): @@ -182,6 +237,79 @@ def _replace_topic_base(discovery_payload: dict[str, Any]) -> None: availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}" +@callback +def _generate_device_config( + hass: HomeAssistant, + object_id: str, + node_id: str | None, + migrate_discovery: bool = False, +) -> MQTTDiscoveryPayload: + """Generate a cleanup or discovery migration message on device cleanup. + + If an empty payload, or a migrate discovery request is received for a device, + we forward an empty payload for all previously discovered components. + """ + mqtt_data = hass.data[DATA_MQTT] + device_node_id: str = f"{node_id} {object_id}" if node_id else object_id + config = MQTTDiscoveryPayload({CONF_DEVICE: {}, CONF_COMPONENTS: {}}) + config.migrate_discovery = migrate_discovery + comp_config = config[CONF_COMPONENTS] + for platform, discover_id in mqtt_data.discovery_already_discovered: + ids = discover_id.split(" ") + component_node_id = ids.pop(0) + component_object_id = " ".join(ids) + if not ids: + continue + if device_node_id == component_node_id: + comp_config[component_object_id] = {CONF_PLATFORM: platform} + + return config if comp_config else MQTTDiscoveryPayload({}) + + +@callback +def _parse_device_payload( + hass: HomeAssistant, + payload: ReceivePayloadType, + object_id: str, + node_id: str | None, +) -> MQTTDiscoveryPayload: + """Parse a device discovery payload. + + The device discovery payload is translated info the config payloads for every single + component inside the device based configuration. + An empty payload is translated in a cleanup, which forwards an empty payload to all + removed components. + """ + device_payload = MQTTDiscoveryPayload() + if payload == "": + if not (device_payload := _generate_device_config(hass, object_id, node_id)): + _LOGGER.warning( + "No device components to cleanup for %s, node_id '%s'", + object_id, + node_id, + ) + return device_payload + try: + device_payload = MQTTDiscoveryPayload(json_loads_object(payload)) + except ValueError: + _LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload) + return device_payload + if _async_process_discovery_migration(device_payload): + return _generate_device_config(hass, object_id, node_id, migrate_discovery=True) + _replace_all_abbreviations(device_payload) + try: + DEVICE_DISCOVERY_SCHEMA(device_payload) + except vol.Invalid as exc: + _LOGGER.warning( + "Invalid MQTT device discovery payload for %s, %s: '%s'", + object_id, + exc, + payload, + ) + return MQTTDiscoveryPayload({}) + return device_payload + + @callback def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool: """Parse and validate origin info from a single component discovery payload.""" @@ -199,6 +327,30 @@ def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool: return True +@callback +def _merge_common_device_options( + component_config: MQTTDiscoveryPayload, device_config: dict[str, Any] +) -> None: + """Merge common device options with the component config options. + + Common options are: + CONF_AVAILABILITY, + CONF_AVAILABILITY_MODE, + CONF_AVAILABILITY_TEMPLATE, + CONF_AVAILABILITY_TOPIC, + CONF_COMMAND_TOPIC, + CONF_PAYLOAD_AVAILABLE, + CONF_PAYLOAD_NOT_AVAILABLE, + CONF_STATE_TOPIC, + Common options in the body of the device based config are inherited into + the component. Unless the option is explicitly specified at component level, + in that case the option at component level will override the common option. + """ + for option in SHARED_OPTIONS: + if option in device_config and option not in component_config: + component_config[option] = device_config.get(option) + + async def async_start( # noqa: C901 hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry ) -> None: @@ -243,8 +395,7 @@ async def async_start( # noqa: C901 _LOGGER.warning( ( "Received message on illegal discovery topic '%s'. The topic" - " contains " - "not allowed characters. For more information see " + " contains non allowed characters. For more information see " "https://www.home-assistant.io/integrations/mqtt/#discovery-topic" ), topic, @@ -253,51 +404,118 @@ async def async_start( # noqa: C901 component, node_id, object_id = match.groups() - if payload: + discovered_components: list[MqttComponentConfig] = [] + if component == CONF_DEVICE: + # Process device based discovery message and regenerate + # cleanup config for the all the components that are being removed. + # This is done when a component in the device config is omitted and detected + # as being removed, or when the device config update payload is empty. + # In that case this will regenerate a cleanup message for all every already + # discovered components that were linked to the initial device discovery. + device_discovery_payload = _parse_device_payload( + hass, payload, object_id, node_id + ) + if not device_discovery_payload: + return + device_config: dict[str, Any] + origin_config: dict[str, Any] | None + component_configs: dict[str, dict[str, Any]] + device_config = device_discovery_payload[CONF_DEVICE] + origin_config = device_discovery_payload.get(CONF_ORIGIN) + component_configs = device_discovery_payload[CONF_COMPONENTS] + for component_id, config in component_configs.items(): + component = config.pop(CONF_PLATFORM) + # The object_id in the device discovery topic is the unique identifier. + # It is used as node_id for the components it contains. + component_node_id = object_id + # The component_id in the discovery playload is used as object_id + # If we have an additional node_id in the discovery topic, + # we extend the component_id with it. + component_object_id = ( + f"{node_id} {component_id}" if node_id else component_id + ) + # We add wrapper to the discovery payload with the discovery data. + # If the dict is empty after removing the platform, the payload is + # assumed to remove the existing config and we do not want to add + # device or orig or shared availability attributes. + if discovery_payload := MQTTDiscoveryPayload(config): + discovery_payload[CONF_DEVICE] = device_config + discovery_payload[CONF_ORIGIN] = origin_config + # Only assign shared config options + # when they are not set at entity level + _merge_common_device_options( + discovery_payload, device_discovery_payload + ) + discovery_payload.device_discovery = True + discovery_payload.migrate_discovery = ( + device_discovery_payload.migrate_discovery + ) + discovered_components.append( + MqttComponentConfig( + component, + component_object_id, + component_node_id, + discovery_payload, + ) + ) + _LOGGER.debug( + "Process device discovery payload %s", device_discovery_payload + ) + device_discovery_id = f"{node_id} {object_id}" if node_id else object_id + message = f"Processing device discovery for '{device_discovery_id}'" + async_log_discovery_origin_info( + message, MQTTDiscoveryPayload(device_discovery_payload) + ) + + else: + # Process component based discovery message try: - discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload)) + discovery_payload = MQTTDiscoveryPayload( + json_loads_object(payload) if payload else {} + ) except ValueError: _LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload) return - _replace_all_abbreviations(discovery_payload) - if not _valid_origin_info(discovery_payload): - return + if not _async_process_discovery_migration(discovery_payload): + _replace_all_abbreviations(discovery_payload) + if not _valid_origin_info(discovery_payload): + return + discovered_components.append( + MqttComponentConfig(component, object_id, node_id, discovery_payload) + ) + + discovery_pending_discovered = mqtt_data.discovery_pending_discovered + for component_config in discovered_components: + component = component_config.component + node_id = component_config.node_id + object_id = component_config.object_id + discovery_payload = component_config.discovery_payload + if TOPIC_BASE in discovery_payload: _replace_topic_base(discovery_payload) - else: - discovery_payload = MQTTDiscoveryPayload({}) - # If present, the node_id will be included in the discovered object id - discovery_id = f"{node_id} {object_id}" if node_id else object_id - discovery_hash = (component, discovery_id) + # If present, the node_id will be included in the discovery_id. + discovery_id = f"{node_id} {object_id}" if node_id else object_id + discovery_hash = (component, discovery_id) - if discovery_payload: # Attach MQTT topic to the payload, used for debug prints - setattr( - discovery_payload, - "__configuration_source__", - f"MQTT (topic: '{topic}')", - ) - discovery_data = { + discovery_payload.discovery_data = { ATTR_DISCOVERY_HASH: discovery_hash, ATTR_DISCOVERY_PAYLOAD: discovery_payload, ATTR_DISCOVERY_TOPIC: topic, } - setattr(discovery_payload, "discovery_data", discovery_data) - discovery_payload[CONF_PLATFORM] = "mqtt" + if discovery_hash in discovery_pending_discovered: + pending = discovery_pending_discovered[discovery_hash]["pending"] + pending.appendleft(discovery_payload) + _LOGGER.debug( + "Component has already been discovered: %s %s, queuing update", + component, + discovery_id, + ) + return - if discovery_hash in mqtt_data.discovery_pending_discovered: - pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"] - pending.appendleft(discovery_payload) - _LOGGER.debug( - "Component has already been discovered: %s %s, queuing update", - component, - discovery_id, - ) - return - - async_process_discovery_payload(component, discovery_id, discovery_payload) + async_process_discovery_payload(component, discovery_id, discovery_payload) @callback def async_process_discovery_payload( @@ -305,7 +523,7 @@ async def async_start( # noqa: C901 ) -> None: """Process the payload of a new discovery.""" - _LOGGER.debug("Process discovery payload %s", payload) + _LOGGER.debug("Process component discovery payload %s", payload) discovery_hash = (component, discovery_id) already_discovered = discovery_hash in mqtt_data.discovery_already_discovered @@ -362,6 +580,8 @@ async def async_start( # noqa: C901 0, job_type=HassJobType.Callback, ) + # Subscribe first for platform discovery wildcard topics first, + # and then subscribe device discovery wildcard topics. for topic in chain( ( f"{discovery_topic}/{component}/+/config" @@ -371,6 +591,10 @@ async def async_start( # noqa: C901 f"{discovery_topic}/{component}/+/+/config" for component in SUPPORTED_COMPONENTS ), + ( + f"{discovery_topic}/device/+/config", + f"{discovery_topic}/device/+/+/config", + ), ) ] diff --git a/homeassistant/components/mqtt/entity.py b/homeassistant/components/mqtt/entity.py index c25ecb068ec..46b2c9e1d42 100644 --- a/homeassistant/components/mqtt/entity.py +++ b/homeassistant/components/mqtt/entity.py @@ -104,6 +104,8 @@ from .discovery import ( MQTT_DISCOVERY_UPDATED, MQTTDiscoveryPayload, clear_discovery_hash, + get_origin_log_string, + get_origin_support_url, set_discovery_hash, ) from .models import ( @@ -591,6 +593,7 @@ async def cleanup_device_registry( entity_registry = er.async_get(hass) if ( device_id + and device_id not in device_registry.deleted_devices and config_entry_id and not er.async_entries_for_device( entity_registry, device_id, include_disabled_entities=False @@ -672,6 +675,7 @@ class MqttDiscoveryDeviceUpdateMixin(ABC): self._config_entry = config_entry self._config_entry_id = config_entry.entry_id self._skip_device_removal: bool = False + self._migrate_discovery: str | None = None discovery_hash = get_discovery_hash(discovery_data) self._remove_discovery_updated = async_dispatcher_connect( @@ -704,12 +708,95 @@ class MqttDiscoveryDeviceUpdateMixin(ABC): ) -> None: """Handle discovery update.""" discovery_hash = get_discovery_hash(self._discovery_data) + # Start discovery migration or rollback if migrate_discovery flag is set + # and the discovery topic is valid and not yet migrating + if ( + discovery_payload.migrate_discovery + and self._migrate_discovery is None + and self._discovery_data[ATTR_DISCOVERY_TOPIC] + == discovery_payload.discovery_data[ATTR_DISCOVERY_TOPIC] + ): + self._migrate_discovery = self._discovery_data[ATTR_DISCOVERY_TOPIC] + discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH] + origin_info = get_origin_log_string( + self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False + ) + action = "Rollback" if discovery_payload.device_discovery else "Migration" + schema_type = "platform" if discovery_payload.device_discovery else "device" + _LOGGER.info( + "%s to MQTT %s discovery schema started for %s '%s'" + "%s on topic %s. To complete %s, publish a %s discovery " + "message with %s '%s'. After completed %s, " + "publish an empty (retained) payload to %s", + action, + schema_type, + discovery_hash[0], + discovery_hash[1], + origin_info, + self._migrate_discovery, + action.lower(), + schema_type, + discovery_hash[0], + discovery_hash[1], + action.lower(), + self._migrate_discovery, + ) + + # Cleanup platform resources + await self.async_tear_down() + # Unregister and clean discovery + stop_discovery_updates( + self.hass, self._discovery_data, self._remove_discovery_updated + ) + send_discovery_done(self.hass, self._discovery_data) + return + _LOGGER.debug( "Got update for %s with hash: %s '%s'", self.log_name, discovery_hash, discovery_payload, ) + new_discovery_topic = discovery_payload.discovery_data[ATTR_DISCOVERY_TOPIC] + + # Abort early if an update is not received via the registered discovery topic. + # This can happen if a device and single component discovery payload + # share the same discovery ID. + if self._discovery_data[ATTR_DISCOVERY_TOPIC] != new_discovery_topic: + # Prevent illegal updates + old_origin_info = get_origin_log_string( + self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False + ) + new_origin_info = get_origin_log_string( + discovery_payload.discovery_data[ATTR_DISCOVERY_PAYLOAD], + include_url=False, + ) + new_origin_support_url = get_origin_support_url( + discovery_payload.discovery_data[ATTR_DISCOVERY_PAYLOAD] + ) + if new_origin_support_url: + get_support = f"for support visit {new_origin_support_url}" + else: + get_support = ( + "for documentation on migration to device schema or rollback to " + "discovery schema, visit https://www.home-assistant.io/integrations/" + "mqtt/#migration-from-single-component-to-device-based-discovery" + ) + _LOGGER.warning( + "Received a conflicting MQTT discovery message for %s '%s' which was " + "previously discovered on topic %s%s; the conflicting discovery " + "message was received on topic %s%s; %s", + discovery_hash[0], + discovery_hash[1], + self._discovery_data[ATTR_DISCOVERY_TOPIC], + old_origin_info, + new_discovery_topic, + new_origin_info, + get_support, + ) + send_discovery_done(self.hass, self._discovery_data) + return + if ( discovery_payload and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD] @@ -806,6 +893,7 @@ class MqttDiscoveryUpdateMixin(Entity): mqtt_data = hass.data[DATA_MQTT] self._registry_hooks = mqtt_data.discovery_registry_hooks discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] + self._migrate_discovery: str | None = None if discovery_hash in self._registry_hooks: self._registry_hooks.pop(discovery_hash)() @@ -863,7 +951,12 @@ class MqttDiscoveryUpdateMixin(Entity): if TYPE_CHECKING: assert self._discovery_data self._cleanup_discovery_on_remove() - await self._async_remove_state_and_registry_entry() + if self._migrate_discovery is None: + # Unload and cleanup registry + await self._async_remove_state_and_registry_entry() + else: + # Only unload the entity + await self.async_remove(force_remove=True) send_discovery_done(self.hass, self._discovery_data) @callback @@ -878,18 +971,102 @@ class MqttDiscoveryUpdateMixin(Entity): """ if TYPE_CHECKING: assert self._discovery_data - discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH] + discovery_hash = get_discovery_hash(self._discovery_data) + # Start discovery migration or rollback if migrate_discovery flag is set + # and the discovery topic is valid and not yet migrating + if ( + payload.migrate_discovery + and self._migrate_discovery is None + and self._discovery_data[ATTR_DISCOVERY_TOPIC] + == payload.discovery_data[ATTR_DISCOVERY_TOPIC] + ): + if self.unique_id is None or self.device_info is None: + _LOGGER.error( + "Discovery migration is not possible for " + "for entity %s on topic %s. A unique_id " + "and device context is required, got unique_id: %s, device: %s", + self.entity_id, + self._discovery_data[ATTR_DISCOVERY_TOPIC], + self.unique_id, + self.device_info, + ) + send_discovery_done(self.hass, self._discovery_data) + return + + self._migrate_discovery = self._discovery_data[ATTR_DISCOVERY_TOPIC] + discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH] + origin_info = get_origin_log_string( + self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False + ) + action = "Rollback" if payload.device_discovery else "Migration" + schema_type = "platform" if payload.device_discovery else "device" + _LOGGER.info( + "%s to MQTT %s discovery schema started for entity %s" + "%s on topic %s. To complete %s, publish a %s discovery " + "message with %s entity '%s'. After completed %s, " + "publish an empty (retained) payload to %s", + action, + schema_type, + self.entity_id, + origin_info, + self._migrate_discovery, + action.lower(), + schema_type, + discovery_hash[0], + discovery_hash[1], + action.lower(), + self._migrate_discovery, + ) + old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD] _LOGGER.debug( "Got update for entity with hash: %s '%s'", discovery_hash, payload, ) - old_payload: DiscoveryInfoType - old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD] + new_discovery_topic = payload.discovery_data[ATTR_DISCOVERY_TOPIC] + # Abort early if an update is not received via the registered discovery topic. + # This can happen if a device and single component discovery payload + # share the same discovery ID. + if self._discovery_data[ATTR_DISCOVERY_TOPIC] != new_discovery_topic: + # Prevent illegal updates + old_origin_info = get_origin_log_string( + self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False + ) + new_origin_info = get_origin_log_string( + payload.discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False + ) + new_origin_support_url = get_origin_support_url( + payload.discovery_data[ATTR_DISCOVERY_PAYLOAD] + ) + if new_origin_support_url: + get_support = f"for support visit {new_origin_support_url}" + else: + get_support = ( + "for documentation on migration to device schema or rollback to " + "discovery schema, visit https://www.home-assistant.io/integrations/" + "mqtt/#migration-from-single-component-to-device-based-discovery" + ) + _LOGGER.warning( + "Received a conflicting MQTT discovery message for entity %s; the " + "entity was previously discovered on topic %s%s; the conflicting " + "discovery message was received on topic %s%s; %s", + self.entity_id, + self._discovery_data[ATTR_DISCOVERY_TOPIC], + old_origin_info, + new_discovery_topic, + new_origin_info, + get_support, + ) + send_discovery_done(self.hass, self._discovery_data) + return + debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id) if not payload: # Empty payload: Remove component - _LOGGER.info("Removing component: %s", self.entity_id) + if self._migrate_discovery is None: + _LOGGER.info("Removing component: %s", self.entity_id) + else: + _LOGGER.info("Unloading component: %s", self.entity_id) self.hass.async_create_task( self._async_process_discovery_update_and_remove() ) diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index f7abbc29464..34c1f304944 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -410,5 +410,15 @@ class MqttData: tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict) +@dataclass(slots=True) +class MqttComponentConfig: + """(component, object_id, node_id, discovery_payload).""" + + component: str + object_id: str + node_id: str | None + discovery_payload: MQTTDiscoveryPayload + + DATA_MQTT: HassKey[MqttData] = HassKey("mqtt") DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available") diff --git a/homeassistant/components/mqtt/schemas.py b/homeassistant/components/mqtt/schemas.py index 0badd325dab..5e942c24738 100644 --- a/homeassistant/components/mqtt/schemas.py +++ b/homeassistant/components/mqtt/schemas.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + import voluptuous as vol from homeassistant.const import ( @@ -11,6 +13,7 @@ from homeassistant.const import ( CONF_MODEL, CONF_MODEL_ID, CONF_NAME, + CONF_PLATFORM, CONF_UNIQUE_ID, CONF_VALUE_TEMPLATE, ) @@ -25,10 +28,13 @@ from .const import ( CONF_AVAILABILITY_MODE, CONF_AVAILABILITY_TEMPLATE, CONF_AVAILABILITY_TOPIC, + CONF_COMMAND_TOPIC, + CONF_COMPONENTS, CONF_CONFIGURATION_URL, CONF_CONNECTIONS, CONF_DEPRECATED_VIA_HUB, CONF_ENABLED_BY_DEFAULT, + CONF_ENCODING, CONF_ENTITY_PICTURE, CONF_HW_VERSION, CONF_IDENTIFIERS, @@ -39,7 +45,9 @@ from .const import ( CONF_ORIGIN, CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, + CONF_QOS, CONF_SERIAL_NUMBER, + CONF_STATE_TOPIC, CONF_SUGGESTED_AREA, CONF_SUPPORT_URL, CONF_SW_VERSION, @@ -47,10 +55,34 @@ from .const import ( CONF_VIA_DEVICE, DEFAULT_PAYLOAD_AVAILABLE, DEFAULT_PAYLOAD_NOT_AVAILABLE, + ENTITY_PLATFORMS, + SUPPORTED_COMPONENTS, ) -from .util import valid_subscribe_topic +from .util import valid_publish_topic, valid_qos_schema, valid_subscribe_topic -MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema( +# Device discovery options that are also available at entity component level +SHARED_OPTIONS = [ + CONF_AVAILABILITY, + CONF_AVAILABILITY_MODE, + CONF_AVAILABILITY_TEMPLATE, + CONF_AVAILABILITY_TOPIC, + CONF_COMMAND_TOPIC, + CONF_PAYLOAD_AVAILABLE, + CONF_PAYLOAD_NOT_AVAILABLE, + CONF_STATE_TOPIC, +] + +MQTT_ORIGIN_INFO_SCHEMA = vol.All( + vol.Schema( + { + vol.Required(CONF_NAME): cv.string, + vol.Optional(CONF_SW_VERSION): cv.string, + vol.Optional(CONF_SUPPORT_URL): cv.configuration_url, + } + ), +) + +_MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema( { vol.Exclusive(CONF_AVAILABILITY_TOPIC, "availability"): valid_subscribe_topic, vol.Optional(CONF_AVAILABILITY_TEMPLATE): cv.template, @@ -63,7 +95,7 @@ MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema( } ) -MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema( +_MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema( { vol.Optional(CONF_AVAILABILITY_MODE, default=AVAILABILITY_LATEST): vol.All( cv.string, vol.In(AVAILABILITY_MODES) @@ -87,8 +119,8 @@ MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema( } ) -MQTT_AVAILABILITY_SCHEMA = MQTT_AVAILABILITY_SINGLE_SCHEMA.extend( - MQTT_AVAILABILITY_LIST_SCHEMA.schema +_MQTT_AVAILABILITY_SCHEMA = _MQTT_AVAILABILITY_SINGLE_SCHEMA.extend( + _MQTT_AVAILABILITY_LIST_SCHEMA.schema ) @@ -138,7 +170,7 @@ MQTT_ORIGIN_INFO_SCHEMA = vol.All( ), ) -MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend( +MQTT_ENTITY_COMMON_SCHEMA = _MQTT_AVAILABILITY_SCHEMA.extend( { vol.Optional(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, vol.Optional(CONF_ENTITY_PICTURE): cv.url, @@ -152,3 +184,35 @@ MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend( vol.Optional(CONF_UNIQUE_ID): cv.string, } ) + +_UNIQUE_ID_SCHEMA = vol.Schema( + {vol.Required(CONF_UNIQUE_ID): cv.string}, +).extend({}, extra=True) + + +def check_unique_id(config: dict[str, Any]) -> dict[str, Any]: + """Check if a unique ID is set in case an entity platform is configured.""" + platform = config[CONF_PLATFORM] + if platform in ENTITY_PLATFORMS and len(config.keys()) > 1: + _UNIQUE_ID_SCHEMA(config) + return config + + +_COMPONENT_CONFIG_SCHEMA = vol.All( + vol.Schema( + {vol.Required(CONF_PLATFORM): vol.In(SUPPORTED_COMPONENTS)}, + ).extend({}, extra=True), + check_unique_id, +) + +DEVICE_DISCOVERY_SCHEMA = _MQTT_AVAILABILITY_SCHEMA.extend( + { + vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, + vol.Required(CONF_COMPONENTS): vol.Schema({str: _COMPONENT_CONFIG_SCHEMA}), + vol.Required(CONF_ORIGIN): MQTT_ORIGIN_INFO_SCHEMA, + vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic, + vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_QOS): valid_qos_schema, + vol.Optional(CONF_ENCODING): cv.string, + } +) diff --git a/tests/components/mqtt/conftest.py b/tests/components/mqtt/conftest.py index e22ae297498..22f0416a2c6 100644 --- a/tests/components/mqtt/conftest.py +++ b/tests/components/mqtt/conftest.py @@ -4,7 +4,7 @@ import asyncio from collections.abc import AsyncGenerator, Generator from random import getrandbits from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -122,3 +122,10 @@ def record_calls(recorded_calls: list[ReceiveMessage]) -> MessageCallbackType: recorded_calls.append(msg) return record_calls + + +@pytest.fixture +def tag_mock() -> Generator[AsyncMock]: + """Fixture to mock tag.""" + with patch("homeassistant.components.tag.async_scan_tag") as mock_tag: + yield mock_tag diff --git a/tests/components/mqtt/test_client.py b/tests/components/mqtt/test_client.py index f2af337bc5e..164c164cdfc 100644 --- a/tests/components/mqtt/test_client.py +++ b/tests/components/mqtt/test_client.py @@ -1716,6 +1716,64 @@ async def test_mqtt_subscribes_topics_on_connect( assert ("still/pending", 1) in subscribe_calls +@pytest.mark.parametrize("mqtt_config_entry_data", [ENTRY_DEFAULT_BIRTH_MESSAGE]) +async def test_mqtt_subscribes_wildcard_topics_in_correct_order( + hass: HomeAssistant, + mock_debouncer: asyncio.Event, + setup_with_birth_msg_client_mock: MqttMockPahoClient, + record_calls: MessageCallbackType, +) -> None: + """Test subscription to wildcard topics on connect in the order of subscription.""" + mqtt_client_mock = setup_with_birth_msg_client_mock + + mock_debouncer.clear() + await mqtt.async_subscribe(hass, "integration/test#", record_calls) + await mqtt.async_subscribe(hass, "integration/kitchen_sink#", record_calls) + await mock_debouncer.wait() + + def _assert_subscription_order(): + discovery_subscribes = [ + f"homeassistant/{platform}/+/config" for platform in SUPPORTED_COMPONENTS + ] + discovery_subscribes.extend( + [ + f"homeassistant/{platform}/+/+/config" + for platform in SUPPORTED_COMPONENTS + ] + ) + discovery_subscribes.extend( + ["homeassistant/device/+/config", "homeassistant/device/+/+/config"] + ) + discovery_subscribes.extend(["integration/test#", "integration/kitchen_sink#"]) + + expected_discovery_subscribes = discovery_subscribes.copy() + + # Assert we see the expected subscribes and in the correct order + actual_subscribes = [ + discovery_subscribes.pop(0) + for call in help_all_subscribe_calls(mqtt_client_mock) + if discovery_subscribes and discovery_subscribes[0] == call[0] + ] + + # Assert we have processed all items and that they are in the correct order + assert len(discovery_subscribes) == 0 + assert actual_subscribes == expected_discovery_subscribes + + # Assert the initial wildcard topic subscription order + _assert_subscription_order() + + mqtt_client_mock.on_disconnect(Mock(), None, 0) + + mqtt_client_mock.reset_mock() + + mock_debouncer.clear() + mqtt_client_mock.on_connect(Mock(), None, 0, 0) + await mock_debouncer.wait() + + # Assert the wildcard topic subscription order after a reconnect + _assert_subscription_order() + + @pytest.mark.parametrize( "mqtt_config_entry_data", [ENTRY_DEFAULT_BIRTH_MESSAGE | {mqtt.CONF_DISCOVERY: False}], diff --git a/tests/components/mqtt/test_common.py b/tests/components/mqtt/test_common.py index 82d90f2cee7..95a26daf562 100644 --- a/tests/components/mqtt/test_common.py +++ b/tests/components/mqtt/test_common.py @@ -69,6 +69,7 @@ DEFAULT_CONFIG_DEVICE_INFO_MAC = { _SENTINEL = object() DISCOVERY_COUNT = len(MQTT) +DEVICE_DISCOVERY_COUNT = 2 type _MqttMessageType = list[tuple[str, str]] type _AttributesType = list[tuple[str, Any]] @@ -1189,7 +1190,10 @@ async def help_test_entity_id_update_subscriptions( assert state is not None assert ( mqtt_mock.async_subscribe.call_count - == len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT + == len(topics) + + 2 * len(SUPPORTED_COMPONENTS) + + DISCOVERY_COUNT + + DEVICE_DISCOVERY_COUNT ) for topic in topics: mqtt_mock.async_subscribe.assert_any_call( diff --git a/tests/components/mqtt/test_device_trigger.py b/tests/components/mqtt/test_device_trigger.py index fd2bf46f828..009a0315029 100644 --- a/tests/components/mqtt/test_device_trigger.py +++ b/tests/components/mqtt/test_device_trigger.py @@ -26,22 +26,42 @@ def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None: """Stub copying the blueprints to the config folder.""" +@pytest.mark.parametrize( + ("discovery_topic", "data"), + [ + ( + "homeassistant/device_automation/0AFFD2/bla/config", + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }', + ), + ( + "homeassistant/device/0AFFD2/config", + '{ "device":{"identifiers":["0AFFD2"]},' + ' "o": {"name": "foobar"}, "cmps": ' + '{ "bla": {' + ' "automation_type":"trigger", ' + ' "payload": "short_press",' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1",' + ' "platform":"device_automation"}}}', + ), + ], +) async def test_get_triggers( hass: HomeAssistant, device_registry: dr.DeviceRegistry, mqtt_mock_entry: MqttMockHAClientGenerator, + discovery_topic: str, + data: str, ) -> None: """Test we get the expected triggers from a discovered mqtt device.""" await mqtt_mock_entry() - data1 = ( - '{ "automation_type":"trigger",' - ' "device":{"identifiers":["0AFFD2"]},' - ' "payload": "short_press",' - ' "topic": "foobar/triggers/button1",' - ' "type": "button_short_press",' - ' "subtype": "button_1" }' - ) - async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data1) + async_fire_mqtt_message(hass, discovery_topic, data) await hass.async_block_till_done() device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 6b8feac4e48..e49e7a27c8d 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -6,12 +6,14 @@ import json import logging from pathlib import Path import re -from unittest.mock import AsyncMock, call, patch +from typing import Any +from unittest.mock import ANY, AsyncMock, call, patch import pytest from homeassistant import config_entries from homeassistant.components import mqtt +from homeassistant.components.device_automation import DeviceAutomationType from homeassistant.components.mqtt.abbreviations import ( ABBREVIATIONS, DEVICE_ABBREVIATIONS, @@ -46,12 +48,14 @@ from homeassistant.util.signal_type import SignalTypeFormat from .conftest import ENTRY_DEFAULT_BIRTH_MESSAGE from .test_common import help_all_subscribe_calls, help_test_unload_config_entry +from .test_tag import DEFAULT_TAG_ID, DEFAULT_TAG_SCAN from tests.common import ( MockConfigEntry, MockModule, async_capture_events, async_fire_mqtt_message, + async_get_device_automations, mock_config_flow, mock_integration, mock_platform, @@ -62,6 +66,86 @@ from tests.typing import ( WebSocketGenerator, ) +TEST_SINGLE_CONFIGS = [ + ( + "homeassistant/device_automation/0AFFD2/bla1/config", + { + "device": {"identifiers": ["0AFFD2"], "name": "test_device"}, + "o": {"name": "Foo2Mqtt", "sw": "1.40.2", "url": "https://www.foo2mqtt.io"}, + "automation_type": "trigger", + "payload": "short_press", + "topic": "foobar/triggers/button1", + "type": "button_short_press", + "subtype": "button_1", + }, + ), + ( + "homeassistant/sensor/0AFFD2/bla2/config", + { + "device": {"identifiers": ["0AFFD2"], "name": "test_device"}, + "o": {"name": "Foo2Mqtt", "sw": "1.40.2", "url": "https://www.foo2mqtt.io"}, + "state_topic": "foobar/sensors/bla2/state", + "unique_id": "bla002", + }, + ), + ( + "homeassistant/tag/0AFFD2/bla3/config", + { + "device": {"identifiers": ["0AFFD2"], "name": "test_device"}, + "o": {"name": "Foo2Mqtt", "sw": "1.40.2", "url": "https://www.foo2mqtt.io"}, + "topic": "foobar/tags/bla3/see", + }, + ), +] +TEST_DEVICE_CONFIG = { + "device": {"identifiers": ["0AFFD2"], "name": "test_device"}, + "o": {"name": "Foo2Mqtt", "sw": "1.50.0", "url": "https://www.foo2mqtt.io"}, + "cmps": { + "bla1": { + "platform": "device_automation", + "automation_type": "trigger", + "payload": "short_press", + "topic": "foobar/triggers/button1", + "type": "button_short_press", + "subtype": "button_1", + }, + "bla2": { + "platform": "sensor", + "state_topic": "foobar/sensors/bla2/state", + "unique_id": "bla002", + "name": "mqtt_sensor", + }, + "bla3": { + "platform": "tag", + "topic": "foobar/tags/bla3/see", + }, + }, +} +TEST_DEVICE_DISCOVERY_TOPIC = "homeassistant/device/0AFFD2/config" + + +async def help_check_discovered_items( + hass: HomeAssistant, device_registry: dr.DeviceRegistry, tag_mock: AsyncMock +) -> None: + """Help checking discovered test items are still available.""" + + # Check the device_trigger was discovered + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + assert device_entry is not None + triggers = await async_get_device_automations( + hass, DeviceAutomationType.TRIGGER, device_entry.id + ) + assert len(triggers) == 1 + # Check the sensor was discovered + state = hass.states.get("sensor.test_device_mqtt_sensor") + assert state is not None + + # Check the tag works + async_fire_mqtt_message(hass, "foobar/tags/bla3/see", DEFAULT_TAG_SCAN) + await hass.async_block_till_done() + tag_mock.assert_called_once_with(ANY, DEFAULT_TAG_ID, device_entry.id) + tag_mock.reset_mock() + @pytest.fixture def mqtt_data_flow_calls() -> list[MqttServiceInfo]: @@ -135,6 +219,8 @@ async def test_subscribing_config_topic( [ ("homeassistant/binary_sensor/bla/not_config", False), ("homeassistant/binary_sensor/rörkrökare/config", True), + ("homeassistant/device/bla/not_config", False), + ("homeassistant/device/rörkrökare/config", True), ], ) async def test_invalid_topic( @@ -163,10 +249,15 @@ async def test_invalid_topic( caplog.clear() +@pytest.mark.parametrize( + "discovery_topic", + ["homeassistant/binary_sensor/bla/config", "homeassistant/device/bla/config"], +) async def test_invalid_json( hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator, caplog: pytest.LogCaptureFixture, + discovery_topic: str, ) -> None: """Test sending in invalid JSON.""" await mqtt_mock_entry() @@ -175,9 +266,7 @@ async def test_invalid_json( ) as mock_dispatcher_send: mock_dispatcher_send = AsyncMock(return_value=None) - async_fire_mqtt_message( - hass, "homeassistant/binary_sensor/bla/config", "not json" - ) + async_fire_mqtt_message(hass, discovery_topic, "not json") await hass.async_block_till_done() assert "Unable to parse JSON" in caplog.text assert not mock_dispatcher_send.called @@ -226,6 +315,56 @@ async def test_invalid_config( assert "Error 'expected int for dictionary value @ data['qos']'" in caplog.text +async def test_invalid_device_discovery_config( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test sending in JSON that violates the discovery schema if device or platform key is missing.""" + await mqtt_mock_entry() + async_fire_mqtt_message( + hass, + "homeassistant/device/bla/config", + '{ "o": {"name": "foobar"}, "cmps": ' + '{ "acp1": {"name": "abc", "state_topic": "home/alarm", ' + '"unique_id": "very_unique",' + '"command_topic": "home/alarm/set", ' + '"platform":"alarm_control_panel"}}}', + ) + await hass.async_block_till_done() + assert ( + "Invalid MQTT device discovery payload for bla, " + "required key not provided @ data['device']" in caplog.text + ) + + caplog.clear() + async_fire_mqtt_message( + hass, + "homeassistant/device/bla/config", + '{ "o": {"name": "foobar"}, "dev": {"identifiers": ["ABDE03"]}, ' + '"cmps": { "acp1": {"name": "abc", "state_topic": "home/alarm", ' + '"command_topic": "home/alarm/set" }}}', + ) + await hass.async_block_till_done() + assert ( + "Invalid MQTT device discovery payload for bla, " + "required key not provided @ data['components']['acp1']['platform']" + in caplog.text + ) + + caplog.clear() + async_fire_mqtt_message( + hass, + "homeassistant/device/bla/config", + '{ "o": {"name": "foobar"}, "dev": {"identifiers": ["ABDE03"]}, ' '"cmps": ""}', + ) + await hass.async_block_till_done() + assert ( + "Invalid MQTT device discovery payload for bla, " + "expected a dictionary for dictionary value @ data['components']" in caplog.text + ) + + async def test_only_valid_components( hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator, @@ -268,27 +407,70 @@ async def test_correct_config_discovery( assert ("binary_sensor", "bla") in hass.data["mqtt"].discovery_already_discovered +@pytest.mark.parametrize( + ("discovery_topic", "payloads", "discovery_id"), + [ + ( + "homeassistant/binary_sensor/bla/config", + ( + '{"name":"Beer","state_topic": "test-topic",' + '"unique_id": "very_unique1",' + '"o":{"name":"bla2mqtt","sw":"1.0"},' + '"dev":{"identifiers":["bla"],"name": "bla"}}', + '{"name":"Milk","state_topic": "test-topic",' + '"unique_id": "very_unique1",' + '"o":{"name":"bla2mqtt","sw":"1.1",' + '"url":"https://bla2mqtt.example.com/support"},' + '"dev":{"identifiers":["bla"],"name": "bla"}}', + ), + "bla", + ), + ( + "homeassistant/device/bla/config", + ( + '{"cmps":{"bin_sens1":{"platform":"binary_sensor",' + '"unique_id": "very_unique1",' + '"name":"Beer","state_topic": "test-topic"}},' + '"o":{"name":"bla2mqtt","sw":"1.0"},' + '"dev":{"identifiers":["bla"],"name": "bla"}}', + '{"cmps":{"bin_sens1":{"platform":"binary_sensor",' + '"unique_id": "very_unique1",' + '"name":"Milk","state_topic": "test-topic"}},' + '"o":{"name":"bla2mqtt","sw":"1.1",' + '"url":"https://bla2mqtt.example.com/support"},' + '"dev":{"identifiers":["bla"],"name": "bla"}}', + ), + "bla bin_sens1", + ), + ], +) async def test_discovery_integration_info( hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator, caplog: pytest.LogCaptureFixture, + discovery_topic: str, + payloads: tuple[str, str], + discovery_id: str, ) -> None: - """Test logging discovery of new and updated items.""" + """Test discovery of integration info.""" await mqtt_mock_entry() async_fire_mqtt_message( hass, - "homeassistant/binary_sensor/bla/config", - '{ "name": "Beer", "state_topic": "test-topic", "o": {"name": "bla2mqtt", "sw": "1.0" } }', + discovery_topic, + payloads[0], ) await hass.async_block_till_done() - state = hass.states.get("binary_sensor.beer") + state = hass.states.get("binary_sensor.bla_beer") assert state is not None - assert state.name == "Beer" + assert state.name == "bla Beer" assert ( - "Found new component: binary_sensor bla from external application bla2mqtt, version: 1.0" + "Processing device discovery for 'bla' from external " + "application bla2mqtt, version: 1.0" + in caplog.text + or f"Found new component: binary_sensor {discovery_id} from external application bla2mqtt, version: 1.0" in caplog.text ) caplog.clear() @@ -296,47 +478,635 @@ async def test_discovery_integration_info( # Send an update and add support url async_fire_mqtt_message( hass, - "homeassistant/binary_sensor/bla/config", - '{ "name": "Milk", "state_topic": "test-topic", "o": {"name": "bla2mqtt", "sw": "1.1", "url": "https://bla2mqtt.example.com/support" } }', + discovery_topic, + payloads[1], ) await hass.async_block_till_done() - state = hass.states.get("binary_sensor.beer") + state = hass.states.get("binary_sensor.bla_beer") assert state is not None - assert state.name == "Milk" + assert state.name == "bla Milk" assert ( - "Component has already been discovered: binary_sensor bla, sending update from external application bla2mqtt, version: 1.1, support URL: https://bla2mqtt.example.com/support" + f"Component has already been discovered: binary_sensor {discovery_id}" in caplog.text ) @pytest.mark.parametrize( - "config_message", + ("single_configs", "device_discovery_topic", "device_config"), + [(TEST_SINGLE_CONFIGS, TEST_DEVICE_DISCOVERY_TOPIC, TEST_DEVICE_CONFIG)], +) +async def test_discovery_migration_to_device_base( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, + mqtt_mock_entry: MqttMockHAClientGenerator, + tag_mock: AsyncMock, + caplog: pytest.LogCaptureFixture, + single_configs: list[tuple[str, dict[str, Any]]], + device_discovery_topic: str, + device_config: dict[str, Any], +) -> None: + """Test the migration of single discovery to device discovery.""" + await mqtt_mock_entry() + + # Discovery single config schema + for discovery_topic, config in single_configs: + payload = json.dumps(config) + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + await help_check_discovered_items(hass, device_registry, tag_mock) + + # Try to migrate to device based discovery without migrate_discovery flag + payload = json.dumps(device_config) + async_fire_mqtt_message( + hass, + device_discovery_topic, + payload, + ) + await hass.async_block_till_done() + assert ( + "Received a conflicting MQTT discovery message for device_automation " + "'0AFFD2 bla1' which was previously discovered on topic homeassistant/" + "device_automation/0AFFD2/bla1/config from external application Foo2Mqtt, " + "version: 1.40.2; the conflicting discovery message was received on topic " + "homeassistant/device/0AFFD2/config from external application Foo2Mqtt, " + "version: 1.50.0; for support visit https://www.foo2mqtt.io" in caplog.text + ) + assert ( + "Received a conflicting MQTT discovery message for entity sensor." + "test_device_mqtt_sensor; the entity was previously discovered on topic " + "homeassistant/sensor/0AFFD2/bla2/config from external application Foo2Mqtt, " + "version: 1.40.2; the conflicting discovery message was received on topic " + "homeassistant/device/0AFFD2/config from external application Foo2Mqtt, " + "version: 1.50.0; for support visit https://www.foo2mqtt.io" in caplog.text + ) + assert ( + "Received a conflicting MQTT discovery message for tag '0AFFD2 bla3' which " + "was previously discovered on topic homeassistant/tag/0AFFD2/bla3/config " + "from external application Foo2Mqtt, version: 1.40.2; the conflicting " + "discovery message was received on topic homeassistant/device/0AFFD2/config " + "from external application Foo2Mqtt, version: 1.50.0; for support visit " + "https://www.foo2mqtt.io" in caplog.text + ) + + # Check we still have our mqtt items + await help_check_discovered_items(hass, device_registry, tag_mock) + + # Test Enable discovery migration + # Discovery single config schema + caplog.clear() + for discovery_topic, _ in single_configs: + # migr_discvry is abbreviation for migrate_discovery + payload = json.dumps({"migr_discvry": True}) + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Assert we still have our device entry + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + assert device_entry is not None + # Check our trigger was unloaden + triggers = await async_get_device_automations( + hass, DeviceAutomationType.TRIGGER, device_entry.id + ) + assert len(triggers) == 0 + # Check the sensor was unloaded + state = hass.states.get("sensor.test_device_mqtt_sensor") + assert state is None + # Check the entity registry entry is retained + assert entity_registry.async_is_registered("sensor.test_device_mqtt_sensor") + + assert ( + "Migration to MQTT device discovery schema started for device_automation " + "'0AFFD2 bla1' from external application Foo2Mqtt, version: 1.40.2 on topic " + "homeassistant/device_automation/0AFFD2/bla1/config. To complete migration, " + "publish a device discovery message with device_automation '0AFFD2 bla1'. " + "After completed migration, publish an empty (retained) payload to " + "homeassistant/device_automation/0AFFD2/bla1/config" in caplog.text + ) + assert ( + "Migration to MQTT device discovery schema started for entity sensor." + "test_device_mqtt_sensor from external application Foo2Mqtt, version: 1.40.2 " + "on topic homeassistant/sensor/0AFFD2/bla2/config. To complete migration, " + "publish a device discovery message with sensor entity '0AFFD2 bla2'. After " + "completed migration, publish an empty (retained) payload to " + "homeassistant/sensor/0AFFD2/bla2/config" in caplog.text + ) + + # Migrate to device based discovery + caplog.clear() + payload = json.dumps(device_config) + async_fire_mqtt_message( + hass, + device_discovery_topic, + payload, + ) + await hass.async_block_till_done() + + caplog.clear() + for _ in range(2): + # Test publishing an empty payload twice to the migrated discovery topics + # does not remove the migrated items + for discovery_topic, _ in single_configs: + async_fire_mqtt_message( + hass, + discovery_topic, + "", + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Check we still have our mqtt items after publishing an + # empty payload to the old discovery topics + await help_check_discovered_items(hass, device_registry, tag_mock) + + # Check we cannot accidentally migrate back and remove the items + caplog.clear() + for discovery_topic, config in single_configs: + payload = json.dumps(config) + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + assert ( + "Received a conflicting MQTT discovery message for device_automation " + "'0AFFD2 bla1' which was previously discovered on topic homeassistant/device" + "/0AFFD2/config from external application Foo2Mqtt, version: 1.50.0; the " + "conflicting discovery message was received on topic homeassistant/" + "device_automation/0AFFD2/bla1/config from external application Foo2Mqtt, " + "version: 1.40.2; for support visit https://www.foo2mqtt.io" in caplog.text + ) + assert ( + "Received a conflicting MQTT discovery message for entity sensor." + "test_device_mqtt_sensor; the entity was previously discovered on topic " + "homeassistant/device/0AFFD2/config from external application Foo2Mqtt, " + "version: 1.50.0; the conflicting discovery message was received on topic " + "homeassistant/sensor/0AFFD2/bla2/config from external application Foo2Mqtt, " + "version: 1.40.2; for support visit https://www.foo2mqtt.io" in caplog.text + ) + assert ( + "Received a conflicting MQTT discovery message for tag '0AFFD2 bla3' which was " + "previously discovered on topic homeassistant/device/0AFFD2/config from " + "external application Foo2Mqtt, version: 1.50.0; the conflicting discovery " + "message was received on topic homeassistant/tag/0AFFD2/bla3/config from " + "external application Foo2Mqtt, version: 1.40.2; for support visit " + "https://www.foo2mqtt.io" in caplog.text + ) + + caplog.clear() + for discovery_topic, config in single_configs: + payload = json.dumps(config) + async_fire_mqtt_message( + hass, + discovery_topic, + "", + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Check we still have our mqtt items after publishing an + # empty payload to the old discovery topics + await help_check_discovered_items(hass, device_registry, tag_mock) + + # Check we can remove the config using the new discovery topic + async_fire_mqtt_message( + hass, + device_discovery_topic, + "", + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + # Check the device was removed as all device components were removed + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + assert device_entry is None + await hass.async_block_till_done(wait_background_tasks=True) + + +@pytest.mark.parametrize( + "config", [ - '{ "name": "Beer", "state_topic": "test-topic", "o": "bla2mqtt" }', - '{ "name": "Beer", "state_topic": "test-topic", "o": 2.0 }', - '{ "name": "Beer", "state_topic": "test-topic", "o": null }', - '{ "name": "Beer", "state_topic": "test-topic", "o": {"sw": "bla2mqtt"} }', + {"state_topic": "foobar/sensors/bla2/state", "name": "none_test"}, + { + "state_topic": "foobar/sensors/bla2/state", + "name": "none_test", + "unique_id": "very_unique", + }, + { + "state_topic": "foobar/sensors/bla2/state", + "device": {"identifiers": ["0AFFD2"], "name": "none_test"}, + }, + ], +) +async def test_discovery_migration_unique_id( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + caplog: pytest.LogCaptureFixture, + config: dict[str, Any], +) -> None: + """Test entity has a unique_id and device context when migrating.""" + await mqtt_mock_entry() + + discovery_topic = "homeassistant/sensor/0AFFD2/bla2/config" + + # Discovery with single config schema + payload = json.dumps(config) + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Try discovery migration + payload = json.dumps({"migr_discvry": True}) + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Assert the migration attempt fails + assert "Discovery migration is not possible" in caplog.text + + +@pytest.mark.parametrize( + ("single_configs", "device_discovery_topic", "device_config"), + [(TEST_SINGLE_CONFIGS, TEST_DEVICE_DISCOVERY_TOPIC, TEST_DEVICE_CONFIG)], +) +async def test_discovery_rollback_to_single_base( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, + mqtt_mock_entry: MqttMockHAClientGenerator, + tag_mock: AsyncMock, + caplog: pytest.LogCaptureFixture, + single_configs: list[tuple[str, dict[str, Any]]], + device_discovery_topic: str, + device_config: dict[str, Any], +) -> None: + """Test the rollback of device discovery to a single component discovery.""" + await mqtt_mock_entry() + + # Start device based discovery + # any single component discovery will be migrated + payload = json.dumps(device_config) + async_fire_mqtt_message( + hass, + device_discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + await help_check_discovered_items(hass, device_registry, tag_mock) + + # Migrate to single component discovery + # Test the schema + caplog.clear() + payload = json.dumps({"migrate_discovery": "invalid"}) + async_fire_mqtt_message( + hass, + device_discovery_topic, + payload, + ) + await hass.async_block_till_done() + assert "Invalid MQTT device discovery payload for 0AFFD2" in caplog.text + + # Set the correct migrate_discovery flag in the device payload + # to allow rollback + payload = json.dumps({"migrate_discovery": True}) + async_fire_mqtt_message( + hass, + device_discovery_topic, + payload, + ) + await hass.async_block_till_done() + + # Check the log messages + assert ( + "Rollback to MQTT platform discovery schema started for entity sensor." + "test_device_mqtt_sensor from external application Foo2Mqtt, version: 1.50.0 " + "on topic homeassistant/device/0AFFD2/config. To complete rollback, publish a " + "platform discovery message with sensor entity '0AFFD2 bla2'. After completed " + "rollback, publish an empty (retained) payload to " + "homeassistant/device/0AFFD2/config" in caplog.text + ) + assert ( + "Rollback to MQTT platform discovery schema started for device_automation " + "'0AFFD2 bla1' from external application Foo2Mqtt, version: 1.50.0 on topic " + "homeassistant/device/0AFFD2/config. To complete rollback, publish a platform " + "discovery message with device_automation '0AFFD2 bla1'. After completed " + "rollback, publish an empty (retained) payload to " + "homeassistant/device/0AFFD2/config" in caplog.text + ) + + # Assert we still have our device entry + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + assert device_entry is not None + # Check our trigger was unloaded + triggers = await async_get_device_automations( + hass, DeviceAutomationType.TRIGGER, device_entry.id + ) + assert len(triggers) == 0 + # Check the sensor was unloaded + state = hass.states.get("sensor.test_device_mqtt_sensor") + assert state is None + # Check the entity registry entry is retained + assert entity_registry.async_is_registered("sensor.test_device_mqtt_sensor") + + # Publish the new component based payloads + # to switch back to component based discovery + for discovery_topic, config in single_configs: + payload = json.dumps(config) + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Check we still have our mqtt items + # await help_check_discovered_items(hass, device_registry, tag_mock) + + for _ in range(2): + # Test publishing an empty payload twice to the migrated discovery topic + # does not remove the migrated items + async_fire_mqtt_message( + hass, + device_discovery_topic, + "", + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Check we still have our mqtt items after publishing an + # empty payload to the old discovery topics + await help_check_discovered_items(hass, device_registry, tag_mock) + + # Check we cannot accidentally migrate back and remove the items + payload = json.dumps(device_config) + async_fire_mqtt_message( + hass, + device_discovery_topic, + payload, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Check we still have our mqtt items after publishing an + # empty payload to the old discovery topics + await help_check_discovered_items(hass, device_registry, tag_mock) + + # Check we can remove the the config using the new discovery topics + for discovery_topic, config in single_configs: + payload = json.dumps(config) + async_fire_mqtt_message( + hass, + discovery_topic, + "", + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + # Check the device was removed as all device components were removed + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + assert device_entry is None + + +@pytest.mark.parametrize( + ("discovery_topic", "payload"), + [ + ( + "homeassistant/binary_sensor/bla/config", + '{"state_topic": "test-topic",' + '"name":"bla","unique_id":"very_unique1",' + '"avty": {"topic": "avty-topic"},' + '"o":{"name":"bla2mqtt","sw":"1.0"},' + '"dev":{"identifiers":["bla"],"name":"Beer"}}', + ), + ( + "homeassistant/device/bla/config", + '{"cmps":{"bin_sens1":{"platform":"binary_sensor",' + '"name":"bla","unique_id":"very_unique1",' + '"state_topic": "test-topic"}},' + '"avty": {"topic": "avty-topic"},' + '"o":{"name":"bla2mqtt","sw":"1.0"},' + '"dev":{"identifiers":["bla"],"name":"Beer"}}', + ), + ], + ids=["component", "device"], +) +async def test_discovery_availability( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + discovery_topic: str, + payload: str, +) -> None: + """Test device discovery with shared availability mapping.""" + await mqtt_mock_entry() + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer_bla") + assert state is not None + assert state.name == "Beer bla" + assert state.state == STATE_UNAVAILABLE + + async_fire_mqtt_message( + hass, + "avty-topic", + "online", + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer_bla") + assert state is not None + assert state.state == STATE_UNKNOWN + + async_fire_mqtt_message( + hass, + "test-topic", + "ON", + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer_bla") + assert state is not None + assert state.state == STATE_ON + + +@pytest.mark.parametrize( + ("discovery_topic", "payload"), + [ + ( + "homeassistant/device/bla/config", + '{"cmps":{"bin_sens1":{"platform":"binary_sensor",' + '"unique_id":"very_unique",' + '"avty": {"topic": "avty-topic-component"},' + '"name":"Beer","state_topic": "test-topic"}},' + '"avty": {"topic": "avty-topic-device"},' + '"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}', + ), + ( + "homeassistant/device/bla/config", + '{"cmps":{"bin_sens1":{"platform":"binary_sensor",' + '"unique_id":"very_unique",' + '"availability_topic": "avty-topic-component",' + '"name":"Beer","state_topic": "test-topic"}},' + '"availability_topic": "avty-topic-device",' + '"o":{"name":"bla2mqtt","sw":"1.0"},"dev":{"identifiers":["bla"]}}', + ), + ], + ids=["test1", "test2"], +) +async def test_discovery_component_availability_overridden( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + discovery_topic: str, + payload: str, +) -> None: + """Test device discovery with overridden shared availability mapping.""" + await mqtt_mock_entry() + async_fire_mqtt_message( + hass, + discovery_topic, + payload, + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.none_beer") + assert state is not None + assert state.name == "Beer" + assert state.state == STATE_UNAVAILABLE + + async_fire_mqtt_message( + hass, + "avty-topic-device", + "online", + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.none_beer") + assert state is not None + assert state.state == STATE_UNAVAILABLE + + async_fire_mqtt_message( + hass, + "avty-topic-component", + "online", + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.none_beer") + assert state is not None + assert state.state == STATE_UNKNOWN + + async_fire_mqtt_message( + hass, + "test-topic", + "ON", + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.none_beer") + assert state is not None + assert state.state == STATE_ON + + +@pytest.mark.parametrize( + ("discovery_topic", "config_message", "error_message"), + [ + ( + "homeassistant/binary_sensor/bla/config", + '{ "name": "Beer", "unique_id": "very_unique", ' + '"state_topic": "test-topic", "o": "bla2mqtt" }', + "Unable to parse origin information from discovery message", + ), + ( + "homeassistant/binary_sensor/bla/config", + '{ "name": "Beer", "unique_id": "very_unique", ' + '"state_topic": "test-topic", "o": 2.0 }', + "Unable to parse origin information from discovery message", + ), + ( + "homeassistant/binary_sensor/bla/config", + '{ "name": "Beer", "unique_id": "very_unique", ' + '"state_topic": "test-topic", "o": null }', + "Unable to parse origin information from discovery message", + ), + ( + "homeassistant/binary_sensor/bla/config", + '{ "name": "Beer", "unique_id": "very_unique", ' + '"state_topic": "test-topic", "o": {"sw": "bla2mqtt"} }', + "Unable to parse origin information from discovery message", + ), + ( + "homeassistant/device/bla/config", + '{"dev":{"identifiers":["bs1"]},"cmps":{"bs1":' + '{"platform":"binary_sensor","name":"Beer","unique_id": "very_unique",' + '"state_topic":"test-topic"}},"o": "bla2mqtt"}', + "Invalid MQTT device discovery payload for bla, " + "expected a dictionary for dictionary value @ data['origin']", + ), + ( + "homeassistant/device/bla/config", + '{"dev":{"identifiers":["bs1"]},"cmps":{"bs1":' + '{"platform":"binary_sensor","name":"Beer","unique_id": "very_unique",' + '"state_topic":"test-topic"}},"o": 2.0}', + "Invalid MQTT device discovery payload for bla, " + "expected a dictionary for dictionary value @ data['origin']", + ), + ( + "homeassistant/device/bla/config", + '{"dev":{"identifiers":["bs1"]},"cmps":{"bs1":' + '{"platform":"binary_sensor","name":"Beer","unique_id": "very_unique",' + '"state_topic":"test-topic"}},"o": null}', + "Invalid MQTT device discovery payload for bla, " + "expected a dictionary for dictionary value @ data['origin']", + ), + ( + "homeassistant/device/bla/config", + '{"dev":{"identifiers":["bs1"]},"cmps":{"bs1":' + '{"platform":"binary_sensor","name":"Beer","unique_id": "very_unique",' + '"state_topic":"test-topic"}},"o": {"sw": "bla2mqtt"}}', + "Invalid MQTT device discovery payload for bla, " + "required key not provided @ data['origin']['name']", + ), ], ) async def test_discovery_with_invalid_integration_info( hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator, caplog: pytest.LogCaptureFixture, + discovery_topic: str, config_message: str, + error_message: str, ) -> None: """Test sending in correct JSON.""" await mqtt_mock_entry() - async_fire_mqtt_message( - hass, "homeassistant/binary_sensor/bla/config", config_message - ) + async_fire_mqtt_message(hass, discovery_topic, config_message) await hass.async_block_till_done() - state = hass.states.get("binary_sensor.beer") + state = hass.states.get("binary_sensor.none_beer") assert state is None - assert "Unable to parse origin information from discovery message" in caplog.text + assert error_message in caplog.text async def test_discover_fan( @@ -855,43 +1625,86 @@ async def test_duplicate_removal( assert "Component has already been discovered: binary_sensor bla" not in caplog.text +@pytest.mark.parametrize( + ("discovery_payloads", "entity_ids"), + [ + ( + { + "homeassistant/sensor/sens1/config": "{" + '"device":{"identifiers":["0AFFD2"]},' + '"state_topic": "foobar/sensor1",' + '"unique_id": "unique1",' + '"name": "sensor1"' + "}", + "homeassistant/sensor/sens2/config": "{" + '"device":{"identifiers":["0AFFD2"]},' + '"state_topic": "foobar/sensor2",' + '"unique_id": "unique2",' + '"name": "sensor2"' + "}", + }, + ["sensor.none_sensor1", "sensor.none_sensor2"], + ), + ( + { + "homeassistant/device/bla/config": "{" + '"device":{"identifiers":["0AFFD2"]},' + '"o": {"name": "foobar"},' + '"cmps": {"sens1": {' + '"platform": "sensor",' + '"name": "sensor1",' + '"state_topic": "foobar/sensor1",' + '"unique_id": "unique1"' + '},"sens2": {' + '"platform": "sensor",' + '"name": "sensor2",' + '"state_topic": "foobar/sensor2",' + '"unique_id": "unique2"' + "}}}" + }, + ["sensor.none_sensor1", "sensor.none_sensor2"], + ), + ], +) async def test_cleanup_device_manual( hass: HomeAssistant, + mock_debouncer: asyncio.Event, hass_ws_client: WebSocketGenerator, device_registry: dr.DeviceRegistry, entity_registry: er.EntityRegistry, mqtt_mock_entry: MqttMockHAClientGenerator, + discovery_payloads: dict[str, str], + entity_ids: list[str], ) -> None: """Test discovered device is cleaned up when entry removed from device.""" mqtt_mock = await mqtt_mock_entry() assert await async_setup_component(hass, "config", {}) ws_client = await hass_ws_client(hass) - data = ( - '{ "device":{"identifiers":["0AFFD2"]},' - ' "state_topic": "foobar/sensor",' - ' "unique_id": "unique" }' - ) - - async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) - await hass.async_block_till_done() + mock_debouncer.clear() + for discovery_topic, discovery_payload in discovery_payloads.items(): + async_fire_mqtt_message(hass, discovery_topic, discovery_payload) + await mock_debouncer.wait() # Verify device and registry entries are created device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) assert device_entry is not None - entity_entry = entity_registry.async_get("sensor.none_mqtt_sensor") - assert entity_entry is not None - state = hass.states.get("sensor.none_mqtt_sensor") - assert state is not None + for entity_id in entity_ids: + entity_entry = entity_registry.async_get(entity_id) + assert entity_entry is not None + + state = hass.states.get(entity_id) + assert state is not None # Remove MQTT from the device mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] + mock_debouncer.clear() response = await ws_client.remove_device( device_entry.id, mqtt_config_entry.entry_id ) assert response["success"] - await hass.async_block_till_done() + await mock_debouncer.wait() await hass.async_block_till_done() # Verify device and registry entries are cleared @@ -901,60 +1714,224 @@ async def test_cleanup_device_manual( assert entity_entry is None # Verify state is removed - state = hass.states.get("sensor.none_mqtt_sensor") - assert state is None - await hass.async_block_till_done() + for entity_id in entity_ids: + state = hass.states.get(entity_id) + assert state is None - # Verify retained discovery topic has been cleared - mqtt_mock.async_publish.assert_called_once_with( - "homeassistant/sensor/bla/config", None, 0, True + # Verify retained discovery topics have been cleared + mqtt_mock.async_publish.assert_has_calls( + [call(discovery_topic, None, 0, True) for discovery_topic in discovery_payloads] ) + await hass.async_block_till_done(wait_background_tasks=True) + +@pytest.mark.parametrize( + ("discovery_topic", "discovery_payload", "entity_ids"), + [ + ( + "homeassistant/sensor/bla/config", + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }', + ["sensor.none_mqtt_sensor"], + ), + ( + "homeassistant/device/bla/config", + '{ "device":{"identifiers":["0AFFD2"]},' + ' "o": {"name": "foobar"},' + ' "cmps": {"sens1": {' + ' "platform": "sensor",' + ' "name": "sensor1",' + ' "state_topic": "foobar/sensor1",' + ' "unique_id": "unique1"' + ' },"sens2": {' + ' "platform": "sensor",' + ' "name": "sensor2",' + ' "state_topic": "foobar/sensor2",' + ' "unique_id": "unique2"' + "}}}", + ["sensor.none_sensor1", "sensor.none_sensor2"], + ), + ], +) async def test_cleanup_device_mqtt( hass: HomeAssistant, device_registry: dr.DeviceRegistry, entity_registry: er.EntityRegistry, mqtt_mock_entry: MqttMockHAClientGenerator, + discovery_topic: str, + discovery_payload: str, + entity_ids: list[str], ) -> None: - """Test discvered device is cleaned up when removed through MQTT.""" + """Test discovered device is cleaned up when removed through MQTT.""" mqtt_mock = await mqtt_mock_entry() - data = ( - '{ "device":{"identifiers":["0AFFD2"]},' - ' "state_topic": "foobar/sensor",' - ' "unique_id": "unique" }' - ) - async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) + # set up an existing sensor first + data = ( + '{ "device":{"identifiers":["0AFFD3"]},' + ' "name": "sensor_base",' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique_base" }' + ) + base_discovery_topic = "homeassistant/sensor/bla_base/config" + base_entity_id = "sensor.none_sensor_base" + async_fire_mqtt_message(hass, base_discovery_topic, data) + await hass.async_block_till_done() + + # Verify the base entity has been created and it has a state + base_device_entry = device_registry.async_get_device( + identifiers={("mqtt", "0AFFD3")} + ) + assert base_device_entry is not None + entity_entry = entity_registry.async_get(base_entity_id) + assert entity_entry is not None + state = hass.states.get(base_entity_id) + assert state is not None + + async_fire_mqtt_message(hass, discovery_topic, discovery_payload) await hass.async_block_till_done() # Verify device and registry entries are created device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) assert device_entry is not None - entity_entry = entity_registry.async_get("sensor.none_mqtt_sensor") - assert entity_entry is not None + for entity_id in entity_ids: + entity_entry = entity_registry.async_get(entity_id) + assert entity_entry is not None - state = hass.states.get("sensor.none_mqtt_sensor") - assert state is not None + state = hass.states.get(entity_id) + assert state is not None - async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", "") + async_fire_mqtt_message(hass, discovery_topic, "") await hass.async_block_till_done() await hass.async_block_till_done() # Verify device and registry entries are cleared device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) assert device_entry is None - entity_entry = entity_registry.async_get("sensor.none_mqtt_sensor") - assert entity_entry is None - # Verify state is removed - state = hass.states.get("sensor.none_mqtt_sensor") - assert state is None - await hass.async_block_till_done() + for entity_id in entity_ids: + entity_entry = entity_registry.async_get(entity_id) + assert entity_entry is None + + # Verify state is removed + state = hass.states.get(entity_id) + assert state is None + await hass.async_block_till_done() # Verify retained discovery topics have not been cleared again mqtt_mock.async_publish.assert_not_called() + # Verify the base entity still exists and it has a state + base_device_entry = device_registry.async_get_device( + identifiers={("mqtt", "0AFFD3")} + ) + assert base_device_entry is not None + entity_entry = entity_registry.async_get(base_entity_id) + assert entity_entry is not None + state = hass.states.get(base_entity_id) + assert state is not None + + +async def test_cleanup_device_mqtt_device_discovery( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, + mqtt_mock_entry: MqttMockHAClientGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test discovered device is cleaned up partly when removed through MQTT.""" + await mqtt_mock_entry() + + discovery_topic = "homeassistant/device/bla/config" + discovery_payload = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "o": {"name": "foobar"},' + ' "cmps": {"sens1": {' + ' "p": "sensor",' + ' "name": "sensor1",' + ' "state_topic": "foobar/sensor1",' + ' "unique_id": "unique1"' + ' },"sens2": {' + ' "p": "sensor",' + ' "name": "sensor2",' + ' "state_topic": "foobar/sensor2",' + ' "unique_id": "unique2"' + "}}}" + ) + entity_ids = ["sensor.none_sensor1", "sensor.none_sensor2"] + async_fire_mqtt_message(hass, discovery_topic, discovery_payload) + await hass.async_block_till_done() + + # Verify device and registry entries are created + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + assert device_entry is not None + for entity_id in entity_ids: + entity_entry = entity_registry.async_get(entity_id) + assert entity_entry is not None + + state = hass.states.get(entity_id) + assert state is not None + + # Do update and remove sensor 2 from device + discovery_payload_update1 = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "o": {"name": "foobar"},' + ' "cmps": {"sens1": {' + ' "p": "sensor",' + ' "name": "sensor1",' + ' "state_topic": "foobar/sensor1",' + ' "unique_id": "unique1"' + ' },"sens2": {' + ' "p": "sensor"' + "}}}" + ) + async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update1) + await hass.async_block_till_done() + state = hass.states.get(entity_ids[0]) + assert state is not None + state = hass.states.get(entity_ids[1]) + assert state is None + + # Repeating the update + async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update1) + await hass.async_block_till_done() + state = hass.states.get(entity_ids[0]) + assert state is not None + state = hass.states.get(entity_ids[1]) + assert state is None + + # Removing last sensor + discovery_payload_update2 = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "o": {"name": "foobar"},' + ' "cmps": {"sens1": {' + ' "p": "sensor"' + ' },"sens2": {' + ' "p": "sensor"' + "}}}" + ) + async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update2) + await hass.async_block_till_done() + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + # Verify the device entry was removed with the last sensor + assert device_entry is None + for entity_id in entity_ids: + entity_entry = entity_registry.async_get(entity_id) + assert entity_entry is None + + state = hass.states.get(entity_id) + assert state is None + + # Repeating the update + async_fire_mqtt_message(hass, discovery_topic, discovery_payload_update2) + await hass.async_block_till_done() + + # Clear the empty discovery payload and verify there was nothing to cleanup + async_fire_mqtt_message(hass, discovery_topic, "") + await hass.async_block_till_done() + assert "No device components to cleanup" in caplog.text + async def test_cleanup_device_multiple_config_entries( hass: HomeAssistant, @@ -1936,3 +2913,77 @@ async def test_discovery_dispatcher_signal_type_messages( assert len(calls) == 1 assert calls[0] == test_data unsub() + + +@pytest.mark.parametrize( + ("discovery_topic", "discovery_payload", "entity_ids"), + [ + ( + "homeassistant/device/bla/config", + '{ "device":{"identifiers":["0AFFD2"]},' + ' "o": {"name": "foobar"},' + ' "state_topic": "foobar/sensor-shared",' + ' "cmps": {"sens1": {' + ' "platform": "sensor",' + ' "name": "sensor1",' + ' "unique_id": "unique1"' + ' },"sens2": {' + ' "platform": "sensor",' + ' "name": "sensor2",' + ' "unique_id": "unique2"' + ' },"sens3": {' + ' "platform": "sensor",' + ' "name": "sensor3",' + ' "state_topic": "foobar/sensor3",' + ' "unique_id": "unique3"' + "}}}", + ["sensor.none_sensor1", "sensor.none_sensor2", "sensor.none_sensor3"], + ), + ], +) +async def test_shared_state_topic( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, + mqtt_mock_entry: MqttMockHAClientGenerator, + discovery_topic: str, + discovery_payload: str, + entity_ids: list[str], +) -> None: + """Test a shared state_topic can be used.""" + await mqtt_mock_entry() + + async_fire_mqtt_message(hass, discovery_topic, discovery_payload) + await hass.async_block_till_done() + + # Verify device and registry entries are created + device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")}) + assert device_entry is not None + for entity_id in entity_ids: + entity_entry = entity_registry.async_get(entity_id) + assert entity_entry is not None + + state = hass.states.get(entity_id) + assert state is not None + assert state.state == STATE_UNKNOWN + + async_fire_mqtt_message(hass, "foobar/sensor-shared", "New state") + + entity_id = entity_ids[0] + state = hass.states.get(entity_id) + assert state is not None + assert state.state == "New state" + entity_id = entity_ids[1] + state = hass.states.get(entity_id) + assert state is not None + assert state.state == "New state" + entity_id = entity_ids[2] + state = hass.states.get(entity_id) + assert state is not None + assert state.state == STATE_UNKNOWN + + async_fire_mqtt_message(hass, "foobar/sensor3", "New state3") + entity_id = entity_ids[2] + state = hass.states.get(entity_id) + assert state is not None + assert state.state == "New state3" diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 396d3477bad..145016751e7 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1197,7 +1197,6 @@ async def test_mqtt_ws_get_device_debug_info( } data_sensor = json.dumps(config_sensor) data_trigger = json.dumps(config_trigger) - config_sensor["platform"] = config_trigger["platform"] = mqtt.DOMAIN async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data_sensor) async_fire_mqtt_message( @@ -1254,7 +1253,6 @@ async def test_mqtt_ws_get_device_debug_info_binary( "unique_id": "unique", } data = json.dumps(config) - config["platform"] = mqtt.DOMAIN async_fire_mqtt_message(hass, "homeassistant/camera/bla/config", data) await hass.async_block_till_done() diff --git a/tests/components/mqtt/test_tag.py b/tests/components/mqtt/test_tag.py index ff407d29e1e..41c417fe3e9 100644 --- a/tests/components/mqtt/test_tag.py +++ b/tests/components/mqtt/test_tag.py @@ -1,10 +1,9 @@ """The tests for MQTT tag scanner.""" -from collections.abc import Generator import copy import json from typing import Any -from unittest.mock import ANY, AsyncMock, patch +from unittest.mock import ANY, AsyncMock import pytest @@ -47,13 +46,6 @@ DEFAULT_TAG_SCAN_JSON = ( ) -@pytest.fixture -def tag_mock() -> Generator[AsyncMock]: - """Fixture to mock tag.""" - with patch("homeassistant.components.tag.async_scan_tag") as mock_tag: - yield mock_tag - - @pytest.mark.no_fail_on_log_exception async def test_discover_bad_tag( hass: HomeAssistant,