From 392d5c673ae0cf0419afc3d4699dee223ea7e4f4 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 8 Oct 2020 08:52:23 +0200 Subject: [PATCH] Cache matching MQTT subscriptions (#41433) --- homeassistant/components/mqtt/__init__.py | 18 ++++++++++++++---- tests/conftest.py | 6 +++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 8da9a642bf2..17fd31e81c5 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -1,6 +1,6 @@ """Support for MQTT message handling.""" import asyncio -from functools import partial, wraps +from functools import lru_cache, partial, wraps import inspect from itertools import groupby import json @@ -842,6 +842,7 @@ class MQTT: topic, _matcher_for_topic(topic), msg_callback, qos, encoding ) self.subscriptions.append(subscription) + self._matching_subscriptions.cache_clear() # Only subscribe if currently connected. if self.connected: @@ -854,6 +855,7 @@ class MQTT: if subscription not in self.subscriptions: raise HomeAssistantError("Can't remove subscription twice") self.subscriptions.remove(subscription) + self._matching_subscriptions.cache_clear() if any(other.topic == topic for other in self.subscriptions): # Other subscriptions on topic remaining - don't unsubscribe. @@ -944,6 +946,14 @@ class MQTT: """Message received callback.""" self.hass.add_job(self._mqtt_handle_message, msg) + @lru_cache(2048) + def _matching_subscriptions(self, topic): + subscriptions = [] + for subscription in self.subscriptions: + if subscription.matcher(topic): + subscriptions.append(subscription) + return subscriptions + @callback def _mqtt_handle_message(self, msg) -> None: _LOGGER.debug( @@ -954,9 +964,9 @@ class MQTT: ) timestamp = dt_util.utcnow() - for subscription in self.subscriptions: - if not subscription.matcher(msg.topic): - continue + subscriptions = self._matching_subscriptions(msg.topic) + + for subscription in subscriptions: payload: SubscribePayloadType = msg.payload if subscription.encoding is not None: diff --git a/tests/conftest.py b/tests/conftest.py index 64bcb8dc951..1b4d54a6fdb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -384,9 +384,13 @@ async def mqtt_mock(hass, mqtt_client_mock, mqtt_config): assert result await hass.async_block_till_done() + # Workaround: asynctest==0.13 fails on @functools.lru_cache + spec = dir(hass.data["mqtt"]) + spec.remove("_matching_subscriptions") + mqtt_component_mock = MagicMock( return_value=hass.data["mqtt"], - spec_set=hass.data["mqtt"], + spec_set=spec, wraps=hass.data["mqtt"], ) mqtt_component_mock._mqttc = mqtt_client_mock