Refactor recorder queue handling (#61161)

* Refactor recorder queue handling

* Address pylint's concerns

* Implement workaround for mypy bug

* Address review comments
This commit is contained in:
Erik Montnemery 2021-12-08 16:54:26 +01:00 committed by GitHub
parent bbe6d3c9ae
commit f30eb05870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 123 additions and 95 deletions

View File

@ -1,6 +1,7 @@
"""Support for recording details.""" """Support for recording details."""
from __future__ import annotations from __future__ import annotations
import abc
import asyncio import asyncio
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
import concurrent.futures import concurrent.futures
@ -11,7 +12,7 @@ import queue
import sqlite3 import sqlite3
import threading import threading
import time import time
from typing import Any, NamedTuple from typing import Any
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -327,66 +328,161 @@ def _async_register_services(hass, instance):
) )
class ClearStatisticsTask(NamedTuple): class RecorderTask(abc.ABC):
"""ABC for recorder tasks."""
@abc.abstractmethod
def run(self, instance: Recorder) -> None:
"""Handle the task."""
@dataclass
class ClearStatisticsTask(RecorderTask):
"""Object to store statistics_ids which for which to remove statistics.""" """Object to store statistics_ids which for which to remove statistics."""
statistic_ids: list[str] statistic_ids: list[str]
def run(self, instance: Recorder) -> None:
"""Handle the task."""
statistics.clear_statistics(instance, self.statistic_ids)
class UpdateStatisticsMetadataTask(NamedTuple):
@dataclass
class UpdateStatisticsMetadataTask(RecorderTask):
"""Object to store statistics_id and unit for update of statistics metadata.""" """Object to store statistics_id and unit for update of statistics metadata."""
statistic_id: str statistic_id: str
unit_of_measurement: str | None unit_of_measurement: str | None
def run(self, instance: Recorder) -> None:
"""Handle the task."""
statistics.update_statistics_metadata(
instance, self.statistic_id, self.unit_of_measurement
)
class PurgeTask(NamedTuple):
@dataclass
class PurgeTask(RecorderTask):
"""Object to store information about purge task.""" """Object to store information about purge task."""
purge_before: datetime purge_before: datetime
repack: bool repack: bool
apply_filter: bool apply_filter: bool
def run(self, instance: Recorder) -> None:
"""Purge the database."""
if purge.purge_old_data(
instance, self.purge_before, self.repack, self.apply_filter
):
# We always need to do the db cleanups after a purge
# is finished to ensure the WAL checkpoint and other
# tasks happen after a vacuum.
perodic_db_cleanups(instance)
return
# Schedule a new purge task if this one didn't finish
instance.queue.put(PurgeTask(self.purge_before, self.repack, self.apply_filter))
class PurgeEntitiesTask(NamedTuple):
@dataclass
class PurgeEntitiesTask(RecorderTask):
"""Object to store entity information about purge task.""" """Object to store entity information about purge task."""
entity_filter: Callable[[str], bool] entity_filter: Callable[[str], bool]
def run(self, instance: Recorder) -> None:
"""Purge entities from the database."""
if purge.purge_entity_data(instance, self.entity_filter):
return
# Schedule a new purge task if this one didn't finish
instance.queue.put(PurgeEntitiesTask(self.entity_filter))
class PerodicCleanupTask:
@dataclass
class PerodicCleanupTask(RecorderTask):
"""An object to insert into the recorder to trigger cleanup tasks when auto purge is disabled.""" """An object to insert into the recorder to trigger cleanup tasks when auto purge is disabled."""
def run(self, instance: Recorder) -> None:
"""Handle the task."""
perodic_db_cleanups(instance)
class StatisticsTask(NamedTuple):
@dataclass
class StatisticsTask(RecorderTask):
"""An object to insert into the recorder queue to run a statistics task.""" """An object to insert into the recorder queue to run a statistics task."""
start: datetime start: datetime
def run(self, instance: Recorder) -> None:
"""Run statistics task."""
if statistics.compile_statistics(instance, self.start):
return
# Schedule a new statistics task if this one didn't finish
instance.queue.put(StatisticsTask(self.start))
class ExternalStatisticsTask(NamedTuple):
@dataclass
class ExternalStatisticsTask(RecorderTask):
"""An object to insert into the recorder queue to run an external statistics task.""" """An object to insert into the recorder queue to run an external statistics task."""
metadata: dict metadata: dict
statistics: Iterable[dict] statistics: Iterable[dict]
def run(self, instance: Recorder) -> None:
class WaitTask: """Run statistics task."""
"""An object to insert into the recorder queue to tell it set the _queue_watch event.""" if statistics.add_external_statistics(instance, self.metadata, self.statistics):
return
# Schedule a new statistics task if this one didn't finish
instance.queue.put(ExternalStatisticsTask(self.metadata, self.statistics))
@dataclass @dataclass
class DatabaseLockTask: class WaitTask(RecorderTask):
"""An object to insert into the recorder queue to tell it set the _queue_watch event."""
def run(self, instance: Recorder) -> None:
"""Handle the task."""
instance._queue_watch.set() # pylint: disable=[protected-access]
@dataclass
class DatabaseLockTask(RecorderTask):
"""An object to insert into the recorder queue to prevent writes to the database.""" """An object to insert into the recorder queue to prevent writes to the database."""
database_locked: asyncio.Event database_locked: asyncio.Event
database_unlock: threading.Event database_unlock: threading.Event
queue_overflow: bool queue_overflow: bool
def run(self, instance: Recorder) -> None:
"""Handle the task."""
instance._lock_database(self) # pylint: disable=[protected-access]
@dataclass
class StopTask(RecorderTask):
"""An object to insert into the recorder queue to stop the event handler."""
def run(self, instance: Recorder) -> None:
"""Handle the task."""
instance.stop_requested = True
@dataclass
class EventTask(RecorderTask):
"""An object to insert into the recorder queue to stop the event handler."""
event: bool
def run(self, instance: Recorder) -> None:
"""Handle the task."""
# pylint: disable-next=[protected-access]
instance._process_one_event(self.event)
class Recorder(threading.Thread): class Recorder(threading.Thread):
"""A threaded recorder class.""" """A threaded recorder class."""
stop_requested: bool
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@ -406,7 +502,7 @@ class Recorder(threading.Thread):
self.auto_purge = auto_purge self.auto_purge = auto_purge
self.keep_days = keep_days self.keep_days = keep_days
self.commit_interval = commit_interval self.commit_interval = commit_interval
self.queue: Any = queue.SimpleQueue() self.queue: queue.SimpleQueue[RecorderTask] = queue.SimpleQueue()
self.recording_start = dt_util.utcnow() self.recording_start = dt_util.utcnow()
self.db_url = uri self.db_url = uri
self.db_max_retries = db_max_retries self.db_max_retries = db_max_retries
@ -537,7 +633,7 @@ class Recorder(threading.Thread):
self.queue.get_nowait() self.queue.get_nowait()
except queue.Empty: except queue.Empty:
break break
self.queue.put(None) self.queue.put(StopTask())
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_FINAL_WRITE, _empty_queue) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_FINAL_WRITE, _empty_queue)
@ -545,7 +641,7 @@ class Recorder(threading.Thread):
"""Shut down the Recorder.""" """Shut down the Recorder."""
if not hass_started.done(): if not hass_started.done():
hass_started.set_result(shutdown_task) hass_started.set_result(shutdown_task)
self.queue.put(None) self.queue.put(StopTask())
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener) self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
self.join() self.join()
@ -691,31 +787,28 @@ class Recorder(threading.Thread):
# Use a session for the event read loop # Use a session for the event read loop
# with a commit every time the event time # with a commit every time the event time
# has changed. This reduces the disk io. # has changed. This reduces the disk io.
while event := self.queue.get(): self.stop_requested = False
while not self.stop_requested:
task = self.queue.get()
try: try:
self._process_one_event_or_recover(event) self._process_one_task_or_recover(task)
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error while processing event %s: %s", event, err) _LOGGER.exception("Error while processing event %s: %s", task, err)
self._shutdown() self._shutdown()
def _process_one_event_or_recover(self, event): def _process_one_task_or_recover(self, task: RecorderTask):
"""Process an event, reconnect, or recover a malformed database.""" """Process an event, reconnect, or recover a malformed database."""
try: try:
if self._process_one_task(event): return task.run(self)
return
self._process_one_event(event)
return
except exc.DatabaseError as err: except exc.DatabaseError as err:
if self._handle_database_error(err): if self._handle_database_error(err):
return return
_LOGGER.exception( _LOGGER.exception(
"Unhandled database error while processing event %s: %s", event, err "Unhandled database error while processing task %s: %s", task, err
) )
except SQLAlchemyError as err: except SQLAlchemyError as err:
_LOGGER.exception( _LOGGER.exception("SQLAlchemyError error processing task %s: %s", task, err)
"SQLAlchemyError error processing event %s: %s", event, err
)
# Reset the session if an SQLAlchemyError (including DatabaseError) # Reset the session if an SQLAlchemyError (including DatabaseError)
# happens to rollback and recover # happens to rollback and recover
@ -773,38 +866,6 @@ class Recorder(threading.Thread):
self.migration_in_progress = False self.migration_in_progress = False
persistent_notification.dismiss(self.hass, "recorder_database_migration") persistent_notification.dismiss(self.hass, "recorder_database_migration")
def _run_purge(self, purge_before, repack, apply_filter):
"""Purge the database."""
if purge.purge_old_data(self, purge_before, repack, apply_filter):
# We always need to do the db cleanups after a purge
# is finished to ensure the WAL checkpoint and other
# tasks happen after a vacuum.
perodic_db_cleanups(self)
return
# Schedule a new purge task if this one didn't finish
self.queue.put(PurgeTask(purge_before, repack, apply_filter))
def _run_purge_entities(self, entity_filter):
"""Purge entities from the database."""
if purge.purge_entity_data(self, entity_filter):
return
# Schedule a new purge task if this one didn't finish
self.queue.put(PurgeEntitiesTask(entity_filter))
def _run_statistics(self, start):
"""Run statistics task."""
if statistics.compile_statistics(self, start):
return
# Schedule a new statistics task if this one didn't finish
self.queue.put(StatisticsTask(start))
def _run_external_statistics(self, metadata, stats):
"""Run statistics task."""
if statistics.add_external_statistics(self, metadata, stats):
return
# Schedule a new statistics task if this one didn't finish
self.queue.put(ExternalStatisticsTask(metadata, stats))
def _lock_database(self, task: DatabaseLockTask): def _lock_database(self, task: DatabaseLockTask):
@callback @callback
def _async_set_database_locked(task: DatabaseLockTask): def _async_set_database_locked(task: DatabaseLockTask):
@ -828,39 +889,6 @@ class Recorder(threading.Thread):
self.queue.qsize(), self.queue.qsize(),
) )
def _process_one_task(self, event) -> bool:
"""Process one event."""
if isinstance(event, PurgeTask):
self._run_purge(event.purge_before, event.repack, event.apply_filter)
return True
if isinstance(event, PurgeEntitiesTask):
self._run_purge_entities(event.entity_filter)
return True
if isinstance(event, PerodicCleanupTask):
perodic_db_cleanups(self)
return True
if isinstance(event, StatisticsTask):
self._run_statistics(event.start)
return True
if isinstance(event, ClearStatisticsTask):
statistics.clear_statistics(self, event.statistic_ids)
return True
if isinstance(event, UpdateStatisticsMetadataTask):
statistics.update_statistics_metadata(
self, event.statistic_id, event.unit_of_measurement
)
return True
if isinstance(event, ExternalStatisticsTask):
self._run_external_statistics(event.metadata, event.statistics)
return True
if isinstance(event, WaitTask):
self._queue_watch.set()
return True
if isinstance(event, DatabaseLockTask):
self._lock_database(event)
return True
return False
def _process_one_event(self, event): def _process_one_event(self, event):
if event.event_type == EVENT_TIME_CHANGED: if event.event_type == EVENT_TIME_CHANGED:
self._keepalive_count += 1 self._keepalive_count += 1
@ -1010,7 +1038,7 @@ class Recorder(threading.Thread):
@callback @callback
def event_listener(self, event): def event_listener(self, event):
"""Listen for new events and put them in the process queue.""" """Listen for new events and put them in the process queue."""
self.queue.put(event) self.queue.put(EventTask(event))
def block_till_done(self): def block_till_done(self):
"""Block till all events processed. """Block till all events processed.

View File

@ -261,7 +261,7 @@ def test_saving_state_with_sqlalchemy_exception(hass, hass_recorder, caplog):
hass.states.set(entity_id, "fail", attributes) hass.states.set(entity_id, "fail", attributes)
wait_recording_done(hass) wait_recording_done(hass)
assert "SQLAlchemyError error processing event" in caplog.text assert "SQLAlchemyError error processing task" in caplog.text
caplog.clear() caplog.clear()
hass.states.set(entity_id, state, attributes) hass.states.set(entity_id, state, attributes)
@ -273,7 +273,7 @@ def test_saving_state_with_sqlalchemy_exception(hass, hass_recorder, caplog):
assert "Error executing query" not in caplog.text assert "Error executing query" not in caplog.text
assert "Error saving events" not in caplog.text assert "Error saving events" not in caplog.text
assert "SQLAlchemyError error processing event" not in caplog.text assert "SQLAlchemyError error processing task" not in caplog.text
async def test_force_shutdown_with_queue_of_writes_that_generate_exceptions( async def test_force_shutdown_with_queue_of_writes_that_generate_exceptions(