Refactor tracking of the recorder run history (#70456)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
J. Nick Koston 2022-04-26 09:59:43 -10:00 committed by GitHub
parent 130e7fe128
commit f073f17040
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 242 additions and 113 deletions

View File

@ -72,13 +72,13 @@ from .executor import DBInterruptibleThreadPoolExecutor
from .models import (
Base,
Events,
RecorderRuns,
StateAttributes,
States,
StatisticsRuns,
process_timestamp,
)
from .pool import POOL_SIZE, MutexPool, RecorderPool
from .run_history import RunHistory
from .util import (
dburl_to_path,
end_incomplete_runs,
@ -244,51 +244,6 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
return instance.entity_filter(entity_id)
def run_information(
hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run.
There is also the run that covers point_in_time.
"""
if run_info := run_information_from_instance(hass, point_in_time):
return run_info
with session_scope(hass=hass) as session:
return run_information_with_session(session, point_in_time)
def run_information_from_instance(
hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the existing instance.
Does not query the database for older runs.
"""
ins = get_instance(hass)
if point_in_time is None or point_in_time > ins.recording_start:
return ins.run_info
return None
def run_information_with_session(
session: Session, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the database."""
recorder_runs = RecorderRuns
query = session.query(recorder_runs)
if point_in_time:
query = query.filter(
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
)
if (res := query.first()) is not None:
session.expunge(res)
return cast(RecorderRuns, res)
return res
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the recorder."""
hass.data[DOMAIN] = {}
@ -438,9 +393,13 @@ class PurgeTask(RecorderTask):
def run(self, instance: Recorder) -> None:
"""Purge the database."""
assert instance.get_session is not None
if purge.purge_old_data(
instance, self.purge_before, self.repack, self.apply_filter
):
with instance.get_session() as session:
instance.run_history.load_from_db(session)
# We always need to do the db cleanups after a purge
# is finished to ensure the WAL checkpoint and other
# tasks happen after a vacuum.
@ -652,7 +611,6 @@ class Recorder(threading.Thread):
self._hass_started: asyncio.Future[object] = asyncio.Future()
self.commit_interval = commit_interval
self.queue: queue.SimpleQueue[RecorderTask] = queue.SimpleQueue()
self.recording_start = dt_util.utcnow()
self.db_url = uri
self.db_max_retries = db_max_retries
self.db_retry_wait = db_retry_wait
@ -660,7 +618,7 @@ class Recorder(threading.Thread):
self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event()
self.engine: Engine | None = None
self.run_info: RecorderRuns | None = None
self.run_history = RunHistory()
self.entity_filter = entity_filter
self.exclude_t = exclude_t
@ -1302,6 +1260,7 @@ class Recorder(threading.Thread):
self._close_event_session()
self._close_connection()
move_away_broken_database(dburl_to_path(self.db_url))
self.run_history.reset()
self._setup_recorder()
self._setup_run()
@ -1465,12 +1424,8 @@ class Recorder(threading.Thread):
"""Log the start of the current run and schedule any needed jobs."""
assert self.get_session is not None
with session_scope(session=self.get_session()) as session:
start = self.recording_start
end_incomplete_runs(session, start)
self.run_info = RecorderRuns(start=start, created=dt_util.utcnow())
session.add(self.run_info)
session.flush()
session.expunge(self.run_info)
end_incomplete_runs(session, self.run_history.recording_start)
self.run_history.start(session)
self._schedule_compile_missing_statistics(session)
self._open_event_session()
@ -1498,16 +1453,14 @@ class Recorder(threading.Thread):
"""End the recorder session."""
if self.event_session is None:
return
assert self.run_info is not None
try:
self.run_info.end = dt_util.utcnow()
self.event_session.add(self.run_info)
self.run_history.end(self.event_session)
self._commit_event_session_or_retry()
self.event_session.close()
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error saving the event session during shutdown: %s", err)
self.run_info = None
self.run_history.clear()
def _shutdown(self) -> None:
"""Save end time for current run."""

View File

@ -23,6 +23,7 @@ from .models import (
RecorderRuns,
StateAttributes,
States,
process_timestamp,
process_timestamp_to_utc_isoformat,
)
from .util import execute, session_scope
@ -478,11 +479,10 @@ def get_states(
no_attributes: bool = False,
) -> list[State]:
"""Return the states at a specific point in time."""
if (
run is None
and (run := (recorder.run_information_from_instance(hass, utc_point_in_time)))
is None
):
if run is None:
run = recorder.get_instance(hass).run_history.get(utc_point_in_time)
if run is None or process_timestamp(run.start) > utc_point_in_time:
# History did not run before utc_point_in_time
return []
@ -507,11 +507,10 @@ def _get_states_with_session(
hass, session, utc_point_in_time, entity_ids[0], no_attributes
)
if (
run is None
and (run := (recorder.run_information_with_session(session, utc_point_in_time)))
is None
):
if run is None:
run = recorder.get_instance(hass).run_history.get(utc_point_in_time)
if run is None or process_timestamp(run.start) > utc_point_in_time:
# History did not run before utc_point_in_time
return []
@ -649,13 +648,11 @@ def _sorted_states_to_dict(
# Get the states at the start time
timer_start = time.perf_counter()
if include_start_time_state:
run = recorder.run_information_from_instance(hass, start_time)
for state in _get_states_with_session(
hass,
session,
start_time,
entity_ids,
run=run,
filters=filters,
no_attributes=no_attributes,
):

View File

@ -291,11 +291,10 @@ def _purge_old_recorder_runs(
) -> None:
"""Purge all old recorder runs."""
# Recorder runs is small, no need to batch run it
assert instance.run_info is not None
deleted_rows = (
session.query(RecorderRuns)
.filter(RecorderRuns.start < purge_before)
.filter(RecorderRuns.run_id != instance.run_info.run_id)
.filter(RecorderRuns.run_id != instance.run_history.current.run_id)
.delete(synchronize_session=False)
)
_LOGGER.debug("Deleted %s recorder_runs", deleted_rows)

View File

@ -0,0 +1,133 @@
"""Track recorder run history."""
from __future__ import annotations
import bisect
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy.orm.session import Session
import homeassistant.util.dt as dt_util
from .models import RecorderRuns, process_timestamp
def _find_recorder_run_for_start_time(
run_history: _RecorderRunsHistory, start: datetime
) -> RecorderRuns | None:
"""Find the recorder run for a start time in _RecorderRunsHistory."""
run_timestamps = run_history.run_timestamps
runs_by_timestamp = run_history.runs_by_timestamp
# bisect_left tells us were we would insert
# a value in the list of runs after the start timestamp.
#
# The run before that (idx-1) is when the run started
#
# If idx is 0, history never ran before the start timestamp
#
if idx := bisect.bisect_left(run_timestamps, start.timestamp()):
return runs_by_timestamp[run_timestamps[idx - 1]]
return None
@dataclass(frozen=True)
class _RecorderRunsHistory:
"""Bisectable history of RecorderRuns."""
run_timestamps: list[int]
runs_by_timestamp: dict[int, RecorderRuns]
class RunHistory:
"""Track recorder run history."""
def __init__(self) -> None:
"""Track recorder run history."""
self._recording_start = dt_util.utcnow()
self._current_run_info: RecorderRuns | None = None
self._run_history = _RecorderRunsHistory([], {})
@property
def recording_start(self) -> datetime:
"""Return the time the recorder started recording states."""
return self._recording_start
@property
def current(self) -> RecorderRuns:
"""Get the current run."""
assert self._current_run_info is not None
return self._current_run_info
def get(self, start: datetime) -> RecorderRuns | None:
"""Return the recorder run that started before or at start.
If the first run started after the start, return None
"""
if start >= self.recording_start:
return self.current
return _find_recorder_run_for_start_time(self._run_history, start)
def start(self, session: Session) -> None:
"""Start a new run.
Must run in the recorder thread.
"""
self._current_run_info = RecorderRuns(
start=self.recording_start, created=dt_util.utcnow()
)
session.add(self._current_run_info)
session.flush()
session.expunge(self._current_run_info)
self.load_from_db(session)
def reset(self) -> None:
"""Reset the run when the database is changed or fails.
Must run in the recorder thread.
"""
self._recording_start = dt_util.utcnow()
self._current_run_info = None
def end(self, session: Session) -> None:
"""End the current run.
Must run in the recorder thread.
"""
assert self._current_run_info is not None
self._current_run_info.end = dt_util.utcnow()
session.add(self._current_run_info)
def load_from_db(self, session: Session) -> None:
"""Update the run cache.
Must run in the recorder thread.
"""
run_timestamps: list[int] = []
runs_by_timestamp: dict[int, RecorderRuns] = {}
for run in session.query(RecorderRuns).order_by(RecorderRuns.start.asc()).all():
session.expunge(run)
if run_dt := process_timestamp(run.start):
timestamp = run_dt.timestamp()
run_timestamps.append(timestamp)
runs_by_timestamp[timestamp] = run
#
# self._run_history is accessed in get()
# which is allowed to be called from any thread
#
# We use a dataclass to ensure that when we update
# run_timestamps and runs_by_timestamp
# are never out of sync with each other.
#
self._run_history = _RecorderRunsHistory(run_timestamps, runs_by_timestamp)
def clear(self) -> None:
"""Clear the current run after ending it.
Must run in the recorder thread.
"""
assert self._current_run_info is not None
assert self._current_run_info.end is not None
self._current_run_info = None

View File

@ -1,10 +1,15 @@
"""Common test utils for working with recorder."""
from datetime import timedelta
from __future__ import annotations
from datetime import datetime, timedelta
from typing import cast
from sqlalchemy import create_engine
from sqlalchemy.orm.session import Session
from homeassistant import core as ha
from homeassistant.components import recorder
from homeassistant.components.recorder.models import RecorderRuns
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
@ -81,3 +86,21 @@ def create_engine_test(*args, **kwargs):
engine = create_engine(*args, **kwargs)
models_schema_0.Base.metadata.create_all(engine)
return engine
def run_information_with_session(
session: Session, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the database."""
recorder_runs = RecorderRuns
query = session.query(recorder_runs)
if point_in_time:
query = query.filter(
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
)
if (res := query.first()) is not None:
session.expunge(res)
return cast(RecorderRuns, res)
return res

View File

@ -1,4 +1,6 @@
"""The tests for the Recorder component."""
from __future__ import annotations
# pylint: disable=protected-access
import asyncio
from datetime import datetime, timedelta
@ -25,9 +27,6 @@ from homeassistant.components.recorder import (
SQLITE_URL_PREFIX,
Recorder,
get_instance,
run_information,
run_information_from_instance,
run_information_with_session,
)
from homeassistant.components.recorder.const import DATA_INSTANCE
from homeassistant.components.recorder.models import (
@ -51,7 +50,12 @@ from homeassistant.core import Context, CoreState, HomeAssistant, callback
from homeassistant.setup import async_setup_component, setup_component
from homeassistant.util import dt as dt_util
from .common import async_wait_recording_done, corrupt_db_file, wait_recording_done
from .common import (
async_wait_recording_done,
corrupt_db_file,
run_information_with_session,
wait_recording_done,
)
from tests.common import (
SetupRecorderInstanceT,
@ -1008,37 +1012,6 @@ def test_saving_state_with_serializable_data(hass_recorder, caplog):
assert "State is not JSON serializable" in caplog.text
def test_run_information(hass_recorder):
"""Ensure run_information returns expected data."""
before_start_recording = dt_util.utcnow()
hass = hass_recorder()
run_info = run_information_from_instance(hass)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
with session_scope(hass=hass) as session:
run_info = run_information_with_session(session)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
run_info = run_information(hass)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
hass.states.set("test.two", "on", {})
wait_recording_done(hass)
run_info = run_information(hass)
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
run_info = run_information(hass, before_start_recording)
assert run_info is None
run_info = run_information(hass, dt_util.utcnow())
assert isinstance(run_info, RecorderRuns)
assert run_info.closed_incorrect is False
def test_has_services(hass_recorder):
"""Test the services exist."""
hass = hass_recorder()
@ -1208,6 +1181,8 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
await hass.async_block_till_done()
caplog.clear()
original_start_time = get_instance(hass).run_history.recording_start
hass.states.async_set("test.lost", "on", {})
sqlite3_exception = DatabaseError("statement", {}, [])
@ -1252,6 +1227,9 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
assert state.entity_id == "test.two"
assert state.state == "on"
new_start_time = get_instance(hass).run_history.recording_start
assert original_start_time < new_start_time
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
hass.stop()

View File

@ -20,9 +20,9 @@ from sqlalchemy.pool import StaticPool
from homeassistant.bootstrap import async_setup_component
from homeassistant.components import persistent_notification as pn, recorder
from homeassistant.components.recorder import RecorderRuns, migration, models
from homeassistant.components.recorder import migration, models
from homeassistant.components.recorder.const import DATA_INSTANCE
from homeassistant.components.recorder.models import States
from homeassistant.components.recorder.models import RecorderRuns, States
from homeassistant.components.recorder.util import session_scope
import homeassistant.util.dt as dt_util
@ -267,7 +267,7 @@ async def test_schema_migrate(hass, start_version):
def _mock_setup_run(self):
self.run_info = RecorderRuns(
start=self.recording_start, created=dt_util.utcnow()
start=self.run_history.recording_start, created=dt_util.utcnow()
)
def _instrument_migration(*args):

View File

@ -0,0 +1,46 @@
"""Test run history."""
from datetime import timedelta
from homeassistant.components import recorder
from homeassistant.components.recorder.models import RecorderRuns, process_timestamp
from homeassistant.util import dt as dt_util
async def test_run_history(hass, recorder_mock):
"""Test the run history gives the correct run."""
instance = recorder.get_instance(hass)
now = dt_util.utcnow()
three_days_ago = now - timedelta(days=3)
two_days_ago = now - timedelta(days=2)
one_day_ago = now - timedelta(days=1)
with instance.get_session() as session:
session.add(RecorderRuns(start=three_days_ago, created=three_days_ago))
session.add(RecorderRuns(start=two_days_ago, created=two_days_ago))
session.add(RecorderRuns(start=one_day_ago, created=one_day_ago))
session.commit()
instance.run_history.load_from_db(session)
assert (
process_timestamp(
instance.run_history.get(three_days_ago + timedelta(microseconds=1)).start
)
== three_days_ago
)
assert (
process_timestamp(
instance.run_history.get(two_days_ago + timedelta(microseconds=1)).start
)
== two_days_ago
)
assert (
process_timestamp(
instance.run_history.get(one_day_ago + timedelta(microseconds=1)).start
)
== one_day_ago
)
assert (
process_timestamp(instance.run_history.get(now).start)
== instance.run_history.recording_start
)

View File

@ -9,7 +9,7 @@ from sqlalchemy import text
from sqlalchemy.sql.elements import TextClause
from homeassistant.components import recorder
from homeassistant.components.recorder import run_information_with_session, util
from homeassistant.components.recorder import util
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import RecorderRuns
from homeassistant.components.recorder.util import (
@ -21,7 +21,7 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
from .common import corrupt_db_file
from .common import corrupt_db_file, run_information_with_session
from tests.common import SetupRecorderInstanceT, async_test_home_assistant