diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 18ce89beb9b..7788c1db641 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -45,7 +45,6 @@ from homeassistant.helpers.start import async_at_started from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from homeassistant.setup import SetupPhases, async_pause_setup -from homeassistant.util.async_ import create_eager_task from homeassistant.util.collection import chunked_or_all from homeassistant.util.logging import catch_log_exception, log_exception @@ -85,7 +84,7 @@ from .models import ( PublishPayloadType, ReceiveMessage, ) -from .util import get_file_path, mqtt_config_entry_enabled +from .util import EnsureJobAfterCooldown, get_file_path, mqtt_config_entry_enabled if TYPE_CHECKING: # Only import for paho-mqtt type checking here, imports are done locally @@ -358,103 +357,6 @@ class MqttClientSetup: return self._client -class EnsureJobAfterCooldown: - """Ensure a cool down period before executing a job. - - When a new execute request arrives we cancel the current request - and start a new one. - """ - - def __init__( - self, timeout: float, callback_job: Callable[[], Coroutine[Any, None, None]] - ) -> None: - """Initialize the timer.""" - self._loop = asyncio.get_running_loop() - self._timeout = timeout - self._callback = callback_job - self._task: asyncio.Task | None = None - self._timer: asyncio.TimerHandle | None = None - self._next_execute_time = 0.0 - - def set_timeout(self, timeout: float) -> None: - """Set a new timeout period.""" - self._timeout = timeout - - async def _async_job(self) -> None: - """Execute after a cooldown period.""" - try: - await self._callback() - except HomeAssistantError as ha_error: - _LOGGER.error("%s", ha_error) - - @callback - def _async_task_done(self, task: asyncio.Task) -> None: - """Handle task done.""" - self._task = None - - @callback - def async_execute(self) -> asyncio.Task: - """Execute the job.""" - if self._task: - # Task already running, - # so we schedule another run - self.async_schedule() - return self._task - - self._async_cancel_timer() - self._task = create_eager_task(self._async_job()) - self._task.add_done_callback(self._async_task_done) - return self._task - - @callback - def _async_cancel_timer(self) -> None: - """Cancel any pending task.""" - if self._timer: - self._timer.cancel() - self._timer = None - - @callback - def async_schedule(self) -> None: - """Ensure we execute after a cooldown period.""" - # We want to reschedule the timer in the future - # every time this is called. - next_when = self._loop.time() + self._timeout - if not self._timer: - self._timer = self._loop.call_at(next_when, self._async_timer_reached) - return - - if self._timer.when() < next_when: - # Timer already running, set the next execute time - # if it fires too early, it will get rescheduled - self._next_execute_time = next_when - - @callback - def _async_timer_reached(self) -> None: - """Handle timer fire.""" - self._timer = None - if self._loop.time() >= self._next_execute_time: - self.async_execute() - return - # Timer fired too early because there were multiple - # calls async_schedule. Reschedule the timer. - self._timer = self._loop.call_at( - self._next_execute_time, self._async_timer_reached - ) - - async def async_cleanup(self) -> None: - """Cleanup any pending task.""" - self._async_cancel_timer() - if not self._task: - return - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass - except Exception: - _LOGGER.exception("Error cleaning up task") - - class MQTT: """Home Assistant MQTT client.""" diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py index 256bad71ba6..97fa616fdd1 100644 --- a/homeassistant/components/mqtt/util.py +++ b/homeassistant/components/mqtt/util.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable, Coroutine from functools import lru_cache import logging import os @@ -14,7 +15,8 @@ import voluptuous as vol from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import MAX_LENGTH_STATE_STATE, STATE_UNKNOWN, Platform -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers.typing import ConfigType from homeassistant.util.async_ import create_eager_task @@ -40,6 +42,108 @@ TEMP_DIR_NAME = f"home-assistant-{DOMAIN}" _VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2])) +_LOGGER = logging.getLogger(__name__) + + +class EnsureJobAfterCooldown: + """Ensure a cool down period before executing a job. + + When a new execute request arrives we cancel the current request + and start a new one. + + We allow patching this util, as we generally have exceptions + for sleeps/waits/debouncers/timers causing long run times in tests. + """ + + def __init__( + self, timeout: float, callback_job: Callable[[], Coroutine[Any, None, None]] + ) -> None: + """Initialize the timer.""" + self._loop = asyncio.get_running_loop() + self._timeout = timeout + self._callback = callback_job + self._task: asyncio.Task | None = None + self._timer: asyncio.TimerHandle | None = None + self._next_execute_time = 0.0 + + def set_timeout(self, timeout: float) -> None: + """Set a new timeout period.""" + self._timeout = timeout + + async def _async_job(self) -> None: + """Execute after a cooldown period.""" + try: + await self._callback() + except HomeAssistantError as ha_error: + _LOGGER.error("%s", ha_error) + + @callback + def _async_task_done(self, task: asyncio.Task) -> None: + """Handle task done.""" + self._task = None + + @callback + def async_execute(self) -> asyncio.Task: + """Execute the job.""" + if self._task: + # Task already running, + # so we schedule another run + self.async_schedule() + return self._task + + self._async_cancel_timer() + self._task = create_eager_task(self._async_job()) + self._task.add_done_callback(self._async_task_done) + return self._task + + @callback + def _async_cancel_timer(self) -> None: + """Cancel any pending task.""" + if self._timer: + self._timer.cancel() + self._timer = None + + @callback + def async_schedule(self) -> None: + """Ensure we execute after a cooldown period.""" + # We want to reschedule the timer in the future + # every time this is called. + next_when = self._loop.time() + self._timeout + if not self._timer: + self._timer = self._loop.call_at(next_when, self._async_timer_reached) + return + + if self._timer.when() < next_when: + # Timer already running, set the next execute time + # if it fires too early, it will get rescheduled + self._next_execute_time = next_when + + @callback + def _async_timer_reached(self) -> None: + """Handle timer fire.""" + self._timer = None + if self._loop.time() >= self._next_execute_time: + self.async_execute() + return + # Timer fired too early because there were multiple + # calls async_schedule. Reschedule the timer. + self._timer = self._loop.call_at( + self._next_execute_time, self._async_timer_reached + ) + + async def async_cleanup(self) -> None: + """Cleanup any pending task.""" + self._async_cancel_timer() + if not self._task: + return + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + except Exception: + _LOGGER.exception("Error cleaning up task") + def platforms_from_config(config: list[ConfigType]) -> set[Platform | str]: """Return the platforms to be set up.""" diff --git a/tests/components/mqtt/conftest.py b/tests/components/mqtt/conftest.py index 39b9f122f75..5a1f65667cf 100644 --- a/tests/components/mqtt/conftest.py +++ b/tests/components/mqtt/conftest.py @@ -9,7 +9,7 @@ import pytest from typing_extensions import AsyncGenerator, Generator from homeassistant.components import mqtt -from homeassistant.components.mqtt.models import ReceiveMessage +from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import HomeAssistant, callback @@ -79,3 +79,21 @@ async def setup_with_birth_msg_client_mock( await hass.async_block_till_done() await birth.wait() yield mqtt_client_mock + + +@pytest.fixture +def recorded_calls() -> list[ReceiveMessage]: + """Fixture to hold recorded calls.""" + return [] + + +@pytest.fixture +def record_calls(recorded_calls: list[ReceiveMessage]) -> MessageCallbackType: + """Fixture to record calls.""" + + @callback + def record_calls(msg: ReceiveMessage) -> None: + """Record calls.""" + recorded_calls.append(msg) + + return record_calls diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 8a76c71f1f3..2c3ca31bff9 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -24,7 +24,6 @@ from homeassistant.components.mqtt import debug_info from homeassistant.components.mqtt.client import ( _LOGGER as CLIENT_LOGGER, RECONNECT_INTERVAL_SECONDS, - EnsureJobAfterCooldown, ) from homeassistant.components.mqtt.models import ( MessageCallbackType, @@ -101,24 +100,6 @@ def mock_storage(hass_storage: dict[str, Any]) -> None: """Autouse hass_storage for the TestCase tests.""" -@pytest.fixture -def recorded_calls() -> list[ReceiveMessage]: - """Fixture to hold recorded calls.""" - return [] - - -@pytest.fixture -def record_calls(recorded_calls: list[ReceiveMessage]) -> MessageCallbackType: - """Fixture to record calls.""" - - @callback - def record_calls(msg: ReceiveMessage) -> None: - """Record calls.""" - recorded_calls.append(msg) - - return record_calls - - @pytest.fixture def client_debug_log() -> Generator[None]: """Set the mqtt client log level to DEBUG.""" @@ -1070,6 +1051,7 @@ async def test_subscribe_topic( async def test_subscribe_topic_not_initialize( hass: HomeAssistant, + record_calls: MessageCallbackType, mqtt_mock_entry: MqttMockHAClientGenerator, ) -> None: """Test the subscription of a topic when MQTT was not initialized.""" @@ -1080,7 +1062,7 @@ async def test_subscribe_topic_not_initialize( async def test_subscribe_mqtt_config_entry_disabled( - hass: HomeAssistant, mqtt_mock: MqttMockHAClient + hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType ) -> None: """Test the subscription of a topic when MQTT config entry is disabled.""" mqtt_mock.connected = True @@ -2016,84 +1998,6 @@ async def test_reload_entry_with_restored_subscriptions( assert recorded_calls[1].payload == "wild-card-payload3" -async def test_canceling_debouncer_on_shutdown( - hass: HomeAssistant, - record_calls: MessageCallbackType, - setup_with_birth_msg_client_mock: MqttMockPahoClient, -) -> None: - """Test canceling the debouncer when HA shuts down.""" - mqtt_client_mock = setup_with_birth_msg_client_mock - - with patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 2): - await hass.async_block_till_done() - async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) - await hass.async_block_till_done() - await mqtt.async_subscribe(hass, "test/state1", record_calls) - async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) - # Stop HA so the scheduled debouncer task will be canceled - mqtt_client_mock.subscribe.reset_mock() - hass.bus.fire(EVENT_HOMEASSISTANT_STOP) - await mqtt.async_subscribe(hass, "test/state2", record_calls) - async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) - await mqtt.async_subscribe(hass, "test/state3", record_calls) - async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) - await mqtt.async_subscribe(hass, "test/state4", record_calls) - async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) - await mqtt.async_subscribe(hass, "test/state5", record_calls) - async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) - await hass.async_block_till_done() - - mqtt_client_mock.subscribe.assert_not_called() - - # Note thet the broker connection will not be disconnected gracefully - await hass.async_block_till_done() - async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) - await asyncio.sleep(0) - await hass.async_block_till_done(wait_background_tasks=True) - mqtt_client_mock.subscribe.assert_not_called() - mqtt_client_mock.disconnect.assert_not_called() - - -async def test_canceling_debouncer_normal( - hass: HomeAssistant, - caplog: pytest.LogCaptureFixture, -) -> None: - """Test canceling the debouncer before completion.""" - - async def _async_myjob() -> None: - await asyncio.sleep(1.0) - - debouncer = EnsureJobAfterCooldown(0.0, _async_myjob) - debouncer.async_schedule() - await asyncio.sleep(0.01) - assert debouncer._task is not None - await debouncer.async_cleanup() - assert debouncer._task is None - - -async def test_canceling_debouncer_throws( - hass: HomeAssistant, - caplog: pytest.LogCaptureFixture, -) -> None: - """Test canceling the debouncer when HA shuts down.""" - - async def _async_myjob() -> None: - await asyncio.sleep(1.0) - - debouncer = EnsureJobAfterCooldown(0.0, _async_myjob) - debouncer.async_schedule() - await asyncio.sleep(0.01) - assert debouncer._task is not None - # let debouncer._task fail by mocking it - with patch.object(debouncer, "_task") as task: - task.cancel = MagicMock(return_value=True) - await debouncer.async_cleanup() - assert "Error cleaning up task" in caplog.text - await hass.async_block_till_done() - async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) - await hass.async_block_till_done() - - async def test_initial_setup_logs_error( hass: HomeAssistant, caplog: pytest.LogCaptureFixture, diff --git a/tests/components/mqtt/test_util.py b/tests/components/mqtt/test_util.py index 290f561e1ad..955fc88448c 100644 --- a/tests/components/mqtt/test_util.py +++ b/tests/components/mqtt/test_util.py @@ -1,22 +1,106 @@ """Test MQTT utils.""" +import asyncio from collections.abc import Callable +from datetime import timedelta from pathlib import Path from random import getrandbits import shutil import tempfile -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from homeassistant.components import mqtt +from homeassistant.components.mqtt.models import MessageCallbackType +from homeassistant.components.mqtt.util import EnsureJobAfterCooldown from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState +from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import CoreState, HomeAssistant +from homeassistant.util.dt import utcnow -from tests.common import MockConfigEntry +from tests.common import MockConfigEntry, async_fire_time_changed from tests.typing import MqttMockHAClient, MqttMockPahoClient +async def test_canceling_debouncer_on_shutdown( + hass: HomeAssistant, + record_calls: MessageCallbackType, + setup_with_birth_msg_client_mock: MqttMockPahoClient, +) -> None: + """Test canceling the debouncer when HA shuts down.""" + mqtt_client_mock = setup_with_birth_msg_client_mock + + with patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 2): + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() + await mqtt.async_subscribe(hass, "test/state1", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) + # Stop HA so the scheduled debouncer task will be canceled + mqtt_client_mock.subscribe.reset_mock() + hass.bus.fire(EVENT_HOMEASSISTANT_STOP) + await mqtt.async_subscribe(hass, "test/state2", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) + await mqtt.async_subscribe(hass, "test/state3", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) + await mqtt.async_subscribe(hass, "test/state4", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) + await mqtt.async_subscribe(hass, "test/state5", record_calls) + async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1)) + await hass.async_block_till_done() + + mqtt_client_mock.subscribe.assert_not_called() + + # Note thet the broker connection will not be disconnected gracefully + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) + await asyncio.sleep(0) + await hass.async_block_till_done(wait_background_tasks=True) + mqtt_client_mock.subscribe.assert_not_called() + mqtt_client_mock.disconnect.assert_not_called() + + +async def test_canceling_debouncer_normal( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test canceling the debouncer before completion.""" + + async def _async_myjob() -> None: + await asyncio.sleep(1.0) + + debouncer = EnsureJobAfterCooldown(0.0, _async_myjob) + debouncer.async_schedule() + await asyncio.sleep(0.01) + assert debouncer._task is not None + await debouncer.async_cleanup() + assert debouncer._task is None + + +async def test_canceling_debouncer_throws( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test canceling the debouncer when HA shuts down.""" + + async def _async_myjob() -> None: + await asyncio.sleep(1.0) + + debouncer = EnsureJobAfterCooldown(0.0, _async_myjob) + debouncer.async_schedule() + await asyncio.sleep(0.01) + assert debouncer._task is not None + # let debouncer._task fail by mocking it + with patch.object(debouncer, "_task") as task: + task.cancel = MagicMock(return_value=True) + await debouncer.async_cleanup() + assert "Error cleaning up task" in caplog.text + await hass.async_block_till_done() + async_fire_time_changed(hass, utcnow() + timedelta(seconds=5)) + await hass.async_block_till_done() + + async def help_create_test_certificate_file( hass: HomeAssistant, mock_temp_dir: str,