Allow re-discovery of mqtt integration config payloads (#127362)

This commit is contained in:
Jan Bouwhuis 2024-10-26 07:21:52 +02:00 committed by GitHub
parent d8b618f7c3
commit d237180a98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 166 additions and 35 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from dataclasses import dataclass
import functools import functools
from itertools import chain from itertools import chain
import logging import logging
@ -11,9 +12,14 @@ import re
import time import time
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import (
SOURCE_MQTT,
ConfigEntry,
signal_discovered_config_entry_removed,
)
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HassJobType, HomeAssistant, callback from homeassistant.core import HassJobType, HomeAssistant, callback
from homeassistant.helpers import discovery_flow
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
@ -71,6 +77,14 @@ class MQTTDiscoveryPayload(dict[str, Any]):
discovery_data: DiscoveryInfoType discovery_data: DiscoveryInfoType
@dataclass(frozen=True)
class MQTTIntegrationDiscoveryConfig:
"""Class to hold an integration discovery playload."""
integration: str
msg: ReceiveMessage
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None: def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry from already discovered list.""" """Clear entry from already discovered list."""
hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash) hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash)
@ -191,7 +205,7 @@ async def async_start( # noqa: C901
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_data = hass.data[DATA_MQTT] mqtt_data = hass.data[DATA_MQTT]
platform_setup_lock: dict[str, asyncio.Lock] = {} platform_setup_lock: dict[str, asyncio.Lock] = {}
integration_discovery_messages: dict[str, int] = {} integration_discovery_messages: dict[str, MQTTIntegrationDiscoveryConfig] = {}
@callback @callback
def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None: def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None:
@ -364,13 +378,39 @@ async def async_start( # noqa: C901
mqtt_integrations = await async_get_mqtt(hass) mqtt_integrations = await async_get_mqtt(hass)
integration_unsubscribe = mqtt_data.integration_unsubscribe integration_unsubscribe = mqtt_data.integration_unsubscribe
async def _async_handle_config_entry_removed(entry: ConfigEntry) -> None:
"""Handle integration config entry changes."""
for discovery_key in entry.discovery_keys[DOMAIN]:
if (
discovery_key.version != 1
or not isinstance(discovery_key.key, str)
or discovery_key.key not in integration_discovery_messages
):
continue
topic = discovery_key.key
discovery_message = integration_discovery_messages[topic]
del integration_discovery_messages[topic]
_LOGGER.debug("Rediscover service on topic %s", topic)
# Initiate re-discovery
await async_integration_message_received(
discovery_message.integration, discovery_message.msg
)
mqtt_data.discovery_unsubscribe.append(
async_dispatcher_connect(
hass,
signal_discovered_config_entry_removed(DOMAIN),
_async_handle_config_entry_removed,
)
)
async def async_integration_message_received( async def async_integration_message_received(
integration: str, msg: ReceiveMessage integration: str, msg: ReceiveMessage
) -> None: ) -> None:
"""Process the received message.""" """Process the received message."""
if ( if (
msg.topic in integration_discovery_messages msg.topic in integration_discovery_messages
and integration_discovery_messages[msg.topic] == hash(msg.payload) and integration_discovery_messages[msg.topic].msg.payload == msg.payload
): ):
_LOGGER.debug( _LOGGER.debug(
"Ignoring already processed discovery message for '%s' on topic %s: %s", "Ignoring already processed discovery message for '%s' on topic %s: %s",
@ -393,14 +433,23 @@ async def async_start( # noqa: C901
subscribed_topic=msg.subscribed_topic, subscribed_topic=msg.subscribed_topic,
timestamp=msg.timestamp, timestamp=msg.timestamp,
) )
await hass.config_entries.flow.async_init( discovery_key = discovery_flow.DiscoveryKey(
integration, context={"source": DOMAIN}, data=data domain=DOMAIN, key=msg.topic, version=1
)
discovery_flow.async_create_flow(
hass,
integration,
{"source": SOURCE_MQTT},
data,
discovery_key=discovery_key,
) )
if msg.payload: if msg.payload:
# Update the last discovered config message # Update the last discovered config message
integration_discovery_messages[msg.topic] = hash(msg.payload) integration_discovery_messages[msg.topic] = (
MQTTIntegrationDiscoveryConfig(integration=integration, msg=msg)
)
elif msg.topic in integration_discovery_messages: elif msg.topic in integration_discovery_messages:
# Cleanup hash if discovery payload is empty # Cleanup cache if discovery payload is empty
del integration_discovery_messages[msg.topic] del integration_discovery_messages[msg.topic]
integration_unsubscribe.update( integration_unsubscribe.update(

View File

@ -34,7 +34,7 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
@ -63,6 +63,53 @@ from tests.typing import (
) )
@pytest.fixture
def mqtt_data_flow_calls() -> list[MqttServiceInfo]:
"""Return list to capture MQTT data data flow calls."""
return []
@pytest.fixture
async def mock_mqtt_flow(
hass: HomeAssistant, mqtt_data_flow_calls: list[MqttServiceInfo]
) -> config_entries.ConfigFlow:
"""Test fixure for mqtt integration flow.
The topic is used as a unique ID.
The component test domain used is: `comp`.
Creates an entry if does not exist.
Updates an entry if it exists, and there is an updated payload.
"""
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step."""
await asyncio.sleep(0)
mqtt_data_flow_calls.append(discovery_info)
# Abort a flow if there is an update for the existing entry
if entry := self.hass.config_entries.async_entry_for_domain_unique_id(
"comp", discovery_info.topic
):
hass.config_entries.async_update_entry(
entry,
data={
"name": discovery_info.topic,
"payload": discovery_info.payload,
},
)
raise AbortFlow("already_configured")
await self.async_set_unique_id(discovery_info.topic)
return self.async_create_entry(
title="Test",
data={"name": discovery_info.topic, "payload": discovery_info.payload},
)
return TestFlow
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mqtt_config_entry_data", "mqtt_config_entry_data",
[{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}], [{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],
@ -1518,20 +1565,14 @@ async def test_mqtt_discovery_flow_starts_once(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
mock_mqtt_flow: config_entries.ConfigFlow,
mqtt_data_flow_calls: list[MqttServiceInfo],
) -> None: ) -> None:
"""Check MQTT integration discovery starts a flow once.""" """Check MQTT integration discovery starts a flow once.
flow_calls: list[MqttServiceInfo] = []
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step."""
await asyncio.sleep(0)
flow_calls.append(discovery_info)
return self.async_create_entry(title="Test", data={})
A flow should be started once after discovery,
and after an entry was removed, to trigger re-discovery.
"""
mock_integration( mock_integration(
hass, MockModule(domain="comp", async_setup_entry=AsyncMock(return_value=True)) hass, MockModule(domain="comp", async_setup_entry=AsyncMock(return_value=True))
) )
@ -1552,7 +1593,7 @@ async def test_mqtt_discovery_flow_starts_once(
"homeassistant.components.mqtt.discovery.async_get_mqtt", "homeassistant.components.mqtt.discovery.async_get_mqtt",
return_value={"comp": ["comp/discovery/#"]}, return_value={"comp": ["comp/discovery/#"]},
), ),
mock_config_flow("comp", TestFlow), mock_config_flow("comp", mock_mqtt_flow),
): ):
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
@ -1561,41 +1602,82 @@ async def test_mqtt_discovery_flow_starts_once(
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock) assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
# Test the initial flow
async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "initial message") async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "initial message")
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 1 assert len(mqtt_data_flow_calls) == 1
assert flow_calls[0].topic == "comp/discovery/bla/config1" assert mqtt_data_flow_calls[0].topic == "comp/discovery/bla/config1"
assert flow_calls[0].payload == "initial message" assert mqtt_data_flow_calls[0].payload == "initial message"
# Test we can ignore updates if they are the same
with caplog.at_level(logging.DEBUG): with caplog.at_level(logging.DEBUG):
async_fire_mqtt_message( async_fire_mqtt_message(
hass, "comp/discovery/bla/config1", "initial message" hass, "comp/discovery/bla/config1", "initial message"
) )
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert "Ignoring already processed discovery message" in caplog.text assert "Ignoring already processed discovery message" in caplog.text
assert len(flow_calls) == 1 assert len(mqtt_data_flow_calls) == 1
# Test we can apply updates
async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "update message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(mqtt_data_flow_calls) == 2
assert mqtt_data_flow_calls[1].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[1].payload == "update message"
# Test we set up multiple entries
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "initial message") async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "initial message")
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 2 assert len(mqtt_data_flow_calls) == 3
assert flow_calls[1].topic == "comp/discovery/bla/config2" assert mqtt_data_flow_calls[2].topic == "comp/discovery/bla/config2"
assert flow_calls[1].payload == "initial message" assert mqtt_data_flow_calls[2].payload == "initial message"
# Test we update multiple entries
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "update message") async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "update message")
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 3 assert len(mqtt_data_flow_calls) == 4
assert flow_calls[2].topic == "comp/discovery/bla/config2" assert mqtt_data_flow_calls[3].topic == "comp/discovery/bla/config2"
assert flow_calls[2].payload == "update message" assert mqtt_data_flow_calls[3].payload == "update message"
# An empty message triggers a flow to allow cleanup # Test an empty message triggers a flow to allow cleanup (if needed)
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "") async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "")
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 4 assert len(mqtt_data_flow_calls) == 5
assert flow_calls[3].topic == "comp/discovery/bla/config2" assert mqtt_data_flow_calls[4].topic == "comp/discovery/bla/config2"
assert flow_calls[3].payload == "" assert mqtt_data_flow_calls[4].payload == ""
# Cleanup the the second entry
assert (
entry := hass.config_entries.async_entry_for_domain_unique_id(
"comp", "comp/discovery/bla/config2"
)
) is not None
await hass.config_entries.async_remove(entry.entry_id)
assert len(hass.config_entries.async_entries(domain="comp")) == 1
# Remove remaining entry1 and assert this triggers an
# automatic re-discovery flow with latest config
assert (
entry := hass.config_entries.async_entry_for_domain_unique_id(
"comp", "comp/discovery/bla/config1"
)
) is not None
assert entry.unique_id == "comp/discovery/bla/config1"
await hass.config_entries.async_remove(entry.entry_id)
assert len(hass.config_entries.async_entries(domain="comp")) == 0
# Wait for re-discovery flow to complete
await hass.async_block_till_done(wait_background_tasks=True)
assert len(mqtt_data_flow_calls) == 6
assert mqtt_data_flow_calls[5].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[5].payload == "update message"
# Re-discovery triggered the config flow
assert len(hass.config_entries.async_entries(domain="comp")) == 1
assert not mqtt_client_mock.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called