Compare commits

...

14 Commits

Author SHA1 Message Date
Jan Bouwhuis 1e97e49cba Merge branch 'dev' into mqtt-subscribe-identifier 2026-05-12 18:18:34 +02:00
jbouwh 406a6dda9b remove local import 2026-05-12 10:00:20 +00:00
Jan Bouwhuis 7633cc7252 Merge branch 'dev' into mqtt-subscribe-identifier 2026-05-11 15:52:29 +02:00
jbouwh 5169b6554a Follow up on code review 2026-05-05 13:06:08 +00:00
Jan Bouwhuis 2d9cf87e44 Merge branch 'dev' into mqtt-subscribe-identifier 2026-05-05 14:37:06 +02:00
jbouwh b5f04ef502 Break up long line 2026-05-03 12:37:01 +00:00
jbouwh d073d4fe4a Cache subscriptions for topic and subscription_id 2026-05-03 12:33:37 +00:00
jbouwh 2e6d8e3aea Set subscription ID for restored subscriptiions 2026-05-02 20:46:45 +00:00
jbouwh 0d9f5a32f4 Fix packet type 2026-05-02 20:28:41 +00:00
jbouwh feea4925cd Improve testcase labels 2026-05-02 20:14:33 +00:00
Jan Bouwhuis 3be267a94c Potential fix for pull request finding
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-02 22:10:21 +02:00
jbouwh 9da3788c52 Expand unsubscribe race tests with wildcard subscriptions and multiple protocols 2026-05-02 13:31:18 +00:00
Jan Bouwhuis cb6a38b10f Merge branch 'dev' into mqtt-subscribe-identifier 2026-05-02 14:29:12 +02:00
jbouwh 03e22e1cb2 Set subscription identifier to allow filtering duplicate payloads with overlapping subscriptions 2026-05-02 11:40:36 +00:00
5 changed files with 304 additions and 32 deletions
+90 -21
View File
@@ -110,7 +110,6 @@ TIMEOUT_ACK = 10
SUBSCRIBE_TIMEOUT = 10
RECONNECT_INTERVAL_SECONDS = 10
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
MAX_SUBSCRIBES_PER_CALL = 500
MAX_UNSUBSCRIBES_PER_CALL = 500
@@ -330,8 +329,9 @@ class Subscription:
is_simple_match: bool
complex_matcher: Callable[[str], bool] | None
job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None]
qos: int = 0
encoding: str | None = "utf-8"
qos: int
encoding: str | None
subscription_id: int
class MqttClientSetup:
@@ -479,6 +479,7 @@ class MQTT:
self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos
self._registered_subscriptions: dict[str, int] = {} # topic, subscription_id
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
)
@@ -853,6 +854,9 @@ class MQTT:
) -> None:
"""Restore tracked subscriptions after reload."""
for subscription in subscriptions:
self._registered_subscriptions[subscription.topic] = (
subscription.subscription_id
)
self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear()
@@ -958,7 +962,19 @@ class MQTT:
is_simple_match = not ("+" in topic or "#" in topic)
matcher = None if is_simple_match else _matcher_for_topic(topic)
subscription = Subscription(topic, is_simple_match, matcher, job, qos, encoding)
if is_simple_match:
subscription_id = 1
elif topic in self._registered_subscriptions:
subscription_id = self._registered_subscriptions[topic]
else:
subscription_id = self._registered_subscriptions[topic] = (
self._mqtt_data.subscription_id_generator.generate()
)
subscription = Subscription(
topic, is_simple_match, matcher, job, qos, encoding, subscription_id
)
self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear()
@@ -977,15 +993,15 @@ class MQTT:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(subscription.topic)
self._async_unsubscribe(subscription.topic, subscription.subscription_id)
@callback
def _async_unsubscribe(self, topic: str) -> None:
def _async_unsubscribe(self, topic: str, subscription_id: int) -> None:
"""Unsubscribe from a topic."""
if self.is_active_subscription(topic):
if self._max_qos[topic] == 0:
return
subs = self._matching_subscriptions(topic)
subs = self._matching_subscriptions(topic, (subscription_id,))
self._max_qos[topic] = max(sub.qos for sub in subs)
# Other subscriptions on topic remaining - don't unsubscribe.
return
@@ -1011,33 +1027,60 @@ class MQTT:
#
# Since we do not know if a published value is retained we need to
# (re)subscribe, to ensure retained messages are replayed
if not self._pending_subscriptions:
return
# Split out the wildcard subscriptions, we subscribe to them one by one
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
pending_subscriptions: dict[str, int] = self._pending_subscriptions
pending_wildcard_subscriptions = {
subscription.topic: pending_subscriptions.pop(subscription.topic)
for subscription in self._wildcard_subscriptions
if subscription.topic in pending_subscriptions
}
subscribe_chain = chunked_or_all(
pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL
)
if self.is_mqttv5 and pending_subscriptions:
bulk_properties = mqtt.Properties(packetType=mqtt.PacketTypes.SUBSCRIBE) # type: ignore[no-untyped-call]
bulk_properties.SubscriptionIdentifier = 1
else:
bulk_properties = None
self._pending_subscriptions = {}
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
for topic, qos in pending_wildcard_subscriptions.items():
if self.is_mqttv5:
properties = mqtt.Properties(packetType=mqtt.PacketTypes.SUBSCRIBE) # type: ignore[no-untyped-call]
properties.SubscriptionIdentifier = self._registered_subscriptions[
topic
]
else:
properties = None
for chunk in chain(
chunked_or_all(
pending_wildcard_subscriptions.items(), MAX_WILDCARD_SUBSCRIBES_PER_CALL
),
chunked_or_all(pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL),
):
result, mid = self._mqttc.subscribe(topic, qos, properties=properties)
if debug_enabled:
_LOGGER.debug(
"Subscribing with mid: %s to topic %s "
"with qos: %s and properties: %s",
mid,
topic,
qos,
properties,
)
self._last_subscribe = time.monotonic()
await self._async_wait_for_mid_or_raise(mid, result)
async_dispatcher_send(
self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, [(topic, qos)]
)
for chunk in subscribe_chain:
chunk_list = list(chunk)
if not chunk_list:
continue
result, mid = self._mqttc.subscribe(chunk_list)
result, mid = self._mqttc.subscribe(chunk_list, properties=bulk_properties)
if debug_enabled:
_LOGGER.debug(
@@ -1068,6 +1111,10 @@ class MQTT:
await self._async_wait_for_mid_or_raise(mid, result)
# Flush subscription identifiers if they are available
for topic in topics:
self._registered_subscriptions.pop(topic, None)
async def _async_resubscribe_and_publish_birth_message(
self, birth_message: PublishMessage
) -> None:
@@ -1166,16 +1213,27 @@ class MQTT:
)
@lru_cache(None) # pylint: disable=method-cache-max-size-none
def _matching_subscriptions(self, topic: str) -> list[Subscription]:
def _matching_subscriptions(
self, topic: str, identifiers: tuple[int, ...] | None
) -> list[Subscription]:
subscriptions: list[Subscription] = []
if topic in self._simple_subscriptions:
subscriptions.extend(self._simple_subscriptions[topic])
simple_subscriptions_for_topic = self._simple_subscriptions[topic]
if identifiers is None:
subscriptions.extend(simple_subscriptions_for_topic)
else:
subscriptions.extend(
subscription
for subscription in simple_subscriptions_for_topic
if subscription.subscription_id in identifiers
)
subscriptions.extend(
subscription
for subscription in self._wildcard_subscriptions
# mypy doesn't know that complex_matcher is always set when
# is_simple_match is False
if subscription.complex_matcher(topic) # type: ignore[misc]
and (identifiers is None or subscription.subscription_id in identifiers)
)
return subscriptions
@@ -1183,6 +1241,17 @@ class MQTT:
def _async_mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None:
identifiers: tuple[int,] | None = None
if self.is_mqttv5:
# It is possible we have multiple messages if there
# are overlapping wildcard subscriptions.
# So we assigned all wildcard subscriptions with a
# unique SubscriptionIdentifier. Simple subscriptions are assigned
# with SubscriptionIdentifier 1.
if msg.properties is not None and hasattr(
msg.properties, "SubscriptionIdentifier"
):
identifiers = tuple(msg.properties.SubscriptionIdentifier)
try:
# msg.topic is a property that decodes the topic to a string
# every time it is accessed. Save the result to avoid
@@ -1199,16 +1268,16 @@ class MQTT:
)
return
_LOGGER.debug(
"Received%s message on %s (qos=%s): %s",
"Received%s message on %s (qos=%s) IDs=%s: %s",
" retained" if msg.retain else "",
topic,
msg.qos,
identifiers,
msg.payload[0:8192],
)
subscriptions = self._matching_subscriptions(topic)
msg_cache_by_subscription_topic: dict[str, ReceiveMessage] = {}
for subscription in subscriptions:
for subscription in self._matching_subscriptions(topic, identifiers):
if msg.retain:
retained_topics = self._retained_topics[subscription]
# Skip if the subscription already received a retained message
+17
View File
@@ -42,6 +42,22 @@ class PayloadSentinel(StrEnum):
DEFAULT = "default"
class SubscriptionID:
"""ID generator for wildcard subscriptions."""
_id: int = 1
def generate(self) -> int:
"""Generate a new subscription ID.
ID 0 is reserved.
ID 1 is used for non wildcard topics.
Generator starts at ID 2.
"""
self._id = self._id + 1
return self._id
_LOGGER = logging.getLogger(__name__)
ATTR_THIS = "this"
@@ -421,6 +437,7 @@ class MqttData:
state_write_requests: EntityTopicState = field(default_factory=EntityTopicState)
subscriptions_to_restore: set[Subscription] = field(default_factory=set)
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
subscription_id_generator: SubscriptionID = field(default_factory=SubscriptionID)
@dataclass(slots=True)
+6 -1
View File
@@ -23,7 +23,7 @@ import os
import pathlib
import time
from types import FrameType, ModuleType
from typing import Any, Literal, NoReturn
from typing import TYPE_CHECKING, Any, Literal, NoReturn
from unittest.mock import AsyncMock, Mock, patch
from aiohttp.test_utils import unused_port as get_test_instance_port
@@ -122,6 +122,9 @@ from .testing_config.custom_components.test_constant_deprecation import (
import_deprecated_constant,
)
if TYPE_CHECKING:
import paho.mqtt.client as mqtt
__all__ = [
"async_get_device_automation_capabilities",
"get_test_instance_port",
@@ -452,6 +455,7 @@ def async_fire_mqtt_message(
payload: bytes | str,
qos: int = 0,
retain: bool = False,
properties: mqtt.Properties | None = None,
) -> None:
"""Fire the MQTT message."""
from homeassistant.components.mqtt import MqttData # noqa: PLC0415
@@ -464,6 +468,7 @@ def async_fire_mqtt_message(
msg.qos = qos
msg.retain = retain
msg.timestamp = time.monotonic()
msg.properties = properties
mqtt_data: MqttData = hass.data["mqtt"]
assert mqtt_data.client
+190 -9
View File
@@ -2,6 +2,7 @@
import asyncio
from datetime import timedelta
import json
import socket
import ssl
import time
@@ -1071,12 +1072,33 @@ async def test_not_calling_subscribe_when_unsubscribed_within_cooldown(
assert not mqtt_client_mock.subscribe.called
@pytest.mark.parametrize(
"mqtt_config_entry_data",
[
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "3.1",
},
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "3.1.1",
},
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "5",
},
],
ids=["v3.1", "v3.1.1", "v5"],
)
async def test_unsubscribe_race(
hass: HomeAssistant,
mock_debouncer: asyncio.Event,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
) -> None:
"""Test not calling unsubscribe() when other subscribers are active."""
"""Test not calling unsubscribe() when other subscribers are active.
Testing with simple topics.
"""
mqtt_client_mock = setup_with_birth_msg_client_mock
calls_a: list[ReceiveMessage] = []
calls_b: list[ReceiveMessage] = []
@@ -1104,16 +1126,89 @@ async def test_unsubscribe_race(
# We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
# when both subscriptions were combined [subscribe]
expected_calls_1 = [
call.subscribe([("test/state", 0)]),
call.subscribe([("test/state", 0)], properties=ANY),
call.unsubscribe("test/state"),
call.subscribe([("test/state", 0)]),
call.subscribe([("test/state", 0)], properties=ANY),
]
expected_calls_2 = [
call.subscribe([("test/state", 0)]),
call.subscribe([("test/state", 0)]),
call.subscribe([("test/state", 0)], properties=ANY),
call.subscribe([("test/state", 0)], properties=ANY),
]
expected_calls_3 = [
call.subscribe([("test/state", 0)]),
call.subscribe([("test/state", 0)], properties=ANY),
]
assert mqtt_client_mock.mock_calls in (
expected_calls_1,
expected_calls_2,
expected_calls_3,
)
@pytest.mark.parametrize(
"mqtt_config_entry_data",
[
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "3.1",
},
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "3.1.1",
},
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "5",
},
],
ids=["v3.1", "v3.1.1", "v5"],
)
@pytest.mark.parametrize("mqtt_config_entry_options", [ENTRY_DEFAULT_BIRTH_MESSAGE])
async def test_wildcard_unsubscribe_race(
hass: HomeAssistant,
mock_debouncer: asyncio.Event,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
) -> None:
"""Test not calling unsubscribe() when other subscribers are active.
Testing with wildcard topics.
"""
mqtt_client_mock = setup_with_birth_msg_client_mock
calls_a: list[ReceiveMessage] = []
calls_b: list[ReceiveMessage] = []
@callback
def _callback_a(msg: ReceiveMessage) -> None:
calls_a.append(msg)
@callback
def _callback_b(msg: ReceiveMessage) -> None:
calls_b.append(msg)
mqtt_client_mock.reset_mock()
mock_debouncer.clear()
unsub = await mqtt.async_subscribe(hass, "test/#", _callback_a)
unsub()
await mqtt.async_subscribe(hass, "test/#", _callback_b)
await mock_debouncer.wait()
async_fire_mqtt_message(hass, "test/state", "online")
assert not calls_a
assert calls_b
# We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
# when both subscriptions were combined [subscribe]
expected_calls_1 = [
call.subscribe("test/#", 0, properties=ANY),
call.unsubscribe("test/#"),
call.subscribe("test/#", 0, properties=ANY),
]
expected_calls_2 = [
call.subscribe("test/#", 0, properties=ANY),
call.subscribe("test/#", 0, properties=ANY),
]
expected_calls_3 = [
call.subscribe("test/#", 0, properties=ANY),
]
assert mqtt_client_mock.mock_calls in (
expected_calls_1,
@@ -1181,7 +1276,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
# the subscription with the highest QoS should survive
expected = [
call([("test/state", 2)]),
call([("test/state", 2)], properties=None),
]
assert mqtt_client_mock.subscribe.mock_calls == expected
@@ -1195,7 +1290,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
# wait for cooldown
await mock_debouncer.wait()
expected.append(call([("test/state", 1)]))
expected.append(call([("test/state", 1)], properties=None))
for expected_call in expected:
assert mqtt_client_mock.subscribe.hass_call(expected_call)
@@ -1387,7 +1482,7 @@ async def test_subscribe_error(
mqtt_client_mock = setup_with_birth_msg_client_mock
mqtt_client_mock.reset_mock()
# simulate client is not connected error before subscribing
mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None)
mqtt_client_mock.subscribe.side_effect = lambda *args, **kwargs: (4, None)
await mqtt.async_subscribe(hass, "some-topic", record_calls)
while mqtt_client_mock.subscribe.call_count == 0:
await hass.async_block_till_done()
@@ -2384,3 +2479,89 @@ async def test_loop_write_failure(
# Cleanup. Server is closed earlier already.
client.close()
@pytest.mark.parametrize(
("mqtt_config_entry_data", "mqtt_config_entry_options"),
[
(
{
mqtt.CONF_BROKER: "mock-broker",
CONF_PROTOCOL: "5",
},
ENTRY_DEFAULT_BIRTH_MESSAGE,
),
],
ids=["v5"],
)
async def test_overlapping_subscriptions_only_processed_once(
hass: HomeAssistant,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
) -> None:
"""Test messages are only processed once per subscription in case of overlap.
Overlapping subscriptions are only supported with MQTTv5
"""
mqtt_client_mock = setup_with_birth_msg_client_mock
assert mqtt_client_mock.connect.call_count == 1
mock_subscribe: MagicMock = mqtt_client_mock.subscribe
mock_subscribe.reset_mock()
# We create 3 sensors:
# - 2 with an overlapping wildcard subscription
# - 1 with an overlapping simple subscription
config1 = json.dumps(
{
"name": "test1",
"default_entity_id": "sensor.test1",
"unique_id": "test1_veryunique",
"state_topic": "test/+/status",
}
)
config2 = json.dumps(
{
"name": "test2",
"default_entity_id": "sensor.test2",
"unique_id": "test2_veryunique",
"state_topic": "test/#",
}
)
config3 = json.dumps(
{
"name": "test3",
"default_entity_id": "sensor.test3",
"unique_id": "test3_veryunique",
"state_topic": "test/bla/status",
}
)
async_fire_mqtt_message(hass, "homeassistant/sensor/config1/config", config1)
async_fire_mqtt_message(hass, "homeassistant/sensor/config2/config", config2)
async_fire_mqtt_message(hass, "homeassistant/sensor/config3/config", config3)
while len(mock_subscribe.mock_calls) < 3:
await hass.async_block_till_done()
message_identifiers = [
mock_call[2]["properties"].SubscriptionIdentifier[0]
for mock_call in mock_subscribe.mock_calls
]
assert hass.states.get("sensor.test1") is not None
assert hass.states.get("sensor.test2") is not None
assert hass.states.get("sensor.test3") is not None
with patch(
"homeassistant.components.mqtt.entity.MqttEntity.async_write_ha_state"
) as mock_async_ha_write_state:
# Simulate the broker sends a publish message at topic "test/bla/status"
# That matches all three subscriptions
for message_identifier in message_identifiers:
properties = paho_mqtt.Properties(paho_mqtt.PacketTypes.PUBLISH)
properties.SubscriptionIdentifier = message_identifier
async_fire_mqtt_message(
hass, "test/bla/status", "bla", properties=properties
)
await hass.async_block_till_done()
# Each sensor should receive one update, so we should have 3 state write calls
assert len(mock_async_ha_write_state.mock_calls) == 3
+1 -1
View File
@@ -1071,7 +1071,7 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient]:
)
return FakeInfo(mid)
def _subscribe(topic, qos=0):
def _subscribe(topic_or_list, qos=0, **kwargs):
mid = get_mid()
hass.loop.call_soon(
mock_client.on_subscribe, Mock(), 0, mid, [MockMqttReasonCode()], None