mirror of
https://github.com/home-assistant/core.git
synced 2026-05-12 17:04:32 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1e97e49cba | |||
| 406a6dda9b | |||
| 7633cc7252 | |||
| 5169b6554a | |||
| 2d9cf87e44 | |||
| b5f04ef502 | |||
| d073d4fe4a | |||
| 2e6d8e3aea | |||
| 0d9f5a32f4 | |||
| feea4925cd | |||
| 3be267a94c | |||
| 9da3788c52 | |||
| cb6a38b10f | |||
| 03e22e1cb2 |
@@ -110,7 +110,6 @@ TIMEOUT_ACK = 10
|
||||
SUBSCRIBE_TIMEOUT = 10
|
||||
RECONNECT_INTERVAL_SECONDS = 10
|
||||
|
||||
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
|
||||
MAX_SUBSCRIBES_PER_CALL = 500
|
||||
MAX_UNSUBSCRIBES_PER_CALL = 500
|
||||
|
||||
@@ -330,8 +329,9 @@ class Subscription:
|
||||
is_simple_match: bool
|
||||
complex_matcher: Callable[[str], bool] | None
|
||||
job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None]
|
||||
qos: int = 0
|
||||
encoding: str | None = "utf-8"
|
||||
qos: int
|
||||
encoding: str | None
|
||||
subscription_id: int
|
||||
|
||||
|
||||
class MqttClientSetup:
|
||||
@@ -479,6 +479,7 @@ class MQTT:
|
||||
|
||||
self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
|
||||
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
||||
self._registered_subscriptions: dict[str, int] = {} # topic, subscription_id
|
||||
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
||||
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
|
||||
)
|
||||
@@ -853,6 +854,9 @@ class MQTT:
|
||||
) -> None:
|
||||
"""Restore tracked subscriptions after reload."""
|
||||
for subscription in subscriptions:
|
||||
self._registered_subscriptions[subscription.topic] = (
|
||||
subscription.subscription_id
|
||||
)
|
||||
self._async_track_subscription(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
|
||||
@@ -958,7 +962,19 @@ class MQTT:
|
||||
is_simple_match = not ("+" in topic or "#" in topic)
|
||||
matcher = None if is_simple_match else _matcher_for_topic(topic)
|
||||
|
||||
subscription = Subscription(topic, is_simple_match, matcher, job, qos, encoding)
|
||||
if is_simple_match:
|
||||
subscription_id = 1
|
||||
elif topic in self._registered_subscriptions:
|
||||
subscription_id = self._registered_subscriptions[topic]
|
||||
else:
|
||||
subscription_id = self._registered_subscriptions[topic] = (
|
||||
self._mqtt_data.subscription_id_generator.generate()
|
||||
)
|
||||
|
||||
subscription = Subscription(
|
||||
topic, is_simple_match, matcher, job, qos, encoding, subscription_id
|
||||
)
|
||||
|
||||
self._async_track_subscription(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
|
||||
@@ -977,15 +993,15 @@ class MQTT:
|
||||
del self._retained_topics[subscription]
|
||||
# Only unsubscribe if currently connected
|
||||
if self.connected:
|
||||
self._async_unsubscribe(subscription.topic)
|
||||
self._async_unsubscribe(subscription.topic, subscription.subscription_id)
|
||||
|
||||
@callback
|
||||
def _async_unsubscribe(self, topic: str) -> None:
|
||||
def _async_unsubscribe(self, topic: str, subscription_id: int) -> None:
|
||||
"""Unsubscribe from a topic."""
|
||||
if self.is_active_subscription(topic):
|
||||
if self._max_qos[topic] == 0:
|
||||
return
|
||||
subs = self._matching_subscriptions(topic)
|
||||
subs = self._matching_subscriptions(topic, (subscription_id,))
|
||||
self._max_qos[topic] = max(sub.qos for sub in subs)
|
||||
# Other subscriptions on topic remaining - don't unsubscribe.
|
||||
return
|
||||
@@ -1011,33 +1027,60 @@ class MQTT:
|
||||
#
|
||||
# Since we do not know if a published value is retained we need to
|
||||
# (re)subscribe, to ensure retained messages are replayed
|
||||
|
||||
if not self._pending_subscriptions:
|
||||
return
|
||||
|
||||
# Split out the wildcard subscriptions, we subscribe to them one by one
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
pending_subscriptions: dict[str, int] = self._pending_subscriptions
|
||||
pending_wildcard_subscriptions = {
|
||||
subscription.topic: pending_subscriptions.pop(subscription.topic)
|
||||
for subscription in self._wildcard_subscriptions
|
||||
if subscription.topic in pending_subscriptions
|
||||
}
|
||||
subscribe_chain = chunked_or_all(
|
||||
pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL
|
||||
)
|
||||
if self.is_mqttv5 and pending_subscriptions:
|
||||
bulk_properties = mqtt.Properties(packetType=mqtt.PacketTypes.SUBSCRIBE) # type: ignore[no-untyped-call]
|
||||
bulk_properties.SubscriptionIdentifier = 1
|
||||
else:
|
||||
bulk_properties = None
|
||||
|
||||
self._pending_subscriptions = {}
|
||||
|
||||
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
||||
for topic, qos in pending_wildcard_subscriptions.items():
|
||||
if self.is_mqttv5:
|
||||
properties = mqtt.Properties(packetType=mqtt.PacketTypes.SUBSCRIBE) # type: ignore[no-untyped-call]
|
||||
properties.SubscriptionIdentifier = self._registered_subscriptions[
|
||||
topic
|
||||
]
|
||||
else:
|
||||
properties = None
|
||||
|
||||
for chunk in chain(
|
||||
chunked_or_all(
|
||||
pending_wildcard_subscriptions.items(), MAX_WILDCARD_SUBSCRIBES_PER_CALL
|
||||
),
|
||||
chunked_or_all(pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL),
|
||||
):
|
||||
result, mid = self._mqttc.subscribe(topic, qos, properties=properties)
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
"Subscribing with mid: %s to topic %s "
|
||||
"with qos: %s and properties: %s",
|
||||
mid,
|
||||
topic,
|
||||
qos,
|
||||
properties,
|
||||
)
|
||||
self._last_subscribe = time.monotonic()
|
||||
|
||||
await self._async_wait_for_mid_or_raise(mid, result)
|
||||
async_dispatcher_send(
|
||||
self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, [(topic, qos)]
|
||||
)
|
||||
|
||||
for chunk in subscribe_chain:
|
||||
chunk_list = list(chunk)
|
||||
if not chunk_list:
|
||||
continue
|
||||
|
||||
result, mid = self._mqttc.subscribe(chunk_list)
|
||||
result, mid = self._mqttc.subscribe(chunk_list, properties=bulk_properties)
|
||||
|
||||
if debug_enabled:
|
||||
_LOGGER.debug(
|
||||
@@ -1068,6 +1111,10 @@ class MQTT:
|
||||
|
||||
await self._async_wait_for_mid_or_raise(mid, result)
|
||||
|
||||
# Flush subscription identifiers if they are available
|
||||
for topic in topics:
|
||||
self._registered_subscriptions.pop(topic, None)
|
||||
|
||||
async def _async_resubscribe_and_publish_birth_message(
|
||||
self, birth_message: PublishMessage
|
||||
) -> None:
|
||||
@@ -1166,16 +1213,27 @@ class MQTT:
|
||||
)
|
||||
|
||||
@lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def _matching_subscriptions(self, topic: str) -> list[Subscription]:
|
||||
def _matching_subscriptions(
|
||||
self, topic: str, identifiers: tuple[int, ...] | None
|
||||
) -> list[Subscription]:
|
||||
subscriptions: list[Subscription] = []
|
||||
if topic in self._simple_subscriptions:
|
||||
subscriptions.extend(self._simple_subscriptions[topic])
|
||||
simple_subscriptions_for_topic = self._simple_subscriptions[topic]
|
||||
if identifiers is None:
|
||||
subscriptions.extend(simple_subscriptions_for_topic)
|
||||
else:
|
||||
subscriptions.extend(
|
||||
subscription
|
||||
for subscription in simple_subscriptions_for_topic
|
||||
if subscription.subscription_id in identifiers
|
||||
)
|
||||
subscriptions.extend(
|
||||
subscription
|
||||
for subscription in self._wildcard_subscriptions
|
||||
# mypy doesn't know that complex_matcher is always set when
|
||||
# is_simple_match is False
|
||||
if subscription.complex_matcher(topic) # type: ignore[misc]
|
||||
and (identifiers is None or subscription.subscription_id in identifiers)
|
||||
)
|
||||
return subscriptions
|
||||
|
||||
@@ -1183,6 +1241,17 @@ class MQTT:
|
||||
def _async_mqtt_on_message(
|
||||
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
||||
) -> None:
|
||||
identifiers: tuple[int,] | None = None
|
||||
if self.is_mqttv5:
|
||||
# It is possible we have multiple messages if there
|
||||
# are overlapping wildcard subscriptions.
|
||||
# So we assigned all wildcard subscriptions with a
|
||||
# unique SubscriptionIdentifier. Simple subscriptions are assigned
|
||||
# with SubscriptionIdentifier 1.
|
||||
if msg.properties is not None and hasattr(
|
||||
msg.properties, "SubscriptionIdentifier"
|
||||
):
|
||||
identifiers = tuple(msg.properties.SubscriptionIdentifier)
|
||||
try:
|
||||
# msg.topic is a property that decodes the topic to a string
|
||||
# every time it is accessed. Save the result to avoid
|
||||
@@ -1199,16 +1268,16 @@ class MQTT:
|
||||
)
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"Received%s message on %s (qos=%s): %s",
|
||||
"Received%s message on %s (qos=%s) IDs=%s: %s",
|
||||
" retained" if msg.retain else "",
|
||||
topic,
|
||||
msg.qos,
|
||||
identifiers,
|
||||
msg.payload[0:8192],
|
||||
)
|
||||
subscriptions = self._matching_subscriptions(topic)
|
||||
msg_cache_by_subscription_topic: dict[str, ReceiveMessage] = {}
|
||||
|
||||
for subscription in subscriptions:
|
||||
for subscription in self._matching_subscriptions(topic, identifiers):
|
||||
if msg.retain:
|
||||
retained_topics = self._retained_topics[subscription]
|
||||
# Skip if the subscription already received a retained message
|
||||
|
||||
@@ -42,6 +42,22 @@ class PayloadSentinel(StrEnum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class SubscriptionID:
|
||||
"""ID generator for wildcard subscriptions."""
|
||||
|
||||
_id: int = 1
|
||||
|
||||
def generate(self) -> int:
|
||||
"""Generate a new subscription ID.
|
||||
|
||||
ID 0 is reserved.
|
||||
ID 1 is used for non wildcard topics.
|
||||
Generator starts at ID 2.
|
||||
"""
|
||||
self._id = self._id + 1
|
||||
return self._id
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ATTR_THIS = "this"
|
||||
@@ -421,6 +437,7 @@ class MqttData:
|
||||
state_write_requests: EntityTopicState = field(default_factory=EntityTopicState)
|
||||
subscriptions_to_restore: set[Subscription] = field(default_factory=set)
|
||||
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
|
||||
subscription_id_generator: SubscriptionID = field(default_factory=SubscriptionID)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
||||
+6
-1
@@ -23,7 +23,7 @@ import os
|
||||
import pathlib
|
||||
import time
|
||||
from types import FrameType, ModuleType
|
||||
from typing import Any, Literal, NoReturn
|
||||
from typing import TYPE_CHECKING, Any, Literal, NoReturn
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from aiohttp.test_utils import unused_port as get_test_instance_port
|
||||
@@ -122,6 +122,9 @@ from .testing_config.custom_components.test_constant_deprecation import (
|
||||
import_deprecated_constant,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
__all__ = [
|
||||
"async_get_device_automation_capabilities",
|
||||
"get_test_instance_port",
|
||||
@@ -452,6 +455,7 @@ def async_fire_mqtt_message(
|
||||
payload: bytes | str,
|
||||
qos: int = 0,
|
||||
retain: bool = False,
|
||||
properties: mqtt.Properties | None = None,
|
||||
) -> None:
|
||||
"""Fire the MQTT message."""
|
||||
from homeassistant.components.mqtt import MqttData # noqa: PLC0415
|
||||
@@ -464,6 +468,7 @@ def async_fire_mqtt_message(
|
||||
msg.qos = qos
|
||||
msg.retain = retain
|
||||
msg.timestamp = time.monotonic()
|
||||
msg.properties = properties
|
||||
|
||||
mqtt_data: MqttData = hass.data["mqtt"]
|
||||
assert mqtt_data.client
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import json
|
||||
import socket
|
||||
import ssl
|
||||
import time
|
||||
@@ -1071,12 +1072,33 @@ async def test_not_calling_subscribe_when_unsubscribed_within_cooldown(
|
||||
assert not mqtt_client_mock.subscribe.called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mqtt_config_entry_data",
|
||||
[
|
||||
{
|
||||
mqtt.CONF_BROKER: "mock-broker",
|
||||
CONF_PROTOCOL: "3.1",
|
||||
},
|
||||
{
|
||||
mqtt.CONF_BROKER: "mock-broker",
|
||||
CONF_PROTOCOL: "3.1.1",
|
||||
},
|
||||
{
|
||||
mqtt.CONF_BROKER: "mock-broker",
|
||||
CONF_PROTOCOL: "5",
|
||||
},
|
||||
],
|
||||
ids=["v3.1", "v3.1.1", "v5"],
|
||||
)
|
||||
async def test_unsubscribe_race(
|
||||
hass: HomeAssistant,
|
||||
mock_debouncer: asyncio.Event,
|
||||
setup_with_birth_msg_client_mock: MqttMockPahoClient,
|
||||
) -> None:
|
||||
"""Test not calling unsubscribe() when other subscribers are active."""
|
||||
"""Test not calling unsubscribe() when other subscribers are active.
|
||||
|
||||
Testing with simple topics.
|
||||
"""
|
||||
mqtt_client_mock = setup_with_birth_msg_client_mock
|
||||
calls_a: list[ReceiveMessage] = []
|
||||
calls_b: list[ReceiveMessage] = []
|
||||
@@ -1104,16 +1126,89 @@ async def test_unsubscribe_race(
|
||||
# We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
|
||||
# when both subscriptions were combined [subscribe]
|
||||
expected_calls_1 = [
|
||||
call.subscribe([("test/state", 0)]),
|
||||
call.subscribe([("test/state", 0)], properties=ANY),
|
||||
call.unsubscribe("test/state"),
|
||||
call.subscribe([("test/state", 0)]),
|
||||
call.subscribe([("test/state", 0)], properties=ANY),
|
||||
]
|
||||
expected_calls_2 = [
|
||||
call.subscribe([("test/state", 0)]),
|
||||
call.subscribe([("test/state", 0)]),
|
||||
call.subscribe([("test/state", 0)], properties=ANY),
|
||||
call.subscribe([("test/state", 0)], properties=ANY),
|
||||
]
|
||||
expected_calls_3 = [
|
||||
call.subscribe([("test/state", 0)]),
|
||||
call.subscribe([("test/state", 0)], properties=ANY),
|
||||
]
|
||||
assert mqtt_client_mock.mock_calls in (
|
||||
expected_calls_1,
|
||||
expected_calls_2,
|
||||
expected_calls_3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mqtt_config_entry_data",
|
||||
[
|
||||
{
|
||||
mqtt.CONF_BROKER: "mock-broker",
|
||||
CONF_PROTOCOL: "3.1",
|
||||
},
|
||||
{
|
||||
mqtt.CONF_BROKER: "mock-broker",
|
||||
CONF_PROTOCOL: "3.1.1",
|
||||
},
|
||||
{
|
||||
mqtt.CONF_BROKER: "mock-broker",
|
||||
CONF_PROTOCOL: "5",
|
||||
},
|
||||
],
|
||||
ids=["v3.1", "v3.1.1", "v5"],
|
||||
)
|
||||
@pytest.mark.parametrize("mqtt_config_entry_options", [ENTRY_DEFAULT_BIRTH_MESSAGE])
|
||||
async def test_wildcard_unsubscribe_race(
|
||||
hass: HomeAssistant,
|
||||
mock_debouncer: asyncio.Event,
|
||||
setup_with_birth_msg_client_mock: MqttMockPahoClient,
|
||||
) -> None:
|
||||
"""Test not calling unsubscribe() when other subscribers are active.
|
||||
|
||||
Testing with wildcard topics.
|
||||
"""
|
||||
mqtt_client_mock = setup_with_birth_msg_client_mock
|
||||
calls_a: list[ReceiveMessage] = []
|
||||
calls_b: list[ReceiveMessage] = []
|
||||
|
||||
@callback
|
||||
def _callback_a(msg: ReceiveMessage) -> None:
|
||||
calls_a.append(msg)
|
||||
|
||||
@callback
|
||||
def _callback_b(msg: ReceiveMessage) -> None:
|
||||
calls_b.append(msg)
|
||||
|
||||
mqtt_client_mock.reset_mock()
|
||||
|
||||
mock_debouncer.clear()
|
||||
unsub = await mqtt.async_subscribe(hass, "test/#", _callback_a)
|
||||
unsub()
|
||||
await mqtt.async_subscribe(hass, "test/#", _callback_b)
|
||||
await mock_debouncer.wait()
|
||||
|
||||
async_fire_mqtt_message(hass, "test/state", "online")
|
||||
assert not calls_a
|
||||
assert calls_b
|
||||
|
||||
# We allow either calls [subscribe, unsubscribe, subscribe], [subscribe, subscribe] or
|
||||
# when both subscriptions were combined [subscribe]
|
||||
expected_calls_1 = [
|
||||
call.subscribe("test/#", 0, properties=ANY),
|
||||
call.unsubscribe("test/#"),
|
||||
call.subscribe("test/#", 0, properties=ANY),
|
||||
]
|
||||
expected_calls_2 = [
|
||||
call.subscribe("test/#", 0, properties=ANY),
|
||||
call.subscribe("test/#", 0, properties=ANY),
|
||||
]
|
||||
expected_calls_3 = [
|
||||
call.subscribe("test/#", 0, properties=ANY),
|
||||
]
|
||||
assert mqtt_client_mock.mock_calls in (
|
||||
expected_calls_1,
|
||||
@@ -1181,7 +1276,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
||||
|
||||
# the subscription with the highest QoS should survive
|
||||
expected = [
|
||||
call([("test/state", 2)]),
|
||||
call([("test/state", 2)], properties=None),
|
||||
]
|
||||
assert mqtt_client_mock.subscribe.mock_calls == expected
|
||||
|
||||
@@ -1195,7 +1290,7 @@ async def test_restore_all_active_subscriptions_on_reconnect(
|
||||
# wait for cooldown
|
||||
await mock_debouncer.wait()
|
||||
|
||||
expected.append(call([("test/state", 1)]))
|
||||
expected.append(call([("test/state", 1)], properties=None))
|
||||
for expected_call in expected:
|
||||
assert mqtt_client_mock.subscribe.hass_call(expected_call)
|
||||
|
||||
@@ -1387,7 +1482,7 @@ async def test_subscribe_error(
|
||||
mqtt_client_mock = setup_with_birth_msg_client_mock
|
||||
mqtt_client_mock.reset_mock()
|
||||
# simulate client is not connected error before subscribing
|
||||
mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None)
|
||||
mqtt_client_mock.subscribe.side_effect = lambda *args, **kwargs: (4, None)
|
||||
await mqtt.async_subscribe(hass, "some-topic", record_calls)
|
||||
while mqtt_client_mock.subscribe.call_count == 0:
|
||||
await hass.async_block_till_done()
|
||||
@@ -2384,3 +2479,89 @@ async def test_loop_write_failure(
|
||||
|
||||
# Cleanup. Server is closed earlier already.
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mqtt_config_entry_data", "mqtt_config_entry_options"),
|
||||
[
|
||||
(
|
||||
{
|
||||
mqtt.CONF_BROKER: "mock-broker",
|
||||
CONF_PROTOCOL: "5",
|
||||
},
|
||||
ENTRY_DEFAULT_BIRTH_MESSAGE,
|
||||
),
|
||||
],
|
||||
ids=["v5"],
|
||||
)
|
||||
async def test_overlapping_subscriptions_only_processed_once(
|
||||
hass: HomeAssistant,
|
||||
setup_with_birth_msg_client_mock: MqttMockPahoClient,
|
||||
) -> None:
|
||||
"""Test messages are only processed once per subscription in case of overlap.
|
||||
|
||||
Overlapping subscriptions are only supported with MQTTv5
|
||||
"""
|
||||
mqtt_client_mock = setup_with_birth_msg_client_mock
|
||||
assert mqtt_client_mock.connect.call_count == 1
|
||||
|
||||
mock_subscribe: MagicMock = mqtt_client_mock.subscribe
|
||||
mock_subscribe.reset_mock()
|
||||
|
||||
# We create 3 sensors:
|
||||
# - 2 with an overlapping wildcard subscription
|
||||
# - 1 with an overlapping simple subscription
|
||||
config1 = json.dumps(
|
||||
{
|
||||
"name": "test1",
|
||||
"default_entity_id": "sensor.test1",
|
||||
"unique_id": "test1_veryunique",
|
||||
"state_topic": "test/+/status",
|
||||
}
|
||||
)
|
||||
config2 = json.dumps(
|
||||
{
|
||||
"name": "test2",
|
||||
"default_entity_id": "sensor.test2",
|
||||
"unique_id": "test2_veryunique",
|
||||
"state_topic": "test/#",
|
||||
}
|
||||
)
|
||||
config3 = json.dumps(
|
||||
{
|
||||
"name": "test3",
|
||||
"default_entity_id": "sensor.test3",
|
||||
"unique_id": "test3_veryunique",
|
||||
"state_topic": "test/bla/status",
|
||||
}
|
||||
)
|
||||
|
||||
async_fire_mqtt_message(hass, "homeassistant/sensor/config1/config", config1)
|
||||
async_fire_mqtt_message(hass, "homeassistant/sensor/config2/config", config2)
|
||||
async_fire_mqtt_message(hass, "homeassistant/sensor/config3/config", config3)
|
||||
while len(mock_subscribe.mock_calls) < 3:
|
||||
await hass.async_block_till_done()
|
||||
|
||||
message_identifiers = [
|
||||
mock_call[2]["properties"].SubscriptionIdentifier[0]
|
||||
for mock_call in mock_subscribe.mock_calls
|
||||
]
|
||||
|
||||
assert hass.states.get("sensor.test1") is not None
|
||||
assert hass.states.get("sensor.test2") is not None
|
||||
assert hass.states.get("sensor.test3") is not None
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.mqtt.entity.MqttEntity.async_write_ha_state"
|
||||
) as mock_async_ha_write_state:
|
||||
# Simulate the broker sends a publish message at topic "test/bla/status"
|
||||
# That matches all three subscriptions
|
||||
for message_identifier in message_identifiers:
|
||||
properties = paho_mqtt.Properties(paho_mqtt.PacketTypes.PUBLISH)
|
||||
properties.SubscriptionIdentifier = message_identifier
|
||||
async_fire_mqtt_message(
|
||||
hass, "test/bla/status", "bla", properties=properties
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
# Each sensor should receive one update, so we should have 3 state write calls
|
||||
assert len(mock_async_ha_write_state.mock_calls) == 3
|
||||
|
||||
+1
-1
@@ -1071,7 +1071,7 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient]:
|
||||
)
|
||||
return FakeInfo(mid)
|
||||
|
||||
def _subscribe(topic, qos=0):
|
||||
def _subscribe(topic_or_list, qos=0, **kwargs):
|
||||
mid = get_mid()
|
||||
hass.loop.call_soon(
|
||||
mock_client.on_subscribe, Mock(), 0, mid, [MockMqttReasonCode()], None
|
||||
|
||||
Reference in New Issue
Block a user