Refactor tracking of the recorder run history (#70456)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
J. Nick Koston 2022-04-26 09:59:43 -10:00 committed by GitHub
parent 130e7fe128
commit f073f17040
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 242 additions and 113 deletions

View File

@ -72,13 +72,13 @@ from .executor import DBInterruptibleThreadPoolExecutor
from .models import ( from .models import (
Base, Base,
Events, Events,
RecorderRuns,
StateAttributes, StateAttributes,
States, States,
StatisticsRuns, StatisticsRuns,
process_timestamp, process_timestamp,
) )
from .pool import POOL_SIZE, MutexPool, RecorderPool from .pool import POOL_SIZE, MutexPool, RecorderPool
from .run_history import RunHistory
from .util import ( from .util import (
dburl_to_path, dburl_to_path,
end_incomplete_runs, end_incomplete_runs,
@ -244,51 +244,6 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
return instance.entity_filter(entity_id) return instance.entity_filter(entity_id)
def run_information(
hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run.
There is also the run that covers point_in_time.
"""
if run_info := run_information_from_instance(hass, point_in_time):
return run_info
with session_scope(hass=hass) as session:
return run_information_with_session(session, point_in_time)
def run_information_from_instance(
hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the existing instance.
Does not query the database for older runs.
"""
ins = get_instance(hass)
if point_in_time is None or point_in_time > ins.recording_start:
return ins.run_info
return None
def run_information_with_session(
session: Session, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the database."""
recorder_runs = RecorderRuns
query = session.query(recorder_runs)
if point_in_time:
query = query.filter(
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
)
if (res := query.first()) is not None:
session.expunge(res)
return cast(RecorderRuns, res)
return res
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the recorder.""" """Set up the recorder."""
hass.data[DOMAIN] = {} hass.data[DOMAIN] = {}
@ -438,9 +393,13 @@ class PurgeTask(RecorderTask):
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Purge the database.""" """Purge the database."""
assert instance.get_session is not None
if purge.purge_old_data( if purge.purge_old_data(
instance, self.purge_before, self.repack, self.apply_filter instance, self.purge_before, self.repack, self.apply_filter
): ):
with instance.get_session() as session:
instance.run_history.load_from_db(session)
# We always need to do the db cleanups after a purge # We always need to do the db cleanups after a purge
# is finished to ensure the WAL checkpoint and other # is finished to ensure the WAL checkpoint and other
# tasks happen after a vacuum. # tasks happen after a vacuum.
@ -652,7 +611,6 @@ class Recorder(threading.Thread):
self._hass_started: asyncio.Future[object] = asyncio.Future() self._hass_started: asyncio.Future[object] = asyncio.Future()
self.commit_interval = commit_interval self.commit_interval = commit_interval
self.queue: queue.SimpleQueue[RecorderTask] = queue.SimpleQueue() self.queue: queue.SimpleQueue[RecorderTask] = queue.SimpleQueue()
self.recording_start = dt_util.utcnow()
self.db_url = uri self.db_url = uri
self.db_max_retries = db_max_retries self.db_max_retries = db_max_retries
self.db_retry_wait = db_retry_wait self.db_retry_wait = db_retry_wait
@ -660,7 +618,7 @@ class Recorder(threading.Thread):
self.async_recorder_ready = asyncio.Event() self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event() self._queue_watch = threading.Event()
self.engine: Engine | None = None self.engine: Engine | None = None
self.run_info: RecorderRuns | None = None self.run_history = RunHistory()
self.entity_filter = entity_filter self.entity_filter = entity_filter
self.exclude_t = exclude_t self.exclude_t = exclude_t
@ -1302,6 +1260,7 @@ class Recorder(threading.Thread):
self._close_event_session() self._close_event_session()
self._close_connection() self._close_connection()
move_away_broken_database(dburl_to_path(self.db_url)) move_away_broken_database(dburl_to_path(self.db_url))
self.run_history.reset()
self._setup_recorder() self._setup_recorder()
self._setup_run() self._setup_run()
@ -1465,12 +1424,8 @@ class Recorder(threading.Thread):
"""Log the start of the current run and schedule any needed jobs.""" """Log the start of the current run and schedule any needed jobs."""
assert self.get_session is not None assert self.get_session is not None
with session_scope(session=self.get_session()) as session: with session_scope(session=self.get_session()) as session:
start = self.recording_start end_incomplete_runs(session, self.run_history.recording_start)
end_incomplete_runs(session, start) self.run_history.start(session)
self.run_info = RecorderRuns(start=start, created=dt_util.utcnow())
session.add(self.run_info)
session.flush()
session.expunge(self.run_info)
self._schedule_compile_missing_statistics(session) self._schedule_compile_missing_statistics(session)
self._open_event_session() self._open_event_session()
@ -1498,16 +1453,14 @@ class Recorder(threading.Thread):
"""End the recorder session.""" """End the recorder session."""
if self.event_session is None: if self.event_session is None:
return return
assert self.run_info is not None
try: try:
self.run_info.end = dt_util.utcnow() self.run_history.end(self.event_session)
self.event_session.add(self.run_info)
self._commit_event_session_or_retry() self._commit_event_session_or_retry()
self.event_session.close() self.event_session.close()
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error saving the event session during shutdown: %s", err) _LOGGER.exception("Error saving the event session during shutdown: %s", err)
self.run_info = None self.run_history.clear()
def _shutdown(self) -> None: def _shutdown(self) -> None:
"""Save end time for current run.""" """Save end time for current run."""

View File

@ -23,6 +23,7 @@ from .models import (
RecorderRuns, RecorderRuns,
StateAttributes, StateAttributes,
States, States,
process_timestamp,
process_timestamp_to_utc_isoformat, process_timestamp_to_utc_isoformat,
) )
from .util import execute, session_scope from .util import execute, session_scope
@ -478,11 +479,10 @@ def get_states(
no_attributes: bool = False, no_attributes: bool = False,
) -> list[State]: ) -> list[State]:
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
if ( if run is None:
run is None run = recorder.get_instance(hass).run_history.get(utc_point_in_time)
and (run := (recorder.run_information_from_instance(hass, utc_point_in_time)))
is None if run is None or process_timestamp(run.start) > utc_point_in_time:
):
# History did not run before utc_point_in_time # History did not run before utc_point_in_time
return [] return []
@ -507,11 +507,10 @@ def _get_states_with_session(
hass, session, utc_point_in_time, entity_ids[0], no_attributes hass, session, utc_point_in_time, entity_ids[0], no_attributes
) )
if ( if run is None:
run is None run = recorder.get_instance(hass).run_history.get(utc_point_in_time)
and (run := (recorder.run_information_with_session(session, utc_point_in_time)))
is None if run is None or process_timestamp(run.start) > utc_point_in_time:
):
# History did not run before utc_point_in_time # History did not run before utc_point_in_time
return [] return []
@ -649,13 +648,11 @@ def _sorted_states_to_dict(
# Get the states at the start time # Get the states at the start time
timer_start = time.perf_counter() timer_start = time.perf_counter()
if include_start_time_state: if include_start_time_state:
run = recorder.run_information_from_instance(hass, start_time)
for state in _get_states_with_session( for state in _get_states_with_session(
hass, hass,
session, session,
start_time, start_time,
entity_ids, entity_ids,
run=run,
filters=filters, filters=filters,
no_attributes=no_attributes, no_attributes=no_attributes,
): ):

View File

@ -291,11 +291,10 @@ def _purge_old_recorder_runs(
) -> None: ) -> None:
"""Purge all old recorder runs.""" """Purge all old recorder runs."""
# Recorder runs is small, no need to batch run it # Recorder runs is small, no need to batch run it
assert instance.run_info is not None
deleted_rows = ( deleted_rows = (
session.query(RecorderRuns) session.query(RecorderRuns)
.filter(RecorderRuns.start < purge_before) .filter(RecorderRuns.start < purge_before)
.filter(RecorderRuns.run_id != instance.run_info.run_id) .filter(RecorderRuns.run_id != instance.run_history.current.run_id)
.delete(synchronize_session=False) .delete(synchronize_session=False)
) )
_LOGGER.debug("Deleted %s recorder_runs", deleted_rows) _LOGGER.debug("Deleted %s recorder_runs", deleted_rows)

View File

@ -0,0 +1,133 @@
"""Track recorder run history."""
from __future__ import annotations
import bisect
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy.orm.session import Session
import homeassistant.util.dt as dt_util
from .models import RecorderRuns, process_timestamp
def _find_recorder_run_for_start_time(
run_history: _RecorderRunsHistory, start: datetime
) -> RecorderRuns | None:
"""Find the recorder run for a start time in _RecorderRunsHistory."""
run_timestamps = run_history.run_timestamps
runs_by_timestamp = run_history.runs_by_timestamp
# bisect_left tells us were we would insert
# a value in the list of runs after the start timestamp.
#
# The run before that (idx-1) is when the run started
#
# If idx is 0, history never ran before the start timestamp
#
if idx := bisect.bisect_left(run_timestamps, start.timestamp()):
return runs_by_timestamp[run_timestamps[idx - 1]]
return None
@dataclass(frozen=True)
class _RecorderRunsHistory:
"""Bisectable history of RecorderRuns."""
run_timestamps: list[int]
runs_by_timestamp: dict[int, RecorderRuns]
class RunHistory:
"""Track recorder run history."""
def __init__(self) -> None:
"""Track recorder run history."""
self._recording_start = dt_util.utcnow()
self._current_run_info: RecorderRuns | None = None
self._run_history = _RecorderRunsHistory([], {})
@property
def recording_start(self) -> datetime:
"""Return the time the recorder started recording states."""
return self._recording_start
@property
def current(self) -> RecorderRuns:
"""Get the current run."""
assert self._current_run_info is not None
return self._current_run_info
def get(self, start: datetime) -> RecorderRuns | None:
"""Return the recorder run that started before or at start.
If the first run started after the start, return None
"""
if start >= self.recording_start:
return self.current
return _find_recorder_run_for_start_time(self._run_history, start)
def start(self, session: Session) -> None:
"""Start a new run.
Must run in the recorder thread.
"""
self._current_run_info = RecorderRuns(
start=self.recording_start, created=dt_util.utcnow()
)
session.add(self._current_run_info)
session.flush()
session.expunge(self._current_run_info)
self.load_from_db(session)
def reset(self) -> None:
"""Reset the run when the database is changed or fails.
Must run in the recorder thread.
"""
self._recording_start = dt_util.utcnow()
self._current_run_info = None
def end(self, session: Session) -> None:
"""End the current run.
Must run in the recorder thread.
"""
assert self._current_run_info is not None
self._current_run_info.end = dt_util.utcnow()
session.add(self._current_run_info)
def load_from_db(self, session: Session) -> None:
"""Update the run cache.
Must run in the recorder thread.
"""
run_timestamps: list[int] = []
runs_by_timestamp: dict[int, RecorderRuns] = {}
for run in session.query(RecorderRuns).order_by(RecorderRuns.start.asc()).all():
session.expunge(run)
if run_dt := process_timestamp(run.start):
timestamp = run_dt.timestamp()
run_timestamps.append(timestamp)
runs_by_timestamp[timestamp] = run
#
# self._run_history is accessed in get()
# which is allowed to be called from any thread
#
# We use a dataclass to ensure that when we update
# run_timestamps and runs_by_timestamp
# are never out of sync with each other.
#
self._run_history = _RecorderRunsHistory(run_timestamps, runs_by_timestamp)
def clear(self) -> None:
"""Clear the current run after ending it.
Must run in the recorder thread.
"""
assert self._current_run_info is not None
assert self._current_run_info.end is not None
self._current_run_info = None

View File

@ -1,10 +1,15 @@
"""Common test utils for working with recorder.""" """Common test utils for working with recorder."""
from datetime import timedelta from __future__ import annotations
from datetime import datetime, timedelta
from typing import cast
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm.session import Session
from homeassistant import core as ha from homeassistant import core as ha
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder.models import RecorderRuns
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -81,3 +86,21 @@ def create_engine_test(*args, **kwargs):
engine = create_engine(*args, **kwargs) engine = create_engine(*args, **kwargs)
models_schema_0.Base.metadata.create_all(engine) models_schema_0.Base.metadata.create_all(engine)
return engine return engine
def run_information_with_session(
session: Session, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the database."""
recorder_runs = RecorderRuns
query = session.query(recorder_runs)
if point_in_time:
query = query.filter(
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
)
if (res := query.first()) is not None:
session.expunge(res)
return cast(RecorderRuns, res)
return res

View File

@ -1,4 +1,6 @@
"""The tests for the Recorder component.""" """The tests for the Recorder component."""
from __future__ import annotations
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -25,9 +27,6 @@ from homeassistant.components.recorder import (
SQLITE_URL_PREFIX, SQLITE_URL_PREFIX,
Recorder, Recorder,
get_instance, get_instance,
run_information,
run_information_from_instance,
run_information_with_session,
) )
from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.const import DATA_INSTANCE
from homeassistant.components.recorder.models import ( from homeassistant.components.recorder.models import (
@ -51,7 +50,12 @@ from homeassistant.core import Context, CoreState, HomeAssistant, callback
from homeassistant.setup import async_setup_component, setup_component from homeassistant.setup import async_setup_component, setup_component
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .common import async_wait_recording_done, corrupt_db_file, wait_recording_done from .common import (
async_wait_recording_done,
corrupt_db_file,
run_information_with_session,
wait_recording_done,
)
from tests.common import ( from tests.common import (
SetupRecorderInstanceT, SetupRecorderInstanceT,
@ -1008,37 +1012,6 @@ def test_saving_state_with_serializable_data(hass_recorder, caplog):
assert "State is not JSON serializable" in caplog.text assert "State is not JSON serializable" in caplog.text
def test_run_information(hass_recorder):
"""Ensure run_information returns expected data."""
before_start_recording = dt_util.utcnow()
hass = hass_recorder()
run_info = run_information_from_instance(hass)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
with session_scope(hass=hass) as session:
run_info = run_information_with_session(session)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
run_info = run_information(hass)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
hass.states.set("test.two", "on", {})
wait_recording_done(hass)
run_info = run_information(hass)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
run_info = run_information(hass, before_start_recording)
assert run_info is None
run_info = run_information(hass, dt_util.utcnow())
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
def test_has_services(hass_recorder): def test_has_services(hass_recorder):
"""Test the services exist.""" """Test the services exist."""
hass = hass_recorder() hass = hass_recorder()
@ -1208,6 +1181,8 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
await hass.async_block_till_done() await hass.async_block_till_done()
caplog.clear() caplog.clear()
original_start_time = get_instance(hass).run_history.recording_start
hass.states.async_set("test.lost", "on", {}) hass.states.async_set("test.lost", "on", {})
sqlite3_exception = DatabaseError("statement", {}, []) sqlite3_exception = DatabaseError("statement", {}, [])
@ -1252,6 +1227,9 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
assert state.entity_id == "test.two" assert state.entity_id == "test.two"
assert state.state == "on" assert state.state == "on"
new_start_time = get_instance(hass).run_history.recording_start
assert original_start_time < new_start_time
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done() await hass.async_block_till_done()
hass.stop() hass.stop()

View File

@ -20,9 +20,9 @@ from sqlalchemy.pool import StaticPool
from homeassistant.bootstrap import async_setup_component from homeassistant.bootstrap import async_setup_component
from homeassistant.components import persistent_notification as pn, recorder from homeassistant.components import persistent_notification as pn, recorder
from homeassistant.components.recorder import RecorderRuns, migration, models from homeassistant.components.recorder import migration, models
from homeassistant.components.recorder.const import DATA_INSTANCE from homeassistant.components.recorder.const import DATA_INSTANCE
from homeassistant.components.recorder.models import States from homeassistant.components.recorder.models import RecorderRuns, States
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -267,7 +267,7 @@ async def test_schema_migrate(hass, start_version):
def _mock_setup_run(self): def _mock_setup_run(self):
self.run_info = RecorderRuns( self.run_info = RecorderRuns(
start=self.recording_start, created=dt_util.utcnow() start=self.run_history.recording_start, created=dt_util.utcnow()
) )
def _instrument_migration(*args): def _instrument_migration(*args):

View File

@ -0,0 +1,46 @@
"""Test run history."""
from datetime import timedelta
from homeassistant.components import recorder
from homeassistant.components.recorder.models import RecorderRuns, process_timestamp
from homeassistant.util import dt as dt_util
async def test_run_history(hass, recorder_mock):
"""Test the run history gives the correct run."""
instance = recorder.get_instance(hass)
now = dt_util.utcnow()
three_days_ago = now - timedelta(days=3)
two_days_ago = now - timedelta(days=2)
one_day_ago = now - timedelta(days=1)
with instance.get_session() as session:
session.add(RecorderRuns(start=three_days_ago, created=three_days_ago))
session.add(RecorderRuns(start=two_days_ago, created=two_days_ago))
session.add(RecorderRuns(start=one_day_ago, created=one_day_ago))
session.commit()
instance.run_history.load_from_db(session)
assert (
process_timestamp(
instance.run_history.get(three_days_ago + timedelta(microseconds=1)).start
)
== three_days_ago
)
assert (
process_timestamp(
instance.run_history.get(two_days_ago + timedelta(microseconds=1)).start
)
== two_days_ago
)
assert (
process_timestamp(
instance.run_history.get(one_day_ago + timedelta(microseconds=1)).start
)
== one_day_ago
)
assert (
process_timestamp(instance.run_history.get(now).start)
== instance.run_history.recording_start
)

View File

@ -9,7 +9,7 @@ from sqlalchemy import text
from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.elements import TextClause
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import run_information_with_session, util from homeassistant.components.recorder import util
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import RecorderRuns from homeassistant.components.recorder.models import RecorderRuns
from homeassistant.components.recorder.util import ( from homeassistant.components.recorder.util import (
@ -21,7 +21,7 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .common import corrupt_db_file from .common import corrupt_db_file, run_information_with_session
from tests.common import SetupRecorderInstanceT, async_test_home_assistant from tests.common import SetupRecorderInstanceT, async_test_home_assistant