Cache matching MQTT subscriptions (#41433)

This commit is contained in:
Erik Montnemery 2020-10-08 08:52:23 +02:00 committed by GitHub
parent 85603dcd08
commit 392d5c673a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 5 deletions

View File

@ -1,6 +1,6 @@
"""Support for MQTT message handling.""" """Support for MQTT message handling."""
import asyncio import asyncio
from functools import partial, wraps from functools import lru_cache, partial, wraps
import inspect import inspect
from itertools import groupby from itertools import groupby
import json import json
@ -842,6 +842,7 @@ class MQTT:
topic, _matcher_for_topic(topic), msg_callback, qos, encoding topic, _matcher_for_topic(topic), msg_callback, qos, encoding
) )
self.subscriptions.append(subscription) self.subscriptions.append(subscription)
self._matching_subscriptions.cache_clear()
# Only subscribe if currently connected. # Only subscribe if currently connected.
if self.connected: if self.connected:
@ -854,6 +855,7 @@ class MQTT:
if subscription not in self.subscriptions: if subscription not in self.subscriptions:
raise HomeAssistantError("Can't remove subscription twice") raise HomeAssistantError("Can't remove subscription twice")
self.subscriptions.remove(subscription) self.subscriptions.remove(subscription)
self._matching_subscriptions.cache_clear()
if any(other.topic == topic for other in self.subscriptions): if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe. # Other subscriptions on topic remaining - don't unsubscribe.
@ -944,6 +946,14 @@ class MQTT:
"""Message received callback.""" """Message received callback."""
self.hass.add_job(self._mqtt_handle_message, msg) self.hass.add_job(self._mqtt_handle_message, msg)
@lru_cache(2048)
def _matching_subscriptions(self, topic):
subscriptions = []
for subscription in self.subscriptions:
if subscription.matcher(topic):
subscriptions.append(subscription)
return subscriptions
@callback @callback
def _mqtt_handle_message(self, msg) -> None: def _mqtt_handle_message(self, msg) -> None:
_LOGGER.debug( _LOGGER.debug(
@ -954,9 +964,9 @@ class MQTT:
) )
timestamp = dt_util.utcnow() timestamp = dt_util.utcnow()
for subscription in self.subscriptions: subscriptions = self._matching_subscriptions(msg.topic)
if not subscription.matcher(msg.topic):
continue for subscription in subscriptions:
payload: SubscribePayloadType = msg.payload payload: SubscribePayloadType = msg.payload
if subscription.encoding is not None: if subscription.encoding is not None:

View File

@ -384,9 +384,13 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
assert result assert result
await hass.async_block_till_done() await hass.async_block_till_done()
# Workaround: asynctest==0.13 fails on @functools.lru_cache
spec = dir(hass.data["mqtt"])
spec.remove("_matching_subscriptions")
mqtt_component_mock = MagicMock( mqtt_component_mock = MagicMock(
return_value=hass.data["mqtt"], return_value=hass.data["mqtt"],
spec_set=hass.data["mqtt"], spec_set=spec,
wraps=hass.data["mqtt"], wraps=hass.data["mqtt"],
) )
mqtt_component_mock._mqttc = mqtt_client_mock mqtt_component_mock._mqttc = mqtt_client_mock