Refactor recorder queue handling (#61161)

* Refactor recorder queue handling

* Address pylint's concerns

* Implement workaround for mypy bug

* Address review comments
This commit is contained in:
Erik Montnemery 2021-12-08 16:54:26 +01:00 committed by GitHub
parent bbe6d3c9ae
commit f30eb05870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 123 additions and 95 deletions

View File

@ -1,6 +1,7 @@
"""Support for recording details."""
from __future__ import annotations
import abc
import asyncio
from collections.abc import Callable, Iterable
import concurrent.futures
@ -11,7 +12,7 @@ import queue
import sqlite3
import threading
import time
from typing import Any, NamedTuple
from typing import Any
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
from sqlalchemy.exc import SQLAlchemyError
@ -327,66 +328,161 @@ def _async_register_services(hass, instance):
)
class ClearStatisticsTask(NamedTuple):
class RecorderTask(abc.ABC):
"""ABC for recorder tasks."""
@abc.abstractmethod
def run(self, instance: Recorder) -> None:
"""Handle the task."""
@dataclass
class ClearStatisticsTask(RecorderTask):
"""Object to store statistics_ids which for which to remove statistics."""
statistic_ids: list[str]
def run(self, instance: Recorder) -> None:
"""Handle the task."""
statistics.clear_statistics(instance, self.statistic_ids)
class UpdateStatisticsMetadataTask(NamedTuple):
@dataclass
class UpdateStatisticsMetadataTask(RecorderTask):
"""Object to store statistics_id and unit for update of statistics metadata."""
statistic_id: str
unit_of_measurement: str | None
def run(self, instance: Recorder) -> None:
"""Handle the task."""
statistics.update_statistics_metadata(
instance, self.statistic_id, self.unit_of_measurement
)
class PurgeTask(NamedTuple):
@dataclass
class PurgeTask(RecorderTask):
"""Object to store information about purge task."""
purge_before: datetime
repack: bool
apply_filter: bool
def run(self, instance: Recorder) -> None:
"""Purge the database."""
if purge.purge_old_data(
instance, self.purge_before, self.repack, self.apply_filter
):
# We always need to do the db cleanups after a purge
# is finished to ensure the WAL checkpoint and other
# tasks happen after a vacuum.
perodic_db_cleanups(instance)
return
# Schedule a new purge task if this one didn't finish
instance.queue.put(PurgeTask(self.purge_before, self.repack, self.apply_filter))
class PurgeEntitiesTask(NamedTuple):
@dataclass
class PurgeEntitiesTask(RecorderTask):
"""Object to store entity information about purge task."""
entity_filter: Callable[[str], bool]
def run(self, instance: Recorder) -> None:
"""Purge entities from the database."""
if purge.purge_entity_data(instance, self.entity_filter):
return
# Schedule a new purge task if this one didn't finish
instance.queue.put(PurgeEntitiesTask(self.entity_filter))
class PerodicCleanupTask:
@dataclass
class PerodicCleanupTask(RecorderTask):
"""An object to insert into the recorder to trigger cleanup tasks when auto purge is disabled."""
def run(self, instance: Recorder) -> None:
"""Handle the task."""
perodic_db_cleanups(instance)
class StatisticsTask(NamedTuple):
@dataclass
class StatisticsTask(RecorderTask):
"""An object to insert into the recorder queue to run a statistics task."""
start: datetime
def run(self, instance: Recorder) -> None:
"""Run statistics task."""
if statistics.compile_statistics(instance, self.start):
return
# Schedule a new statistics task if this one didn't finish
instance.queue.put(StatisticsTask(self.start))
class ExternalStatisticsTask(NamedTuple):
@dataclass
class ExternalStatisticsTask(RecorderTask):
"""An object to insert into the recorder queue to run an external statistics task."""
metadata: dict
statistics: Iterable[dict]
class WaitTask:
"""An object to insert into the recorder queue to tell it set the _queue_watch event."""
def run(self, instance: Recorder) -> None:
"""Run statistics task."""
if statistics.add_external_statistics(instance, self.metadata, self.statistics):
return
# Schedule a new statistics task if this one didn't finish
instance.queue.put(ExternalStatisticsTask(self.metadata, self.statistics))
@dataclass
class DatabaseLockTask:
class WaitTask(RecorderTask):
"""An object to insert into the recorder queue to tell it set the _queue_watch event."""
def run(self, instance: Recorder) -> None:
"""Handle the task."""
instance._queue_watch.set() # pylint: disable=[protected-access]
@dataclass
class DatabaseLockTask(RecorderTask):
"""An object to insert into the recorder queue to prevent writes to the database."""
database_locked: asyncio.Event
database_unlock: threading.Event
queue_overflow: bool
def run(self, instance: Recorder) -> None:
"""Handle the task."""
instance._lock_database(self) # pylint: disable=[protected-access]
@dataclass
class StopTask(RecorderTask):
"""An object to insert into the recorder queue to stop the event handler."""
def run(self, instance: Recorder) -> None:
"""Handle the task."""
instance.stop_requested = True
@dataclass
class EventTask(RecorderTask):
"""An object to insert into the recorder queue to stop the event handler."""
event: bool
def run(self, instance: Recorder) -> None:
"""Handle the task."""
# pylint: disable-next=[protected-access]
instance._process_one_event(self.event)
class Recorder(threading.Thread):
"""A threaded recorder class."""
stop_requested: bool
def __init__(
self,
hass: HomeAssistant,
@ -406,7 +502,7 @@ class Recorder(threading.Thread):
self.auto_purge = auto_purge
self.keep_days = keep_days
self.commit_interval = commit_interval
self.queue: Any = queue.SimpleQueue()
self.queue: queue.SimpleQueue[RecorderTask] = queue.SimpleQueue()
self.recording_start = dt_util.utcnow()
self.db_url = uri
self.db_max_retries = db_max_retries
@ -537,7 +633,7 @@ class Recorder(threading.Thread):
self.queue.get_nowait()
except queue.Empty:
break
self.queue.put(None)
self.queue.put(StopTask())
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_FINAL_WRITE, _empty_queue)
@ -545,7 +641,7 @@ class Recorder(threading.Thread):
"""Shut down the Recorder."""
if not hass_started.done():
hass_started.set_result(shutdown_task)
self.queue.put(None)
self.queue.put(StopTask())
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
self.join()
@ -691,31 +787,28 @@ class Recorder(threading.Thread):
# Use a session for the event read loop
# with a commit every time the event time
# has changed. This reduces the disk io.
while event := self.queue.get():
self.stop_requested = False
while not self.stop_requested:
task = self.queue.get()
try:
self._process_one_event_or_recover(event)
self._process_one_task_or_recover(task)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error while processing event %s: %s", event, err)
_LOGGER.exception("Error while processing event %s: %s", task, err)
self._shutdown()
def _process_one_event_or_recover(self, event):
def _process_one_task_or_recover(self, task: RecorderTask):
"""Process an event, reconnect, or recover a malformed database."""
try:
if self._process_one_task(event):
return
self._process_one_event(event)
return
return task.run(self)
except exc.DatabaseError as err:
if self._handle_database_error(err):
return
_LOGGER.exception(
"Unhandled database error while processing event %s: %s", event, err
"Unhandled database error while processing task %s: %s", task, err
)
except SQLAlchemyError as err:
_LOGGER.exception(
"SQLAlchemyError error processing event %s: %s", event, err
)
_LOGGER.exception("SQLAlchemyError error processing task %s: %s", task, err)
# Reset the session if an SQLAlchemyError (including DatabaseError)
# happens to rollback and recover
@ -773,38 +866,6 @@ class Recorder(threading.Thread):
self.migration_in_progress = False
persistent_notification.dismiss(self.hass, "recorder_database_migration")
def _run_purge(self, purge_before, repack, apply_filter):
"""Purge the database."""
if purge.purge_old_data(self, purge_before, repack, apply_filter):
# We always need to do the db cleanups after a purge
# is finished to ensure the WAL checkpoint and other
# tasks happen after a vacuum.
perodic_db_cleanups(self)
return
# Schedule a new purge task if this one didn't finish
self.queue.put(PurgeTask(purge_before, repack, apply_filter))
def _run_purge_entities(self, entity_filter):
"""Purge entities from the database."""
if purge.purge_entity_data(self, entity_filter):
return
# Schedule a new purge task if this one didn't finish
self.queue.put(PurgeEntitiesTask(entity_filter))
def _run_statistics(self, start):
"""Run statistics task."""
if statistics.compile_statistics(self, start):
return
# Schedule a new statistics task if this one didn't finish
self.queue.put(StatisticsTask(start))
def _run_external_statistics(self, metadata, stats):
"""Run statistics task."""
if statistics.add_external_statistics(self, metadata, stats):
return
# Schedule a new statistics task if this one didn't finish
self.queue.put(ExternalStatisticsTask(metadata, stats))
def _lock_database(self, task: DatabaseLockTask):
@callback
def _async_set_database_locked(task: DatabaseLockTask):
@ -828,39 +889,6 @@ class Recorder(threading.Thread):
self.queue.qsize(),
)
def _process_one_task(self, event) -> bool:
"""Process one event."""
if isinstance(event, PurgeTask):
self._run_purge(event.purge_before, event.repack, event.apply_filter)
return True
if isinstance(event, PurgeEntitiesTask):
self._run_purge_entities(event.entity_filter)
return True
if isinstance(event, PerodicCleanupTask):
perodic_db_cleanups(self)
return True
if isinstance(event, StatisticsTask):
self._run_statistics(event.start)
return True
if isinstance(event, ClearStatisticsTask):
statistics.clear_statistics(self, event.statistic_ids)
return True
if isinstance(event, UpdateStatisticsMetadataTask):
statistics.update_statistics_metadata(
self, event.statistic_id, event.unit_of_measurement
)
return True
if isinstance(event, ExternalStatisticsTask):
self._run_external_statistics(event.metadata, event.statistics)
return True
if isinstance(event, WaitTask):
self._queue_watch.set()
return True
if isinstance(event, DatabaseLockTask):
self._lock_database(event)
return True
return False
def _process_one_event(self, event):
if event.event_type == EVENT_TIME_CHANGED:
self._keepalive_count += 1
@ -1010,7 +1038,7 @@ class Recorder(threading.Thread):
@callback
def event_listener(self, event):
"""Listen for new events and put them in the process queue."""
self.queue.put(event)
self.queue.put(EventTask(event))
def block_till_done(self):
"""Block till all events processed.

View File

@ -261,7 +261,7 @@ def test_saving_state_with_sqlalchemy_exception(hass, hass_recorder, caplog):
hass.states.set(entity_id, "fail", attributes)
wait_recording_done(hass)
assert "SQLAlchemyError error processing event" in caplog.text
assert "SQLAlchemyError error processing task" in caplog.text
caplog.clear()
hass.states.set(entity_id, state, attributes)
@ -273,7 +273,7 @@ def test_saving_state_with_sqlalchemy_exception(hass, hass_recorder, caplog):
assert "Error executing query" not in caplog.text
assert "Error saving events" not in caplog.text
assert "SQLAlchemyError error processing event" not in caplog.text
assert "SQLAlchemyError error processing task" not in caplog.text
async def test_force_shutdown_with_queue_of_writes_that_generate_exceptions(