From 67b3be84321a3bccb3e81e980a85e27744cd8a46 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 4 Jun 2024 14:21:03 -0500 Subject: [PATCH] Remove useless threading locks in mqtt (#118737) --- homeassistant/components/mqtt/async_client.py | 60 +++++++++++++++++++ homeassistant/components/mqtt/client.py | 16 +++-- tests/components/mqtt/test_config_flow.py | 8 ++- tests/components/mqtt/test_init.py | 33 +++++++--- tests/conftest.py | 4 +- 5 files changed, 106 insertions(+), 15 deletions(-) create mode 100644 homeassistant/components/mqtt/async_client.py diff --git a/homeassistant/components/mqtt/async_client.py b/homeassistant/components/mqtt/async_client.py new file mode 100644 index 00000000000..c0b847f35a1 --- /dev/null +++ b/homeassistant/components/mqtt/async_client.py @@ -0,0 +1,60 @@ +"""Async wrappings for mqtt client.""" + +from __future__ import annotations + +from functools import lru_cache +from types import TracebackType +from typing import Self + +from paho.mqtt.client import Client as MQTTClient + +_MQTT_LOCK_COUNT = 7 + + +class NullLock: + """Null lock.""" + + @lru_cache(maxsize=_MQTT_LOCK_COUNT) + def __enter__(self) -> Self: + """Enter the lock.""" + return self + + @lru_cache(maxsize=_MQTT_LOCK_COUNT) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """Exit the lock.""" + + @lru_cache(maxsize=_MQTT_LOCK_COUNT) + def acquire(self, blocking: bool = False, timeout: int = -1) -> None: + """Acquire the lock.""" + + @lru_cache(maxsize=_MQTT_LOCK_COUNT) + def release(self) -> None: + """Release the lock.""" + + +class AsyncMQTTClient(MQTTClient): + """Async MQTT Client. + + Wrapper around paho.mqtt.client.Client to remove the locking + that is not needed since we are running in an async event loop. + """ + + def async_setup(self) -> None: + """Set up the client. + + All the threading locks are replaced with NullLock + since the client is running in an async event loop + and will never run in multiple threads. + """ + self._in_callback_mutex = NullLock() + self._callback_mutex = NullLock() + self._msgtime_mutex = NullLock() + self._out_message_mutex = NullLock() + self._in_message_mutex = NullLock() + self._reconnect_delay_mutex = NullLock() + self._mid_generate_mutex = NullLock() diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index d36670baef1..f01cb9c948f 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -91,6 +91,8 @@ if TYPE_CHECKING: # because integrations should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt + from .async_client import AsyncMQTTClient + _LOGGER = logging.getLogger(__name__) MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails @@ -281,6 +283,9 @@ class MqttClientSetup: # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel + # pylint: disable-next=import-outside-toplevel + from .async_client import AsyncMQTTClient + if (protocol := config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)) == PROTOCOL_31: proto = mqtt.MQTTv31 elif protocol == PROTOCOL_5: @@ -293,9 +298,10 @@ class MqttClientSetup: # However, that feature is not mandatory so we generate our own. client_id = mqtt.base62(uuid.uuid4().int, padding=22) transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT) - self._client = mqtt.Client( + self._client = AsyncMQTTClient( client_id, protocol=proto, transport=transport, reconnect_on_failure=False ) + self._client.async_setup() # Enable logging self._client.enable_logger() @@ -329,7 +335,7 @@ class MqttClientSetup: self._client.tls_insecure_set(tls_insecure) @property - def client(self) -> mqtt.Client: + def client(self) -> AsyncMQTTClient: """Return the paho MQTT client.""" return self._client @@ -434,7 +440,7 @@ class EnsureJobAfterCooldown: class MQTT: """Home Assistant MQTT client.""" - _mqttc: mqtt.Client + _mqttc: AsyncMQTTClient _last_subscribe: float _mqtt_data: MqttData @@ -533,7 +539,9 @@ class MQTT: async def async_init_client(self) -> None: """Initialize paho client.""" with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES): - await async_import_module(self.hass, "paho.mqtt.client") + await async_import_module( + self.hass, "homeassistant.components.mqtt.async_client" + ) mqttc = MqttClientSetup(self.conf).client # on_socket_unregister_write and _async_on_socket_close diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index 576ba3f94b2..f218a5b0447 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -121,7 +121,9 @@ def mock_try_connection_success() -> Generator[MqttMockPahoClient, None, None]: mock_client().on_unsubscribe(mock_client, 0, mid) return (0, mid) - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: mock_client().loop_start = loop_start mock_client().subscribe = _subscribe mock_client().unsubscribe = _unsubscribe @@ -135,7 +137,9 @@ def mock_try_connection_time_out() -> Generator[MagicMock, None, None]: # Patch prevent waiting 5 sec for a timeout with ( - patch("paho.mqtt.client.Client") as mock_client, + patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client, patch("homeassistant.components.mqtt.config_flow.MQTT_TIMEOUT", 0), ): mock_client().loop_start = lambda *args: 1 diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 2b9e4260c7e..5189196ac2b 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -180,7 +180,9 @@ async def test_mqtt_await_ack_at_disconnect( mid = 100 rc = 0 - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: mqtt_client = mock_client.return_value mqtt_client.connect = MagicMock( return_value=0, @@ -191,10 +193,15 @@ async def test_mqtt_await_ack_at_disconnect( mqtt_client.publish = MagicMock(return_value=FakeInfo()) entry = MockConfigEntry( domain=mqtt.DOMAIN, - data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"}, + data={ + "certificate": "auto", + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_DISCOVERY: False, + }, ) entry.add_to_hass(hass) assert await hass.config_entries.async_setup(entry.entry_id) + mqtt_client = mock_client.return_value # publish from MQTT client without awaiting @@ -2219,7 +2226,9 @@ async def test_publish_error( entry.add_to_hass(hass) # simulate an Out of memory error - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: mock_client().connect = lambda *args: 1 mock_client().publish().rc = 1 assert await hass.config_entries.async_setup(entry.entry_id) @@ -2354,7 +2363,9 @@ async def test_setup_mqtt_client_protocol( protocol: int, ) -> None: """Test MQTT client protocol setup.""" - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: await mqtt_mock_entry() # check if protocol setup was correctly @@ -2374,7 +2385,9 @@ async def test_handle_mqtt_timeout_on_callback( mid = 100 rc = 0 - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: def _mock_ack(topic: str, qos: int = 0) -> tuple[int, int]: # Handle ACK for subscribe normally @@ -2419,7 +2432,9 @@ async def test_setup_raises_config_entry_not_ready_if_no_connect_broker( entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}) entry.add_to_hass(hass) - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: mock_client().connect = MagicMock(side_effect=OSError("Connection error")) assert await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() @@ -2454,7 +2469,9 @@ async def test_setup_uses_certificate_on_certificate_set_to_auto_and_insecure( def mock_tls_insecure_set(insecure_param) -> None: insecure_check["insecure"] = insecure_param - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: mock_client().tls_set = mock_tls_set mock_client().tls_insecure_set = mock_tls_insecure_set await mqtt_mock_entry() @@ -4023,7 +4040,7 @@ async def test_link_config_entry( assert _check_entities() == 2 # reload entry and assert again - with patch("paho.mqtt.client.Client"): + with patch("homeassistant.components.mqtt.async_client.AsyncMQTTClient"): await hass.config_entries.async_reload(mqtt_config_entry.entry_id) await hass.async_block_till_done() diff --git a/tests/conftest.py b/tests/conftest.py index 13a8daa8ce1..a6f9c34c568 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -920,7 +920,9 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None, self.mid = mid self.rc = 0 - with patch("paho.mqtt.client.Client") as mock_client: + with patch( + "homeassistant.components.mqtt.async_client.AsyncMQTTClient" + ) as mock_client: # The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe # callbacks to simulate the behavior of the real MQTT client which will # not be synchronous.