Start mqtt integration discovery config flow only once if config has not changed (#126966)

* Start mqtt integration config flow only once

* Remember last config message

* Filter out instead of unsubscribing the intehration discovery topic

* Follow up comments from code review
This commit is contained in:
Jan Bouwhuis 2024-09-30 12:44:40 +02:00 committed by GitHub
parent e8fd97e355
commit f99b7d8b78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 87 additions and 52 deletions

View File

@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Any
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HassJobType, HomeAssistant, callback from homeassistant.core import HassJobType, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResultType
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
@ -192,6 +191,7 @@ async def async_start( # noqa: C901
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_data = hass.data[DATA_MQTT] mqtt_data = hass.data[DATA_MQTT]
platform_setup_lock: dict[str, asyncio.Lock] = {} platform_setup_lock: dict[str, asyncio.Lock] = {}
integration_discovery_messages: dict[str, int] = {}
@callback @callback
def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None: def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None:
@ -368,17 +368,23 @@ async def async_start( # noqa: C901
integration: str, msg: ReceiveMessage integration: str, msg: ReceiveMessage
) -> None: ) -> None:
"""Process the received message.""" """Process the received message."""
if (
msg.topic in integration_discovery_messages
and integration_discovery_messages[msg.topic] == hash(msg.payload)
):
_LOGGER.debug(
"Ignoring already processed discovery message for '%s' on topic %s: %s",
integration,
msg.topic,
msg.payload,
)
return
if TYPE_CHECKING: if TYPE_CHECKING:
assert mqtt_data.data_config_flow_lock assert mqtt_data.data_config_flow_lock
key = f"{integration}_{msg.subscribed_topic}"
# Lock to prevent initiating many parallel config flows. # Lock to prevent initiating many parallel config flows.
# Note: The lock is not intended to prevent a race, only for performance # Note: The lock is not intended to prevent a race, only for performance
async with mqtt_data.data_config_flow_lock: async with mqtt_data.data_config_flow_lock:
# Already unsubscribed
if key not in integration_unsubscribe:
return
data = MqttServiceInfo( data = MqttServiceInfo(
topic=msg.topic, topic=msg.topic,
payload=msg.payload, payload=msg.payload,
@ -387,15 +393,15 @@ async def async_start( # noqa: C901
subscribed_topic=msg.subscribed_topic, subscribed_topic=msg.subscribed_topic,
timestamp=msg.timestamp, timestamp=msg.timestamp,
) )
result = await hass.config_entries.flow.async_init( await hass.config_entries.flow.async_init(
integration, context={"source": DOMAIN}, data=data integration, context={"source": DOMAIN}, data=data
) )
if ( if msg.payload:
result # Update the last discovered config message
and result["type"] == FlowResultType.ABORT integration_discovery_messages[msg.topic] = hash(msg.payload)
and result["reason"] == "single_instance_allowed" elif msg.topic in integration_discovery_messages:
): # Cleanup hash if discovery payload is empty
integration_unsubscribe.pop(key)() del integration_discovery_messages[msg.topic]
integration_unsubscribe.update( integration_unsubscribe.update(
{ {

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import copy import copy
import json import json
import logging
from pathlib import Path from pathlib import Path
import re import re
from unittest.mock import AsyncMock, call, patch from unittest.mock import AsyncMock, call, patch
@ -48,9 +49,11 @@ from .test_common import help_all_subscribe_calls, help_test_unload_config_entry
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
MockModule,
async_capture_events, async_capture_events,
async_fire_mqtt_message, async_fire_mqtt_message,
mock_config_flow, mock_config_flow,
mock_integration,
mock_platform, mock_platform,
) )
from tests.typing import ( from tests.typing import (
@ -1445,26 +1448,20 @@ async def test_complex_discovery_topic_prefix(
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("reason", "unsubscribes"), "reason", ["single_instance_allowed", "already_configured", "some_abort_error"]
[
("single_instance_allowed", True),
("already_configured", False),
("some_abort_error", False),
],
) )
async def test_mqtt_integration_discovery_subscribe_unsubscribe( async def test_mqtt_integration_discovery_flow_fitering_on_redundant_payload(
hass: HomeAssistant, hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient, reason: str
mqtt_client_mock: MqttMockPahoClient,
reason: str,
unsubscribes: bool,
) -> None: ) -> None:
"""Check MQTT integration discovery subscribe and unsubscribe.""" """Check MQTT integration discovery starts a flow once."""
flow_calls: list[MqttServiceInfo] = []
class TestFlow(config_entries.ConfigFlow): class TestFlow(config_entries.ConfigFlow):
"""Test flow.""" """Test flow."""
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step.""" """Test mqtt step."""
flow_calls.append(discovery_info)
return self.async_abort(reason=reason) return self.async_abort(reason=reason)
mock_platform(hass, "comp.config_flow", None) mock_platform(hass, "comp.config_flow", None)
@ -1493,30 +1490,38 @@ async def test_mqtt_integration_discovery_subscribe_unsubscribe(
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
mqtt_client_mock.reset_mock() mqtt_client_mock.reset_mock()
assert len(flow_calls) == 0
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config", "initial message")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 1
assert ( # A redundant message gets does not start a new flow
unsubscribes
and call(["comp/discovery/#"]) in mqtt_client_mock.unsubscribe.mock_calls
or not unsubscribes
and call(["comp/discovery/#"])
not in mqtt_client_mock.unsubscribe.mock_calls
)
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "initial message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 1
# An updated message gets starts a new flow
await hass.async_block_till_done(wait_background_tasks=True)
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "update message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 2
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.0)
async def test_mqtt_discovery_unsubscribe_once( async def test_mqtt_discovery_flow_starts_once(
hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Check MQTT integration discovery unsubscribe once.""" """Check MQTT integration discovery starts a flow once."""
flow_calls: list[MqttServiceInfo] = []
class TestFlow(config_entries.ConfigFlow): class TestFlow(config_entries.ConfigFlow):
"""Test flow.""" """Test flow."""
@ -1524,8 +1529,12 @@ async def test_mqtt_discovery_unsubscribe_once(
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step.""" """Test mqtt step."""
await asyncio.sleep(0) await asyncio.sleep(0)
return self.async_abort(reason="single_instance_allowed") flow_calls.append(discovery_info)
return self.async_create_entry(title="Test", data={})
mock_integration(
hass, MockModule(domain="comp", async_setup_entry=AsyncMock(return_value=True))
)
mock_platform(hass, "comp.config_flow", None) mock_platform(hass, "comp.config_flow", None)
birth = asyncio.Event() birth = asyncio.Event()
@ -1535,13 +1544,6 @@ async def test_mqtt_discovery_unsubscribe_once(
"""Handle birth message.""" """Handle birth message."""
birth.set() birth.set()
wait_unsub = asyncio.Event()
@callback
def _mock_unsubscribe(topics: list[str]) -> tuple[int, int]:
wait_unsub.set()
return (0, 0)
entry = MockConfigEntry(domain=mqtt.DOMAIN, data=ENTRY_DEFAULT_BIRTH_MESSAGE) entry = MockConfigEntry(domain=mqtt.DOMAIN, data=ENTRY_DEFAULT_BIRTH_MESSAGE)
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -1551,7 +1553,6 @@ async def test_mqtt_discovery_unsubscribe_once(
return_value={"comp": ["comp/discovery/#"]}, return_value={"comp": ["comp/discovery/#"]},
), ),
mock_config_flow("comp", TestFlow), mock_config_flow("comp", TestFlow),
patch.object(mqtt_client_mock, "unsubscribe", side_effect=_mock_unsubscribe),
): ):
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
@ -1559,17 +1560,45 @@ async def test_mqtt_discovery_unsubscribe_once(
await birth.wait() await birth.wait()
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
assert not mqtt_client_mock.unsubscribe.called
async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "initial message")
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") assert len(flow_calls) == 1
async_fire_mqtt_message(hass, "comp/discovery/bla/config", "") assert flow_calls[0].topic == "comp/discovery/bla/config1"
await wait_unsub.wait() assert flow_calls[0].payload == "initial message"
await asyncio.sleep(0)
with caplog.at_level(logging.DEBUG):
async_fire_mqtt_message(
hass, "comp/discovery/bla/config1", "initial message"
)
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
mqtt_client_mock.unsubscribe.assert_called_once_with(["comp/discovery/#"]) assert "Ignoring already processed discovery message" in caplog.text
assert len(flow_calls) == 1
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "initial message")
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 2
assert flow_calls[1].topic == "comp/discovery/bla/config2"
assert flow_calls[1].payload == "initial message"
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "update message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 3
assert flow_calls[2].topic == "comp/discovery/bla/config2"
assert flow_calls[2].payload == "update message"
# An empty message triggers a flow to allow cleanup
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 4
assert flow_calls[3].topic == "comp/discovery/bla/config2"
assert flow_calls[3].payload == ""
assert not mqtt_client_mock.unsubscribe.called
async def test_clear_config_topic_disabled_entity( async def test_clear_config_topic_disabled_entity(
hass: HomeAssistant, hass: HomeAssistant,