From f99b7d8b78025889c9b2fd48a6ba5580dbc3c767 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Mon, 30 Sep 2024 12:44:40 +0200 Subject: [PATCH] Start mqtt integration discovery config flow only once if config has not changed (#126966) * Start mqtt integration config flow only once * Remember last config message * Filter out instead of unsubscribing the intehration discovery topic * Follow up comments from code review --- homeassistant/components/mqtt/discovery.py | 32 +++--- tests/components/mqtt/test_discovery.py | 107 +++++++++++++-------- 2 files changed, 87 insertions(+), 52 deletions(-) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index e2a726e2915..af27615e2c0 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Any from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.core import HassJobType, HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResultType import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, @@ -192,6 +191,7 @@ async def async_start( # noqa: C901 """Start MQTT Discovery.""" mqtt_data = hass.data[DATA_MQTT] platform_setup_lock: dict[str, asyncio.Lock] = {} + integration_discovery_messages: dict[str, int] = {} @callback def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None: @@ -368,17 +368,23 @@ async def async_start( # noqa: C901 integration: str, msg: ReceiveMessage ) -> None: """Process the received message.""" + if ( + msg.topic in integration_discovery_messages + and integration_discovery_messages[msg.topic] == hash(msg.payload) + ): + _LOGGER.debug( + "Ignoring already processed discovery message for '%s' on topic %s: %s", + integration, + msg.topic, + msg.payload, + ) + return if TYPE_CHECKING: assert mqtt_data.data_config_flow_lock - key = f"{integration}_{msg.subscribed_topic}" # Lock to prevent initiating many parallel config flows. # Note: The lock is not intended to prevent a race, only for performance async with mqtt_data.data_config_flow_lock: - # Already unsubscribed - if key not in integration_unsubscribe: - return - data = MqttServiceInfo( topic=msg.topic, payload=msg.payload, @@ -387,15 +393,15 @@ async def async_start( # noqa: C901 subscribed_topic=msg.subscribed_topic, timestamp=msg.timestamp, ) - result = await hass.config_entries.flow.async_init( + await hass.config_entries.flow.async_init( integration, context={"source": DOMAIN}, data=data ) - if ( - result - and result["type"] == FlowResultType.ABORT - and result["reason"] == "single_instance_allowed" - ): - integration_unsubscribe.pop(key)() + if msg.payload: + # Update the last discovered config message + integration_discovery_messages[msg.topic] = hash(msg.payload) + elif msg.topic in integration_discovery_messages: + # Cleanup hash if discovery payload is empty + del integration_discovery_messages[msg.topic] integration_unsubscribe.update( { diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 2f83c1138b9..cc7142236d0 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -3,6 +3,7 @@ import asyncio import copy import json +import logging from pathlib import Path import re from unittest.mock import AsyncMock, call, patch @@ -48,9 +49,11 @@ from .test_common import help_all_subscribe_calls, help_test_unload_config_entry from tests.common import ( MockConfigEntry, + MockModule, async_capture_events, async_fire_mqtt_message, mock_config_flow, + mock_integration, mock_platform, ) from tests.typing import ( @@ -1445,26 +1448,20 @@ async def test_complex_discovery_topic_prefix( @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0) @pytest.mark.parametrize( - ("reason", "unsubscribes"), - [ - ("single_instance_allowed", True), - ("already_configured", False), - ("some_abort_error", False), - ], + "reason", ["single_instance_allowed", "already_configured", "some_abort_error"] ) -async def test_mqtt_integration_discovery_subscribe_unsubscribe( - hass: HomeAssistant, - mqtt_client_mock: MqttMockPahoClient, - reason: str, - unsubscribes: bool, +async def test_mqtt_integration_discovery_flow_fitering_on_redundant_payload( + hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, reason: str ) -> None: - """Check MQTT integration discovery subscribe and unsubscribe.""" + """Check MQTT integration discovery starts a flow once.""" + flow_calls: list[MqttServiceInfo] = [] class TestFlow(config_entries.ConfigFlow): """Test flow.""" async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: """Test mqtt step.""" + flow_calls.append(discovery_info) return self.async_abort(reason=reason) mock_platform(hass, "comp.config_flow", None) @@ -1493,30 +1490,38 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe( assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert not mqtt_client_mock.unsubscribe.called mqtt_client_mock.reset_mock() + assert len(flow_calls) == 0 await hass.async_block_till_done(wait_background_tasks=True) - async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") - await hass.async_block_till_done() + async_fire_mqtt_message(hass, "comp/discovery/bla/config", "initial message") await hass.async_block_till_done(wait_background_tasks=True) + assert len(flow_calls) == 1 - assert ( - unsubscribes - and call(["comp/discovery/#"]) in mqtt_client_mock.unsubscribe.mock_calls - or not unsubscribes - and call(["comp/discovery/#"]) - not in mqtt_client_mock.unsubscribe.mock_calls - ) + # A redundant message gets does not start a new flow await hass.async_block_till_done(wait_background_tasks=True) + async_fire_mqtt_message(hass, "comp/discovery/bla/config", "initial message") + await hass.async_block_till_done(wait_background_tasks=True) + assert len(flow_calls) == 1 + + # An updated message gets starts a new flow + await hass.async_block_till_done(wait_background_tasks=True) + async_fire_mqtt_message(hass, "comp/discovery/bla/config", "update message") + await hass.async_block_till_done(wait_background_tasks=True) + assert len(flow_calls) == 2 @patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0) -async def test_mqtt_discovery_unsubscribe_once( - hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient +async def test_mqtt_discovery_flow_starts_once( + hass: HomeAssistant, + mqtt_client_mock: MqttMockPahoClient, + caplog: pytest.LogCaptureFixture, ) -> None: - """Check MQTT integration discovery unsubscribe once.""" + """Check MQTT integration discovery starts a flow once.""" + + flow_calls: list[MqttServiceInfo] = [] class TestFlow(config_entries.ConfigFlow): """Test flow.""" @@ -1524,8 +1529,12 @@ async def test_mqtt_discovery_unsubscribe_once( async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: """Test mqtt step.""" await asyncio.sleep(0) - return self.async_abort(reason="single_instance_allowed") + flow_calls.append(discovery_info) + return self.async_create_entry(title="Test", data={}) + mock_integration( + hass, MockModule(domain="comp", async_setup_entry=AsyncMock(return_value=True)) + ) mock_platform(hass, "comp.config_flow", None) birth = asyncio.Event() @@ -1535,13 +1544,6 @@ async def test_mqtt_discovery_unsubscribe_once( """Handle birth message.""" birth.set() - wait_unsub = asyncio.Event() - - @callback - def _mock_unsubscribe(topics: list[str]) -> tuple[int, int]: - wait_unsub.set() - return (0, 0) - entry = MockConfigEntry(domain=mqtt.DOMAIN, data=ENTRY_DEFAULT_BIRTH_MESSAGE) entry.add_to_hass(hass) @@ -1551,7 +1553,6 @@ async def test_mqtt_discovery_unsubscribe_once( return_value={"comp": ["comp/discovery/#"]}, ), mock_config_flow("comp", TestFlow), - patch.object(mqtt_client_mock, "unsubscribe", side_effect=_mock_unsubscribe), ): assert await hass.config_entries.async_setup(entry.entry_id) await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) @@ -1559,17 +1560,45 @@ async def test_mqtt_discovery_unsubscribe_once( await birth.wait() assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) - assert not mqtt_client_mock.unsubscribe.called + async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "initial message") await hass.async_block_till_done(wait_background_tasks=True) - async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") - async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") - await wait_unsub.wait() - await asyncio.sleep(0) + assert len(flow_calls) == 1 + assert flow_calls[0].topic == "comp/discovery/bla/config1" + assert flow_calls[0].payload == "initial message" + + with caplog.at_level(logging.DEBUG): + async_fire_mqtt_message( + hass, "comp/discovery/bla/config1", "initial message" + ) + await hass.async_block_till_done(wait_background_tasks=True) + assert "Ignoring already processed discovery message" in caplog.text + assert len(flow_calls) == 1 + + async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "initial message") await hass.async_block_till_done(wait_background_tasks=True) - mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"]) + + assert len(flow_calls) == 2 + assert flow_calls[1].topic == "comp/discovery/bla/config2" + assert flow_calls[1].payload == "initial message" + + async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "update message") await hass.async_block_till_done(wait_background_tasks=True) + assert len(flow_calls) == 3 + assert flow_calls[2].topic == "comp/discovery/bla/config2" + assert flow_calls[2].payload == "update message" + + # An empty message triggers a flow to allow cleanup + async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "") + await hass.async_block_till_done(wait_background_tasks=True) + + assert len(flow_calls) == 4 + assert flow_calls[3].topic == "comp/discovery/bla/config2" + assert flow_calls[3].payload == "" + + assert not mqtt_client_mock.unsubscribe.called + async def test_clear_config_topic_disabled_entity( hass: HomeAssistant,