Small cleanups for recorder (#68551)

This commit is contained in:
J. Nick Koston 2022-03-23 12:12:37 -10:00 committed by GitHub
parent c44d7205cf
commit 8c10963bc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 55 deletions

View File

@ -16,6 +16,7 @@ from typing import Any, TypeVar
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm.session import Session
@ -33,7 +34,14 @@ from homeassistant.const import (
EVENT_TIME_CHANGED,
MATCH_ALL,
)
from homeassistant.core import CoreState, HomeAssistant, ServiceCall, callback
from homeassistant.core import (
CALLBACK_TYPE,
CoreState,
Event,
HomeAssistant,
ServiceCall,
callback,
)
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import (
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
@ -306,7 +314,7 @@ async def _process_recorder_platform(hass, domain, platform):
@callback
def _async_register_services(hass, instance):
def _async_register_services(hass: HomeAssistant, instance: Recorder) -> None:
"""Register recorder services."""
async def async_handle_purge_service(service: ServiceCall) -> None:
@ -524,9 +532,9 @@ class StopTask(RecorderTask):
@dataclass
class EventTask(RecorderTask):
"""An object to insert into the recorder queue to stop the event handler."""
"""An event to be processed."""
event: bool
event: Event
commit_before = False
def run(self, instance: Recorder) -> None:
@ -567,7 +575,7 @@ class Recorder(threading.Thread):
self.async_db_ready: asyncio.Future = asyncio.Future()
self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event()
self.engine: Any = None
self.engine: Engine | None = None
self.run_info: Any = None
self.entity_filter = entity_filter
@ -580,13 +588,13 @@ class Recorder(threading.Thread):
self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE)
self._pending_state_attributes: dict[str, StateAttributes] = {}
self._pending_expunge: list[States] = []
self.event_session = None
self.get_session = None
self._completed_first_database_setup = None
self._event_listener = None
self.event_session: Session | None = None
self.get_session: Callable[[], Session] | None = None
self._completed_first_database_setup: bool | None = None
self._event_listener: CALLBACK_TYPE | None = None
self.async_migration_event = asyncio.Event()
self.migration_in_progress = False
self._queue_watcher = None
self._queue_watcher: CALLBACK_TYPE | None = None
self._db_supports_row_number = True
self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
@ -651,7 +659,7 @@ class Recorder(threading.Thread):
self._async_stop_queue_watcher_and_event_listener()
@callback
def _async_stop_queue_watcher_and_event_listener(self):
def _async_stop_queue_watcher_and_event_listener(self) -> None:
"""Stop watching the queue and listening for events."""
if self._queue_watcher:
self._queue_watcher()
@ -661,7 +669,7 @@ class Recorder(threading.Thread):
self._event_listener = None
@callback
def _async_event_filter(self, event) -> bool:
def _async_event_filter(self, event: Event) -> bool:
"""Filter events."""
if event.event_type in self.exclude_t:
return False
@ -702,7 +710,9 @@ class Recorder(threading.Thread):
self.queue.put(StatisticsTask(start))
@callback
def async_register(self, shutdown_task, hass_started):
def async_register(
self, shutdown_task: object, hass_started: concurrent.futures.Future
) -> None:
"""Post connection initialize."""
def _empty_queue(event):
@ -746,7 +756,7 @@ class Recorder(threading.Thread):
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, async_hass_started)
@callback
def async_connection_failed(self):
def async_connection_failed(self) -> None:
"""Connect failed tasks."""
self.async_db_ready.set_result(False)
persistent_notification.async_create(
@ -757,19 +767,19 @@ class Recorder(threading.Thread):
self._async_stop_queue_watcher_and_event_listener()
@callback
def async_connection_success(self):
def async_connection_success(self) -> None:
"""Connect success tasks."""
self.async_db_ready.set_result(True)
self.async_start_executor()
@callback
def _async_recorder_ready(self):
def _async_recorder_ready(self) -> None:
"""Finish start and mark recorder ready."""
self._async_setup_periodic_tasks()
self.async_recorder_ready.set()
@callback
def async_nightly_tasks(self, now):
def async_nightly_tasks(self, now: datetime) -> None:
"""Trigger the purge."""
if self.auto_purge:
# Purge will schedule the perodic cleanups
@ -781,7 +791,7 @@ class Recorder(threading.Thread):
self.queue.put(PerodicCleanupTask())
@callback
def async_periodic_statistics(self, now):
def async_periodic_statistics(self, now: datetime) -> None:
"""Trigger the hourly statistics run."""
start = statistics.get_start_time()
self.queue.put(StatisticsTask(start))
@ -807,7 +817,7 @@ class Recorder(threading.Thread):
self.queue.put(ExternalStatisticsTask(metadata, stats))
@callback
def _async_setup_periodic_tasks(self):
def _async_setup_periodic_tasks(self) -> None:
"""Prepare periodic tasks."""
if self.hass.is_stopping or not self.get_session:
# Home Assistant is shutting down
@ -823,10 +833,10 @@ class Recorder(threading.Thread):
self.hass, self.async_periodic_statistics, minute=range(0, 60, 5), second=10
)
def run(self):
def run(self) -> None:
"""Start processing events to save."""
shutdown_task = object()
hass_started = concurrent.futures.Future()
hass_started: concurrent.futures.Future = concurrent.futures.Future()
self.hass.add_job(self.async_register, shutdown_task, hass_started)
@ -875,7 +885,7 @@ class Recorder(threading.Thread):
self.hass.add_job(self._async_recorder_ready)
self._run_event_loop()
def _run_event_loop(self):
def _run_event_loop(self) -> None:
"""Run the event loop for the recorder."""
# Use a session for the event read loop
# with a commit every time the event time
@ -890,7 +900,7 @@ class Recorder(threading.Thread):
self._shutdown()
def _process_one_task_or_recover(self, task: RecorderTask):
def _process_one_task_or_recover(self, task: RecorderTask) -> None:
"""Process an event, reconnect, or recover a malformed database."""
try:
# If its not an event, commit everything
@ -931,11 +941,11 @@ class Recorder(threading.Thread):
return None
@callback
def _async_migration_started(self):
def _async_migration_started(self) -> None:
"""Set the migration started event."""
self.async_migration_event.set()
def _migrate_schema_and_setup_run(self, current_version) -> bool:
def _migrate_schema_and_setup_run(self, current_version: int) -> bool:
"""Migrate schema to the latest version."""
persistent_notification.create(
self.hass,
@ -962,7 +972,7 @@ class Recorder(threading.Thread):
self.migration_in_progress = False
persistent_notification.dismiss(self.hass, "recorder_database_migration")
def _lock_database(self, task: DatabaseLockTask):
def _lock_database(self, task: DatabaseLockTask) -> None:
@callback
def _async_set_database_locked(task: DatabaseLockTask):
task.database_locked.set()
@ -985,7 +995,7 @@ class Recorder(threading.Thread):
self.queue.qsize(),
)
def _process_one_event(self, event):
def _process_one_event(self, event: Event) -> None:
if event.event_type == EVENT_TIME_CHANGED:
self._keepalive_count += 1
if self._keepalive_count >= KEEPALIVE_TIME:
@ -1000,6 +1010,7 @@ class Recorder(threading.Thread):
if not self.enabled:
return
assert self.event_session is not None
try:
if event.event_type == EVENT_STATE_CHANGED:
@ -1071,7 +1082,7 @@ class Recorder(threading.Thread):
if not self.commit_interval:
self._commit_event_session_or_retry()
def _handle_database_error(self, err):
def _handle_database_error(self, err: Exception) -> bool:
"""Handle a database error that may result in moving away the corrupt db."""
if isinstance(err.__cause__, sqlite3.DatabaseError):
_LOGGER.exception(
@ -1081,7 +1092,7 @@ class Recorder(threading.Thread):
return True
return False
def _commit_event_session_or_retry(self):
def _commit_event_session_or_retry(self) -> None:
"""Commit the event session if there is work to do."""
if not self.event_session or (
not self.event_session.new and not self.event_session.dirty
@ -1105,7 +1116,8 @@ class Recorder(threading.Thread):
tries += 1
time.sleep(self.db_retry_wait)
def _commit_event_session(self):
def _commit_event_session(self) -> None:
assert self.event_session is not None
self._commits_without_expire += 1
if self._pending_expunge:
@ -1120,7 +1132,7 @@ class Recorder(threading.Thread):
# We just committed the state attributes to the database
# and we now know the attributes_ids. We can save
# a many selects for matching attributes by loading them
# many selects for matching attributes by loading them
# into the LRU cache now.
for state_attr in self._pending_state_attributes.values():
self._state_attributes_ids[
@ -1135,7 +1147,7 @@ class Recorder(threading.Thread):
self._commits_without_expire = 0
self.event_session.expire_all()
def _handle_sqlite_corruption(self):
def _handle_sqlite_corruption(self) -> None:
"""Handle the sqlite3 database being corrupt."""
self._close_event_session()
self._close_connection()
@ -1143,7 +1155,7 @@ class Recorder(threading.Thread):
self._setup_recorder()
self._setup_run()
def _close_event_session(self):
def _close_event_session(self) -> None:
"""Close the event session."""
self._old_states = {}
self._state_attributes_ids = {}
@ -1160,27 +1172,29 @@ class Recorder(threading.Thread):
"Error while rolling back and closing the event session: %s", err
)
def _reopen_event_session(self):
def _reopen_event_session(self) -> None:
"""Rollback the event session and reopen it after a failure."""
self._close_event_session()
self._open_event_session()
def _open_event_session(self):
def _open_event_session(self) -> None:
"""Open the event session."""
assert self.get_session is not None
self.event_session = self.get_session()
self.event_session.expire_on_commit = False
def _send_keep_alive(self):
def _send_keep_alive(self) -> None:
"""Send a keep alive to keep the db connection open."""
assert self.event_session is not None
_LOGGER.debug("Sending keepalive")
self.event_session.connection().scalar(select([1]))
@callback
def event_listener(self, event):
def event_listener(self, event: Event) -> None:
"""Listen for new events and put them in the process queue."""
self.queue.put(EventTask(event))
def block_till_done(self):
def block_till_done(self) -> None:
"""Block till all events processed.
This is only called in tests.
@ -1244,9 +1258,9 @@ class Recorder(threading.Thread):
return success
def _setup_connection(self):
def _setup_connection(self) -> None:
"""Ensure database is ready to fly."""
kwargs = {}
kwargs: dict[str, Any] = {}
self._completed_first_database_setup = False
def setup_recorder_connection(dbapi_connection, connection_record):
@ -1280,20 +1294,22 @@ class Recorder(threading.Thread):
_LOGGER.debug("Connected to recorder database")
@property
def _using_file_sqlite(self):
def _using_file_sqlite(self) -> bool:
"""Short version to check if we are using sqlite3 as a file."""
return self.db_url != SQLITE_URL_PREFIX and self.db_url.startswith(
SQLITE_URL_PREFIX
)
def _close_connection(self):
def _close_connection(self) -> None:
"""Close the connection."""
assert self.engine is not None
self.engine.dispose()
self.engine = None
self.get_session = None
def _setup_run(self):
def _setup_run(self) -> None:
"""Log the start of the current run and schedule any needed jobs."""
assert self.get_session is not None
with session_scope(session=self.get_session()) as session:
start = self.recording_start
end_incomplete_runs(session, start)
@ -1324,7 +1340,7 @@ class Recorder(threading.Thread):
self.queue.put(StatisticsTask(start))
start = end
def _end_session(self):
def _end_session(self) -> None:
"""End the recorder session."""
if self.event_session is None:
return
@ -1338,7 +1354,7 @@ class Recorder(threading.Thread):
self.run_info = None
def _shutdown(self):
def _shutdown(self) -> None:
"""Save end time for current run."""
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
self._stop_executor()
@ -1346,6 +1362,6 @@ class Recorder(threading.Thread):
self._close_connection()
@property
def recording(self):
def recording(self) -> bool:
"""Return if the recorder is recording."""
return self._event_listener is not None

View File

@ -12,15 +12,17 @@ _LOGGER = logging.getLogger(__name__)
def repack_database(instance: Recorder) -> None:
"""Repack based on engine type."""
assert instance.engine is not None
dialect_name = instance.engine.dialect.name
# Execute sqlite command to free up space on disk
if instance.engine.dialect.name == "sqlite":
if dialect_name == "sqlite":
_LOGGER.debug("Vacuuming SQL DB to free space")
instance.engine.execute("VACUUM")
return
# Execute postgresql vacuum command to free up space on disk
if instance.engine.dialect.name == "postgresql":
if dialect_name == "postgresql":
_LOGGER.debug("Vacuuming SQL DB to free space")
with instance.engine.connect().execution_options(
isolation_level="AUTOCOMMIT"
@ -29,7 +31,7 @@ def repack_database(instance: Recorder) -> None:
return
# Optimize mysql / mariadb tables to free up space on disk
if instance.engine.dialect.name == "mysql":
if dialect_name == "mysql":
_LOGGER.debug("Optimizing SQL DB to free space")
instance.engine.execute("OPTIMIZE TABLE states, events, recorder_runs")
return

View File

@ -1247,19 +1247,19 @@ def _filter_unique_constraint_integrity_error(
if not isinstance(err, StatementError):
return False
assert instance.engine is not None
dialect_name = instance.engine.dialect.name
ignore = False
if (
instance.engine.dialect.name == "sqlite"
and "UNIQUE constraint failed" in str(err)
):
if dialect_name == "sqlite" and "UNIQUE constraint failed" in str(err):
ignore = True
if (
instance.engine.dialect.name == "postgresql"
dialect_name == "postgresql"
and hasattr(err.orig, "pgcode")
and err.orig.pgcode == "23505"
):
ignore = True
if instance.engine.dialect.name == "mysql" and hasattr(err.orig, "args"):
if dialect_name == "mysql" and hasattr(err.orig, "args"):
with contextlib.suppress(TypeError):
if err.orig.args[0] == 1062:
ignore = True