Allow re-discovery of mqtt integration config payloads (#127362)

This commit is contained in:
Jan Bouwhuis 2024-10-26 07:21:52 +02:00 committed by GitHub
parent d8b618f7c3
commit d237180a98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 166 additions and 35 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections import deque
from dataclasses import dataclass
import functools
from itertools import chain
import logging
@ -11,9 +12,14 @@ import re
import time
from typing import TYPE_CHECKING, Any
from homeassistant.config_entries import ConfigEntry
from homeassistant.config_entries import (
SOURCE_MQTT,
ConfigEntry,
signal_discovered_config_entry_removed,
)
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HassJobType, HomeAssistant, callback
from homeassistant.helpers import discovery_flow
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
@ -71,6 +77,14 @@ class MQTTDiscoveryPayload(dict[str, Any]):
discovery_data: DiscoveryInfoType
@dataclass(frozen=True)
class MQTTIntegrationDiscoveryConfig:
"""Class to hold an integration discovery playload."""
integration: str
msg: ReceiveMessage
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry from already discovered list."""
hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash)
@ -191,7 +205,7 @@ async def async_start( # noqa: C901
"""Start MQTT Discovery."""
mqtt_data = hass.data[DATA_MQTT]
platform_setup_lock: dict[str, asyncio.Lock] = {}
integration_discovery_messages: dict[str, int] = {}
integration_discovery_messages: dict[str, MQTTIntegrationDiscoveryConfig] = {}
@callback
def _async_add_component(discovery_payload: MQTTDiscoveryPayload) -> None:
@ -364,13 +378,39 @@ async def async_start( # noqa: C901
mqtt_integrations = await async_get_mqtt(hass)
integration_unsubscribe = mqtt_data.integration_unsubscribe
async def _async_handle_config_entry_removed(entry: ConfigEntry) -> None:
"""Handle integration config entry changes."""
for discovery_key in entry.discovery_keys[DOMAIN]:
if (
discovery_key.version != 1
or not isinstance(discovery_key.key, str)
or discovery_key.key not in integration_discovery_messages
):
continue
topic = discovery_key.key
discovery_message = integration_discovery_messages[topic]
del integration_discovery_messages[topic]
_LOGGER.debug("Rediscover service on topic %s", topic)
# Initiate re-discovery
await async_integration_message_received(
discovery_message.integration, discovery_message.msg
)
mqtt_data.discovery_unsubscribe.append(
async_dispatcher_connect(
hass,
signal_discovered_config_entry_removed(DOMAIN),
_async_handle_config_entry_removed,
)
)
async def async_integration_message_received(
integration: str, msg: ReceiveMessage
) -> None:
"""Process the received message."""
if (
msg.topic in integration_discovery_messages
and integration_discovery_messages[msg.topic] == hash(msg.payload)
and integration_discovery_messages[msg.topic].msg.payload == msg.payload
):
_LOGGER.debug(
"Ignoring already processed discovery message for '%s' on topic %s: %s",
@ -393,14 +433,23 @@ async def async_start( # noqa: C901
subscribed_topic=msg.subscribed_topic,
timestamp=msg.timestamp,
)
await hass.config_entries.flow.async_init(
integration, context={"source": DOMAIN}, data=data
discovery_key = discovery_flow.DiscoveryKey(
domain=DOMAIN, key=msg.topic, version=1
)
discovery_flow.async_create_flow(
hass,
integration,
{"source": SOURCE_MQTT},
data,
discovery_key=discovery_key,
)
if msg.payload:
# Update the last discovered config message
integration_discovery_messages[msg.topic] = hash(msg.payload)
integration_discovery_messages[msg.topic] = (
MQTTIntegrationDiscoveryConfig(integration=integration, msg=msg)
)
elif msg.topic in integration_discovery_messages:
# Cleanup hash if discovery payload is empty
# Cleanup cache if discovery payload is empty
del integration_discovery_messages[msg.topic]
integration_unsubscribe.update(

View File

@ -34,7 +34,7 @@ from homeassistant.const import (
Platform,
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
@ -63,6 +63,53 @@ from tests.typing import (
)
@pytest.fixture
def mqtt_data_flow_calls() -> list[MqttServiceInfo]:
"""Return list to capture MQTT data data flow calls."""
return []
@pytest.fixture
async def mock_mqtt_flow(
hass: HomeAssistant, mqtt_data_flow_calls: list[MqttServiceInfo]
) -> config_entries.ConfigFlow:
"""Test fixure for mqtt integration flow.
The topic is used as a unique ID.
The component test domain used is: `comp`.
Creates an entry if does not exist.
Updates an entry if it exists, and there is an updated payload.
"""
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step."""
await asyncio.sleep(0)
mqtt_data_flow_calls.append(discovery_info)
# Abort a flow if there is an update for the existing entry
if entry := self.hass.config_entries.async_entry_for_domain_unique_id(
"comp", discovery_info.topic
):
hass.config_entries.async_update_entry(
entry,
data={
"name": discovery_info.topic,
"payload": discovery_info.payload,
},
)
raise AbortFlow("already_configured")
await self.async_set_unique_id(discovery_info.topic)
return self.async_create_entry(
title="Test",
data={"name": discovery_info.topic, "payload": discovery_info.payload},
)
return TestFlow
@pytest.mark.parametrize(
"mqtt_config_entry_data",
[{mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_DISCOVERY: False}],
@ -1518,20 +1565,14 @@ async def test_mqtt_discovery_flow_starts_once(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture,
mock_mqtt_flow: config_entries.ConfigFlow,
mqtt_data_flow_calls: list[MqttServiceInfo],
) -> None:
"""Check MQTT integration discovery starts a flow once."""
flow_calls: list[MqttServiceInfo] = []
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
"""Test mqtt step."""
await asyncio.sleep(0)
flow_calls.append(discovery_info)
return self.async_create_entry(title="Test", data={})
"""Check MQTT integration discovery starts a flow once.
A flow should be started once after discovery,
and after an entry was removed, to trigger re-discovery.
"""
mock_integration(
hass, MockModule(domain="comp", async_setup_entry=AsyncMock(return_value=True))
)
@ -1552,7 +1593,7 @@ async def test_mqtt_discovery_flow_starts_once(
"homeassistant.components.mqtt.discovery.async_get_mqtt",
return_value={"comp": ["comp/discovery/#"]},
),
mock_config_flow("comp", TestFlow),
mock_config_flow("comp", mock_mqtt_flow),
):
assert await hass.config_entries.async_setup(entry.entry_id)
await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth)
@ -1561,41 +1602,82 @@ async def test_mqtt_discovery_flow_starts_once(
assert ("comp/discovery/#", 0) in help_all_subscribe_calls(mqtt_client_mock)
# Test the initial flow
async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "initial message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 1
assert flow_calls[0].topic == "comp/discovery/bla/config1"
assert flow_calls[0].payload == "initial message"
assert len(mqtt_data_flow_calls) == 1
assert mqtt_data_flow_calls[0].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[0].payload == "initial message"
# Test we can ignore updates if they are the same
with caplog.at_level(logging.DEBUG):
async_fire_mqtt_message(
hass, "comp/discovery/bla/config1", "initial message"
)
await hass.async_block_till_done(wait_background_tasks=True)
assert "Ignoring already processed discovery message" in caplog.text
assert len(flow_calls) == 1
assert len(mqtt_data_flow_calls) == 1
# Test we can apply updates
async_fire_mqtt_message(hass, "comp/discovery/bla/config1", "update message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(mqtt_data_flow_calls) == 2
assert mqtt_data_flow_calls[1].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[1].payload == "update message"
# Test we set up multiple entries
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "initial message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 2
assert flow_calls[1].topic == "comp/discovery/bla/config2"
assert flow_calls[1].payload == "initial message"
assert len(mqtt_data_flow_calls) == 3
assert mqtt_data_flow_calls[2].topic == "comp/discovery/bla/config2"
assert mqtt_data_flow_calls[2].payload == "initial message"
# Test we update multiple entries
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "update message")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 3
assert flow_calls[2].topic == "comp/discovery/bla/config2"
assert flow_calls[2].payload == "update message"
assert len(mqtt_data_flow_calls) == 4
assert mqtt_data_flow_calls[3].topic == "comp/discovery/bla/config2"
assert mqtt_data_flow_calls[3].payload == "update message"
# An empty message triggers a flow to allow cleanup
# Test an empty message triggers a flow to allow cleanup (if needed)
async_fire_mqtt_message(hass, "comp/discovery/bla/config2", "")
await hass.async_block_till_done(wait_background_tasks=True)
assert len(flow_calls) == 4
assert flow_calls[3].topic == "comp/discovery/bla/config2"
assert flow_calls[3].payload == ""
assert len(mqtt_data_flow_calls) == 5
assert mqtt_data_flow_calls[4].topic == "comp/discovery/bla/config2"
assert mqtt_data_flow_calls[4].payload == ""
# Cleanup the the second entry
assert (
entry := hass.config_entries.async_entry_for_domain_unique_id(
"comp", "comp/discovery/bla/config2"
)
) is not None
await hass.config_entries.async_remove(entry.entry_id)
assert len(hass.config_entries.async_entries(domain="comp")) == 1
# Remove remaining entry1 and assert this triggers an
# automatic re-discovery flow with latest config
assert (
entry := hass.config_entries.async_entry_for_domain_unique_id(
"comp", "comp/discovery/bla/config1"
)
) is not None
assert entry.unique_id == "comp/discovery/bla/config1"
await hass.config_entries.async_remove(entry.entry_id)
assert len(hass.config_entries.async_entries(domain="comp")) == 0
# Wait for re-discovery flow to complete
await hass.async_block_till_done(wait_background_tasks=True)
assert len(mqtt_data_flow_calls) == 6
assert mqtt_data_flow_calls[5].topic == "comp/discovery/bla/config1"
assert mqtt_data_flow_calls[5].payload == "update message"
# Re-discovery triggered the config flow
assert len(hass.config_entries.async_entries(domain="comp")) == 1
assert not mqtt_client_mock.unsubscribe.called