Move mqtt debouncer to mqtt utils (#120392)

This commit is contained in:
Jan Bouwhuis 2024-06-25 10:33:58 +02:00 committed by GitHub
parent 46ed76df31
commit 1d16cbec96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 213 additions and 201 deletions

View File

@ -45,7 +45,6 @@ from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.setup import SetupPhases, async_pause_setup
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.collection import chunked_or_all
from homeassistant.util.logging import catch_log_exception, log_exception
@ -85,7 +84,7 @@ from .models import (
PublishPayloadType,
ReceiveMessage,
)
from .util import get_file_path, mqtt_config_entry_enabled
from .util import EnsureJobAfterCooldown, get_file_path, mqtt_config_entry_enabled
if TYPE_CHECKING:
# Only import for paho-mqtt type checking here, imports are done locally
@ -358,103 +357,6 @@ class MqttClientSetup:
return self._client
class EnsureJobAfterCooldown:
"""Ensure a cool down period before executing a job.
When a new execute request arrives we cancel the current request
and start a new one.
"""
def __init__(
self, timeout: float, callback_job: Callable[[], Coroutine[Any, None, None]]
) -> None:
"""Initialize the timer."""
self._loop = asyncio.get_running_loop()
self._timeout = timeout
self._callback = callback_job
self._task: asyncio.Task | None = None
self._timer: asyncio.TimerHandle | None = None
self._next_execute_time = 0.0
def set_timeout(self, timeout: float) -> None:
"""Set a new timeout period."""
self._timeout = timeout
async def _async_job(self) -> None:
"""Execute after a cooldown period."""
try:
await self._callback()
except HomeAssistantError as ha_error:
_LOGGER.error("%s", ha_error)
@callback
def _async_task_done(self, task: asyncio.Task) -> None:
"""Handle task done."""
self._task = None
@callback
def async_execute(self) -> asyncio.Task:
"""Execute the job."""
if self._task:
# Task already running,
# so we schedule another run
self.async_schedule()
return self._task
self._async_cancel_timer()
self._task = create_eager_task(self._async_job())
self._task.add_done_callback(self._async_task_done)
return self._task
@callback
def _async_cancel_timer(self) -> None:
"""Cancel any pending task."""
if self._timer:
self._timer.cancel()
self._timer = None
@callback
def async_schedule(self) -> None:
"""Ensure we execute after a cooldown period."""
# We want to reschedule the timer in the future
# every time this is called.
next_when = self._loop.time() + self._timeout
if not self._timer:
self._timer = self._loop.call_at(next_when, self._async_timer_reached)
return
if self._timer.when() < next_when:
# Timer already running, set the next execute time
# if it fires too early, it will get rescheduled
self._next_execute_time = next_when
@callback
def _async_timer_reached(self) -> None:
"""Handle timer fire."""
self._timer = None
if self._loop.time() >= self._next_execute_time:
self.async_execute()
return
# Timer fired too early because there were multiple
# calls async_schedule. Reschedule the timer.
self._timer = self._loop.call_at(
self._next_execute_time, self._async_timer_reached
)
async def async_cleanup(self) -> None:
"""Cleanup any pending task."""
self._async_cancel_timer()
if not self._task:
return
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
except Exception:
_LOGGER.exception("Error cleaning up task")
class MQTT:
"""Home Assistant MQTT client."""

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from functools import lru_cache
import logging
import os
@ -14,7 +15,8 @@ import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import MAX_LENGTH_STATE_STATE, STATE_UNKNOWN, Platform
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.async_ import create_eager_task
@ -40,6 +42,108 @@ TEMP_DIR_NAME = f"home-assistant-{DOMAIN}"
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
_LOGGER = logging.getLogger(__name__)
class EnsureJobAfterCooldown:
"""Ensure a cool down period before executing a job.
When a new execute request arrives we cancel the current request
and start a new one.
We allow patching this util, as we generally have exceptions
for sleeps/waits/debouncers/timers causing long run times in tests.
"""
def __init__(
self, timeout: float, callback_job: Callable[[], Coroutine[Any, None, None]]
) -> None:
"""Initialize the timer."""
self._loop = asyncio.get_running_loop()
self._timeout = timeout
self._callback = callback_job
self._task: asyncio.Task | None = None
self._timer: asyncio.TimerHandle | None = None
self._next_execute_time = 0.0
def set_timeout(self, timeout: float) -> None:
"""Set a new timeout period."""
self._timeout = timeout
async def _async_job(self) -> None:
"""Execute after a cooldown period."""
try:
await self._callback()
except HomeAssistantError as ha_error:
_LOGGER.error("%s", ha_error)
@callback
def _async_task_done(self, task: asyncio.Task) -> None:
"""Handle task done."""
self._task = None
@callback
def async_execute(self) -> asyncio.Task:
"""Execute the job."""
if self._task:
# Task already running,
# so we schedule another run
self.async_schedule()
return self._task
self._async_cancel_timer()
self._task = create_eager_task(self._async_job())
self._task.add_done_callback(self._async_task_done)
return self._task
@callback
def _async_cancel_timer(self) -> None:
"""Cancel any pending task."""
if self._timer:
self._timer.cancel()
self._timer = None
@callback
def async_schedule(self) -> None:
"""Ensure we execute after a cooldown period."""
# We want to reschedule the timer in the future
# every time this is called.
next_when = self._loop.time() + self._timeout
if not self._timer:
self._timer = self._loop.call_at(next_when, self._async_timer_reached)
return
if self._timer.when() < next_when:
# Timer already running, set the next execute time
# if it fires too early, it will get rescheduled
self._next_execute_time = next_when
@callback
def _async_timer_reached(self) -> None:
"""Handle timer fire."""
self._timer = None
if self._loop.time() >= self._next_execute_time:
self.async_execute()
return
# Timer fired too early because there were multiple
# calls async_schedule. Reschedule the timer.
self._timer = self._loop.call_at(
self._next_execute_time, self._async_timer_reached
)
async def async_cleanup(self) -> None:
"""Cleanup any pending task."""
self._async_cancel_timer()
if not self._task:
return
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
except Exception:
_LOGGER.exception("Error cleaning up task")
def platforms_from_config(config: list[ConfigType]) -> set[Platform | str]:
"""Return the platforms to be set up."""

View File

@ -9,7 +9,7 @@ import pytest
from typing_extensions import AsyncGenerator, Generator
from homeassistant.components import mqtt
from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import HomeAssistant, callback
@ -79,3 +79,21 @@ async def setup_with_birth_msg_client_mock(
await hass.async_block_till_done()
await birth.wait()
yield mqtt_client_mock
@pytest.fixture
def recorded_calls() -> list[ReceiveMessage]:
"""Fixture to hold recorded calls."""
return []
@pytest.fixture
def record_calls(recorded_calls: list[ReceiveMessage]) -> MessageCallbackType:
"""Fixture to record calls."""
@callback
def record_calls(msg: ReceiveMessage) -> None:
"""Record calls."""
recorded_calls.append(msg)
return record_calls

View File

@ -24,7 +24,6 @@ from homeassistant.components.mqtt import debug_info
from homeassistant.components.mqtt.client import (
_LOGGER as CLIENT_LOGGER,
RECONNECT_INTERVAL_SECONDS,
EnsureJobAfterCooldown,
)
from homeassistant.components.mqtt.models import (
MessageCallbackType,
@ -101,24 +100,6 @@ def mock_storage(hass_storage: dict[str, Any]) -> None:
"""Autouse hass_storage for the TestCase tests."""
@pytest.fixture
def recorded_calls() -> list[ReceiveMessage]:
"""Fixture to hold recorded calls."""
return []
@pytest.fixture
def record_calls(recorded_calls: list[ReceiveMessage]) -> MessageCallbackType:
"""Fixture to record calls."""
@callback
def record_calls(msg: ReceiveMessage) -> None:
"""Record calls."""
recorded_calls.append(msg)
return record_calls
@pytest.fixture
def client_debug_log() -> Generator[None]:
"""Set the mqtt client log level to DEBUG."""
@ -1070,6 +1051,7 @@ async def test_subscribe_topic(
async def test_subscribe_topic_not_initialize(
hass: HomeAssistant,
record_calls: MessageCallbackType,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test the subscription of a topic when MQTT was not initialized."""
@ -1080,7 +1062,7 @@ async def test_subscribe_topic_not_initialize(
async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType
) -> None:
"""Test the subscription of a topic when MQTT config entry is disabled."""
mqtt_mock.connected = True
@ -2016,84 +1998,6 @@ async def test_reload_entry_with_restored_subscriptions(
assert recorded_calls[1].payload == "wild-card-payload3"
async def test_canceling_debouncer_on_shutdown(
hass: HomeAssistant,
record_calls: MessageCallbackType,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
) -> None:
"""Test canceling the debouncer when HA shuts down."""
mqtt_client_mock = setup_with_birth_msg_client_mock
with patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 2):
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state1", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
# Stop HA so the scheduled debouncer task will be canceled
mqtt_client_mock.subscribe.reset_mock()
hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
await mqtt.async_subscribe(hass, "test/state2", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await mqtt.async_subscribe(hass, "test/state3", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await mqtt.async_subscribe(hass, "test/state4", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await mqtt.async_subscribe(hass, "test/state5", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await hass.async_block_till_done()
mqtt_client_mock.subscribe.assert_not_called()
# Note thet the broker connection will not be disconnected gracefully
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await asyncio.sleep(0)
await hass.async_block_till_done(wait_background_tasks=True)
mqtt_client_mock.subscribe.assert_not_called()
mqtt_client_mock.disconnect.assert_not_called()
async def test_canceling_debouncer_normal(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test canceling the debouncer before completion."""
async def _async_myjob() -> None:
await asyncio.sleep(1.0)
debouncer = EnsureJobAfterCooldown(0.0, _async_myjob)
debouncer.async_schedule()
await asyncio.sleep(0.01)
assert debouncer._task is not None
await debouncer.async_cleanup()
assert debouncer._task is None
async def test_canceling_debouncer_throws(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test canceling the debouncer when HA shuts down."""
async def _async_myjob() -> None:
await asyncio.sleep(1.0)
debouncer = EnsureJobAfterCooldown(0.0, _async_myjob)
debouncer.async_schedule()
await asyncio.sleep(0.01)
assert debouncer._task is not None
# let debouncer._task fail by mocking it
with patch.object(debouncer, "_task") as task:
task.cancel = MagicMock(return_value=True)
await debouncer.async_cleanup()
assert "Error cleaning up task" in caplog.text
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
async def test_initial_setup_logs_error(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,

View File

@ -1,22 +1,106 @@
"""Test MQTT utils."""
import asyncio
from collections.abc import Callable
from datetime import timedelta
from pathlib import Path
from random import getrandbits
import shutil
import tempfile
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.components import mqtt
from homeassistant.components.mqtt.models import MessageCallbackType
from homeassistant.components.mqtt.util import EnsureJobAfterCooldown
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, HomeAssistant
from homeassistant.util.dt import utcnow
from tests.common import MockConfigEntry
from tests.common import MockConfigEntry, async_fire_time_changed
from tests.typing import MqttMockHAClient, MqttMockPahoClient
async def test_canceling_debouncer_on_shutdown(
hass: HomeAssistant,
record_calls: MessageCallbackType,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
) -> None:
"""Test canceling the debouncer when HA shuts down."""
mqtt_client_mock = setup_with_birth_msg_client_mock
with patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 2):
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
await mqtt.async_subscribe(hass, "test/state1", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
# Stop HA so the scheduled debouncer task will be canceled
mqtt_client_mock.subscribe.reset_mock()
hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
await mqtt.async_subscribe(hass, "test/state2", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await mqtt.async_subscribe(hass, "test/state3", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await mqtt.async_subscribe(hass, "test/state4", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await mqtt.async_subscribe(hass, "test/state5", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=0.1))
await hass.async_block_till_done()
mqtt_client_mock.subscribe.assert_not_called()
# Note thet the broker connection will not be disconnected gracefully
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await asyncio.sleep(0)
await hass.async_block_till_done(wait_background_tasks=True)
mqtt_client_mock.subscribe.assert_not_called()
mqtt_client_mock.disconnect.assert_not_called()
async def test_canceling_debouncer_normal(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test canceling the debouncer before completion."""
async def _async_myjob() -> None:
await asyncio.sleep(1.0)
debouncer = EnsureJobAfterCooldown(0.0, _async_myjob)
debouncer.async_schedule()
await asyncio.sleep(0.01)
assert debouncer._task is not None
await debouncer.async_cleanup()
assert debouncer._task is None
async def test_canceling_debouncer_throws(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test canceling the debouncer when HA shuts down."""
async def _async_myjob() -> None:
await asyncio.sleep(1.0)
debouncer = EnsureJobAfterCooldown(0.0, _async_myjob)
debouncer.async_schedule()
await asyncio.sleep(0.01)
assert debouncer._task is not None
# let debouncer._task fail by mocking it
with patch.object(debouncer, "_task") as task:
task.cancel = MagicMock(return_value=True)
await debouncer.async_cleanup()
assert "Error cleaning up task" in caplog.text
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
async def help_create_test_certificate_file(
hass: HomeAssistant,
mock_temp_dir: str,