diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 8a907a8d9fa..333d955ab63 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -1,6 +1,7 @@ """Support for recording details.""" from __future__ import annotations +import abc import asyncio from collections.abc import Callable, Iterable import concurrent.futures @@ -11,7 +12,7 @@ import queue import sqlite3 import threading import time -from typing import Any, NamedTuple +from typing import Any from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy.exc import SQLAlchemyError @@ -327,66 +328,161 @@ def _async_register_services(hass, instance): ) -class ClearStatisticsTask(NamedTuple): +class RecorderTask(abc.ABC): + """ABC for recorder tasks.""" + + @abc.abstractmethod + def run(self, instance: Recorder) -> None: + """Handle the task.""" + + +@dataclass +class ClearStatisticsTask(RecorderTask): """Object to store statistics_ids which for which to remove statistics.""" statistic_ids: list[str] + def run(self, instance: Recorder) -> None: + """Handle the task.""" + statistics.clear_statistics(instance, self.statistic_ids) -class UpdateStatisticsMetadataTask(NamedTuple): + +@dataclass +class UpdateStatisticsMetadataTask(RecorderTask): """Object to store statistics_id and unit for update of statistics metadata.""" statistic_id: str unit_of_measurement: str | None + def run(self, instance: Recorder) -> None: + """Handle the task.""" + statistics.update_statistics_metadata( + instance, self.statistic_id, self.unit_of_measurement + ) -class PurgeTask(NamedTuple): + +@dataclass +class PurgeTask(RecorderTask): """Object to store information about purge task.""" purge_before: datetime repack: bool apply_filter: bool + def run(self, instance: Recorder) -> None: + """Purge the database.""" + if purge.purge_old_data( + instance, self.purge_before, self.repack, self.apply_filter + ): + # We always need to do the db cleanups after a purge + # is finished to ensure the WAL checkpoint and other + # tasks happen after a vacuum. + perodic_db_cleanups(instance) + return + # Schedule a new purge task if this one didn't finish + instance.queue.put(PurgeTask(self.purge_before, self.repack, self.apply_filter)) -class PurgeEntitiesTask(NamedTuple): + +@dataclass +class PurgeEntitiesTask(RecorderTask): """Object to store entity information about purge task.""" entity_filter: Callable[[str], bool] + def run(self, instance: Recorder) -> None: + """Purge entities from the database.""" + if purge.purge_entity_data(instance, self.entity_filter): + return + # Schedule a new purge task if this one didn't finish + instance.queue.put(PurgeEntitiesTask(self.entity_filter)) -class PerodicCleanupTask: + +@dataclass +class PerodicCleanupTask(RecorderTask): """An object to insert into the recorder to trigger cleanup tasks when auto purge is disabled.""" + def run(self, instance: Recorder) -> None: + """Handle the task.""" + perodic_db_cleanups(instance) -class StatisticsTask(NamedTuple): + +@dataclass +class StatisticsTask(RecorderTask): """An object to insert into the recorder queue to run a statistics task.""" start: datetime + def run(self, instance: Recorder) -> None: + """Run statistics task.""" + if statistics.compile_statistics(instance, self.start): + return + # Schedule a new statistics task if this one didn't finish + instance.queue.put(StatisticsTask(self.start)) -class ExternalStatisticsTask(NamedTuple): + +@dataclass +class ExternalStatisticsTask(RecorderTask): """An object to insert into the recorder queue to run an external statistics task.""" metadata: dict statistics: Iterable[dict] - -class WaitTask: - """An object to insert into the recorder queue to tell it set the _queue_watch event.""" + def run(self, instance: Recorder) -> None: + """Run statistics task.""" + if statistics.add_external_statistics(instance, self.metadata, self.statistics): + return + # Schedule a new statistics task if this one didn't finish + instance.queue.put(ExternalStatisticsTask(self.metadata, self.statistics)) @dataclass -class DatabaseLockTask: +class WaitTask(RecorderTask): + """An object to insert into the recorder queue to tell it set the _queue_watch event.""" + + def run(self, instance: Recorder) -> None: + """Handle the task.""" + instance._queue_watch.set() # pylint: disable=[protected-access] + + +@dataclass +class DatabaseLockTask(RecorderTask): """An object to insert into the recorder queue to prevent writes to the database.""" database_locked: asyncio.Event database_unlock: threading.Event queue_overflow: bool + def run(self, instance: Recorder) -> None: + """Handle the task.""" + instance._lock_database(self) # pylint: disable=[protected-access] + + +@dataclass +class StopTask(RecorderTask): + """An object to insert into the recorder queue to stop the event handler.""" + + def run(self, instance: Recorder) -> None: + """Handle the task.""" + instance.stop_requested = True + + +@dataclass +class EventTask(RecorderTask): + """An object to insert into the recorder queue to stop the event handler.""" + + event: bool + + def run(self, instance: Recorder) -> None: + """Handle the task.""" + # pylint: disable-next=[protected-access] + instance._process_one_event(self.event) + class Recorder(threading.Thread): """A threaded recorder class.""" + stop_requested: bool + def __init__( self, hass: HomeAssistant, @@ -406,7 +502,7 @@ class Recorder(threading.Thread): self.auto_purge = auto_purge self.keep_days = keep_days self.commit_interval = commit_interval - self.queue: Any = queue.SimpleQueue() + self.queue: queue.SimpleQueue[RecorderTask] = queue.SimpleQueue() self.recording_start = dt_util.utcnow() self.db_url = uri self.db_max_retries = db_max_retries @@ -537,7 +633,7 @@ class Recorder(threading.Thread): self.queue.get_nowait() except queue.Empty: break - self.queue.put(None) + self.queue.put(StopTask()) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_FINAL_WRITE, _empty_queue) @@ -545,7 +641,7 @@ class Recorder(threading.Thread): """Shut down the Recorder.""" if not hass_started.done(): hass_started.set_result(shutdown_task) - self.queue.put(None) + self.queue.put(StopTask()) self.hass.add_job(self._async_stop_queue_watcher_and_event_listener) self.join() @@ -691,31 +787,28 @@ class Recorder(threading.Thread): # Use a session for the event read loop # with a commit every time the event time # has changed. This reduces the disk io. - while event := self.queue.get(): + self.stop_requested = False + while not self.stop_requested: + task = self.queue.get() try: - self._process_one_event_or_recover(event) + self._process_one_task_or_recover(task) except Exception as err: # pylint: disable=broad-except - _LOGGER.exception("Error while processing event %s: %s", event, err) + _LOGGER.exception("Error while processing event %s: %s", task, err) self._shutdown() - def _process_one_event_or_recover(self, event): + def _process_one_task_or_recover(self, task: RecorderTask): """Process an event, reconnect, or recover a malformed database.""" try: - if self._process_one_task(event): - return - self._process_one_event(event) - return + return task.run(self) except exc.DatabaseError as err: if self._handle_database_error(err): return _LOGGER.exception( - "Unhandled database error while processing event %s: %s", event, err + "Unhandled database error while processing task %s: %s", task, err ) except SQLAlchemyError as err: - _LOGGER.exception( - "SQLAlchemyError error processing event %s: %s", event, err - ) + _LOGGER.exception("SQLAlchemyError error processing task %s: %s", task, err) # Reset the session if an SQLAlchemyError (including DatabaseError) # happens to rollback and recover @@ -773,38 +866,6 @@ class Recorder(threading.Thread): self.migration_in_progress = False persistent_notification.dismiss(self.hass, "recorder_database_migration") - def _run_purge(self, purge_before, repack, apply_filter): - """Purge the database.""" - if purge.purge_old_data(self, purge_before, repack, apply_filter): - # We always need to do the db cleanups after a purge - # is finished to ensure the WAL checkpoint and other - # tasks happen after a vacuum. - perodic_db_cleanups(self) - return - # Schedule a new purge task if this one didn't finish - self.queue.put(PurgeTask(purge_before, repack, apply_filter)) - - def _run_purge_entities(self, entity_filter): - """Purge entities from the database.""" - if purge.purge_entity_data(self, entity_filter): - return - # Schedule a new purge task if this one didn't finish - self.queue.put(PurgeEntitiesTask(entity_filter)) - - def _run_statistics(self, start): - """Run statistics task.""" - if statistics.compile_statistics(self, start): - return - # Schedule a new statistics task if this one didn't finish - self.queue.put(StatisticsTask(start)) - - def _run_external_statistics(self, metadata, stats): - """Run statistics task.""" - if statistics.add_external_statistics(self, metadata, stats): - return - # Schedule a new statistics task if this one didn't finish - self.queue.put(ExternalStatisticsTask(metadata, stats)) - def _lock_database(self, task: DatabaseLockTask): @callback def _async_set_database_locked(task: DatabaseLockTask): @@ -828,39 +889,6 @@ class Recorder(threading.Thread): self.queue.qsize(), ) - def _process_one_task(self, event) -> bool: - """Process one event.""" - if isinstance(event, PurgeTask): - self._run_purge(event.purge_before, event.repack, event.apply_filter) - return True - if isinstance(event, PurgeEntitiesTask): - self._run_purge_entities(event.entity_filter) - return True - if isinstance(event, PerodicCleanupTask): - perodic_db_cleanups(self) - return True - if isinstance(event, StatisticsTask): - self._run_statistics(event.start) - return True - if isinstance(event, ClearStatisticsTask): - statistics.clear_statistics(self, event.statistic_ids) - return True - if isinstance(event, UpdateStatisticsMetadataTask): - statistics.update_statistics_metadata( - self, event.statistic_id, event.unit_of_measurement - ) - return True - if isinstance(event, ExternalStatisticsTask): - self._run_external_statistics(event.metadata, event.statistics) - return True - if isinstance(event, WaitTask): - self._queue_watch.set() - return True - if isinstance(event, DatabaseLockTask): - self._lock_database(event) - return True - return False - def _process_one_event(self, event): if event.event_type == EVENT_TIME_CHANGED: self._keepalive_count += 1 @@ -1010,7 +1038,7 @@ class Recorder(threading.Thread): @callback def event_listener(self, event): """Listen for new events and put them in the process queue.""" - self.queue.put(event) + self.queue.put(EventTask(event)) def block_till_done(self): """Block till all events processed. diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 7d7c3f27fb6..64560e4d33b 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -261,7 +261,7 @@ def test_saving_state_with_sqlalchemy_exception(hass, hass_recorder, caplog): hass.states.set(entity_id, "fail", attributes) wait_recording_done(hass) - assert "SQLAlchemyError error processing event" in caplog.text + assert "SQLAlchemyError error processing task" in caplog.text caplog.clear() hass.states.set(entity_id, state, attributes) @@ -273,7 +273,7 @@ def test_saving_state_with_sqlalchemy_exception(hass, hass_recorder, caplog): assert "Error executing query" not in caplog.text assert "Error saving events" not in caplog.text - assert "SQLAlchemyError error processing event" not in caplog.text + assert "SQLAlchemyError error processing task" not in caplog.text async def test_force_shutdown_with_queue_of_writes_that_generate_exceptions(