From d237180a987ce80a454b2ca1b11353c32888775b Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Sat, 26 Oct 2024 07:21:52 +0200 Subject: [PATCH] Allow re-discovery of mqtt integration config payloads (#127362) --- homeassistant/components/mqtt/discovery.py | 63 ++++++++-- tests/components/mqtt/test_discovery.py | 138 ++++++++++++++++----- 2 files changed, 166 insertions(+), 35 deletions(-) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index af27615e2c0..bdaf71f8740 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections import deque +from dataclasses import dataclass import functools from itertools import chain import logging @@ -11,9 +12,14 @@ import re import time from typing import TYPE_CHECKING, Any -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ( + SOURCE_MQTT, + ConfigEntry, + signal_discovered_config_entry_removed, +) from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.core import HassJobType, HomeAssistant, callback +from homeassistant.helpers import discovery_flow import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, @@ -71,6 +77,14 @@ class MQTTDiscoveryPayload(dict[str, Any]): discovery_data: DiscoveryInfoType +@dataclass(frozen=True) +class MQTTIntegrationDiscoveryConfig: + """Class to hold an integration discovery playload.""" + + integration: str + msg: ReceiveMessage + + def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None: """Clear entry from already discovered list.""" hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash) @@ -191,7 +205,7 @@ async def async_start( # noqa: C901 """Start MQTT Discovery.""" mqtt_data = hass.data[DATA_MQTT] platform_setup_lock: dict[str, asyncio.Lock] = {} - integration_discovery_messages: dict[str, int] = {} + integration_discovery_messages: dict[str, MQTTIntegrationDiscoveryConfig] = {} @callback def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None: @@ -364,13 +378,39 @@ async def async_start( # noqa: C901 mqtt_integrations = await async_get_mqtt(hass) integration_unsubscribe = mqtt_data.integration_unsubscribe + async def _async_handle_config_entry_removed(entry: ConfigEntry) -> None: + """Handle integration config entry changes.""" + for discovery_key in entry.discovery_keys[DOMAIN]: + if ( + discovery_key.version != 1 + or not isinstance(discovery_key.key, str) + or discovery_key.key not in integration_discovery_messages + ): + continue + topic = discovery_key.key + discovery_message = integration_discovery_messages[topic] + del integration_discovery_messages[topic] + _LOGGER.debug("Rediscover service on topic %s", topic) + # Initiate re-discovery + await async_integration_message_received( + discovery_message.integration, discovery_message.msg + ) + + mqtt_data.discovery_unsubscribe.append( + async_dispatcher_connect( + hass, + signal_discovered_config_entry_removed(DOMAIN), + _async_handle_config_entry_removed, + ) + ) + async def async_integration_message_received( integration: str, msg: ReceiveMessage ) -> None: """Process the received message.""" if ( msg.topic in integration_discovery_messages - and integration_discovery_messages[msg.topic] == hash(msg.payload) + and integration_discovery_messages[msg.topic].msg.payload == msg.payload ): _LOGGER.debug( "Ignoring already processed discovery message for '%s' on topic %s: %s", @@ -393,14 +433,23 @@ async def async_start( # noqa: C901 subscribed_topic=msg.subscribed_topic, timestamp=msg.timestamp, ) - await hass.config_entries.flow.async_init( - integration, context={"source": DOMAIN}, data=data + discovery_key = discovery_flow.DiscoveryKey( + domain=DOMAIN, key=msg.topic, version=1 + ) + discovery_flow.async_create_flow( + hass, + integration, + {"source": SOURCE_MQTT}, + data, + discovery_key=discovery_key, ) if msg.payload: # Update the last discovered config message - integration_discovery_messages[msg.topic] = hash(msg.payload) + integration_discovery_messages[msg.topic] = ( + MQTTIntegrationDiscoveryConfig(integration=integration, msg=msg) + ) elif msg.topic in integration_discovery_messages: - # Cleanup hash if discovery payload is empty + # Cleanup cache if discovery payload is empty del integration_discovery_messages[msg.topic] integration_unsubscribe.update( diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index cc7142236d0..6b8feac4e48 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -34,7 +34,7 @@ from homeassistant.const import ( Platform, ) from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult +from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, @@ -63,6 +63,53 @@ from tests.typing import ( ) +@pytest.fixture +def mqtt_data_flow_calls() -> list[MqttServiceInfo]: + """Return list to capture MQTT data data flow calls.""" + return [] + + +@pytest.fixture +async def mock_mqtt_flow( + hass: HomeAssistant, mqtt_data_flow_calls: list[MqttServiceInfo] +) -> config_entries.ConfigFlow: + """Test fixure for mqtt integration flow. + + The topic is used as a unique ID. + The component test domain used is: `comp`. + + Creates an entry if does not exist. + Updates an entry if it exists, and there is an updated payload. + """ + + class TestFlow(config_entries.ConfigFlow): + """Test flow.""" + + async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: + """Test mqtt step.""" + await asyncio.sleep(0) + mqtt_data_flow_calls.append(discovery_info) + # Abort a flow if there is an update for the existing entry + if entry := self.hass.config_entries.async_entry_for_domain_unique_id( + "comp", discovery_info.topic + ): + hass.config_entries.async_update_entry( + entry, + data={ + "name": discovery_info.topic, + "payload": discovery_info.payload, + }, + ) + raise AbortFlow("already_configured") + await self.async_set_unique_id(discovery_info.topic) + return self.async_create_entry( + title="Test", + data={"name": discovery_info.topic, "payload": discovery_info.payload}, + ) + + return TestFlow + + @pytest.mark.parametrize( "mqtt_config_entry_data", [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}], @@ -1518,20 +1565,14 @@ async def test_mqtt_discovery_flow_starts_once( hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, caplog: pytest.LogCaptureFixture, + mock_mqtt_flow: config_entries.ConfigFlow, + mqtt_data_flow_calls: list[MqttServiceInfo], ) -> None: - """Check MQTT integration discovery starts a flow once.""" - - flow_calls: list[MqttServiceInfo] = [] - - class TestFlow(config_entries.ConfigFlow): - """Test flow.""" - - async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: - """Test mqtt step.""" - await asyncio.sleep(0) - flow_calls.append(discovery_info) - return self.async_create_entry(title="Test", data={}) + """Check MQTT integration discovery starts a flow once. + A flow should be started once after discovery, + and after an entry was removed, to trigger re-discovery. + """ mock_integration( hass, MockModule(domain="comp", async_setup_entry=AsyncMock(return_value=True)) ) @@ -1552,7 +1593,7 @@ async def test_mqtt_discovery_flow_starts_once( "homeassistant.components.mqtt.discovery.async_get_mqtt", return_value={"comp": ["comp/discovery/#"]}, ), - mock_config_flow("comp", TestFlow), + mock_config_flow("comp", mock_mqtt_flow), ): assert await hass.config_entries.async_setup(entry.entry_id) await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) @@ -1561,41 +1602,82 @@ async def test_mqtt_discovery_flow_starts_once( assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) + # Test the initial flow async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "initial message") await hass.async_block_till_done(wait_background_tasks=True) - assert len(flow_calls) == 1 - assert flow_calls[0].topic == "comp/discovery/bla/config1" - assert flow_calls[0].payload == "initial message" + assert len(mqtt_data_flow_calls) == 1 + assert mqtt_data_flow_calls[0].topic == "comp/discovery/bla/config1" + assert mqtt_data_flow_calls[0].payload == "initial message" + # Test we can ignore updates if they are the same with caplog.at_level(logging.DEBUG): async_fire_mqtt_message( hass, "comp/discovery/bla/config1", "initial message" ) await hass.async_block_till_done(wait_background_tasks=True) assert "Ignoring already processed discovery message" in caplog.text - assert len(flow_calls) == 1 + assert len(mqtt_data_flow_calls) == 1 + # Test we can apply updates + async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "update message") + await hass.async_block_till_done(wait_background_tasks=True) + + assert len(mqtt_data_flow_calls) == 2 + assert mqtt_data_flow_calls[1].topic == "comp/discovery/bla/config1" + assert mqtt_data_flow_calls[1].payload == "update message" + + # Test we set up multiple entries async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "initial message") await hass.async_block_till_done(wait_background_tasks=True) - assert len(flow_calls) == 2 - assert flow_calls[1].topic == "comp/discovery/bla/config2" - assert flow_calls[1].payload == "initial message" + assert len(mqtt_data_flow_calls) == 3 + assert mqtt_data_flow_calls[2].topic == "comp/discovery/bla/config2" + assert mqtt_data_flow_calls[2].payload == "initial message" + # Test we update multiple entries async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "update message") await hass.async_block_till_done(wait_background_tasks=True) - assert len(flow_calls) == 3 - assert flow_calls[2].topic == "comp/discovery/bla/config2" - assert flow_calls[2].payload == "update message" + assert len(mqtt_data_flow_calls) == 4 + assert mqtt_data_flow_calls[3].topic == "comp/discovery/bla/config2" + assert mqtt_data_flow_calls[3].payload == "update message" - # An empty message triggers a flow to allow cleanup + # Test an empty message triggers a flow to allow cleanup (if needed) async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "") await hass.async_block_till_done(wait_background_tasks=True) - assert len(flow_calls) == 4 - assert flow_calls[3].topic == "comp/discovery/bla/config2" - assert flow_calls[3].payload == "" + assert len(mqtt_data_flow_calls) == 5 + assert mqtt_data_flow_calls[4].topic == "comp/discovery/bla/config2" + assert mqtt_data_flow_calls[4].payload == "" + + # Cleanup the the second entry + assert ( + entry := hass.config_entries.async_entry_for_domain_unique_id( + "comp", "comp/discovery/bla/config2" + ) + ) is not None + await hass.config_entries.async_remove(entry.entry_id) + assert len(hass.config_entries.async_entries(domain="comp")) == 1 + + # Remove remaining entry1 and assert this triggers an + # automatic re-discovery flow with latest config + assert ( + entry := hass.config_entries.async_entry_for_domain_unique_id( + "comp", "comp/discovery/bla/config1" + ) + ) is not None + assert entry.unique_id == "comp/discovery/bla/config1" + await hass.config_entries.async_remove(entry.entry_id) + assert len(hass.config_entries.async_entries(domain="comp")) == 0 + + # Wait for re-discovery flow to complete + await hass.async_block_till_done(wait_background_tasks=True) + assert len(mqtt_data_flow_calls) == 6 + assert mqtt_data_flow_calls[5].topic == "comp/discovery/bla/config1" + assert mqtt_data_flow_calls[5].payload == "update message" + + # Re-discovery triggered the config flow + assert len(hass.config_entries.async_entries(domain="comp")) == 1 assert not mqtt_client_mock.unsubscribe.called