Ensure recorder still shuts down if the final commit fails (#87799)

This commit is contained in:
J. Nick Koston 2023-02-09 15:12:40 -06:00 committed by GitHub
parent f96de4ab45
commit 509de02044
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 14 deletions

View File

@ -989,8 +989,10 @@ class Recorder(threading.Thread):
def _handle_sqlite_corruption(self) -> None:
"""Handle the sqlite3 database being corrupt."""
self._close_event_session()
self._close_connection()
try:
self._close_event_session()
finally:
self._close_connection()
move_away_broken_database(dburl_to_path(self.db_url))
self.run_history.reset()
self._setup_recorder()
@ -1213,18 +1215,21 @@ class Recorder(threading.Thread):
"""End the recorder session."""
if self.event_session is None:
return
try:
if self.run_history.active:
self.run_history.end(self.event_session)
try:
self._commit_event_session_or_retry()
self.event_session.close()
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error saving the event session during shutdown: %s", err)
self.event_session.close()
self.run_history.clear()
def _shutdown(self) -> None:
"""Save end time for current run."""
self.hass.add_job(self._async_stop_listeners)
self._stop_executor()
self._end_session()
self._close_connection()
try:
self._end_session()
finally:
self._close_connection()

View File

@ -72,6 +72,11 @@ class RunHistory:
start=self.recording_start, created=dt_util.utcnow()
)
@property
def active(self) -> bool:
"""Return if a run is active."""
return self._current_run_info is not None
def get(self, start: datetime) -> RecorderRuns | None:
"""Return the recorder run that started before or at start.
@ -142,6 +147,5 @@ class RunHistory:
Must run in the recorder thread.
"""
assert self._current_run_info is not None
assert self._current_run_info.end is not None
self._current_run_info = None
if self._current_run_info:
self._current_run_info = None

View File

@ -1,12 +1,16 @@
"""Test run history."""
from datetime import timedelta
from unittest.mock import patch
from homeassistant.components import recorder
from homeassistant.components.recorder.db_schema import RecorderRuns
from homeassistant.components.recorder.models import process_timestamp
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
from tests.common import SetupRecorderInstanceT
async def test_run_history(recorder_mock, hass):
"""Test the run history gives the correct run."""
@ -47,12 +51,32 @@ async def test_run_history(recorder_mock, hass):
)
async def test_run_history_during_schema_migration(recorder_mock, hass):
"""Test the run history during schema migration."""
instance = recorder.get_instance(hass)
async def test_run_history_while_recorder_is_not_yet_started(
async_setup_recorder_instance: SetupRecorderInstanceT,
hass: HomeAssistant,
recorder_db_url: str,
) -> None:
"""Test the run history while recorder is not yet started.
This usually happens during schema migration because
we do not start right away.
"""
# Prevent the run history from starting to ensure
# we can test run_history.current.start returns the expected value
with patch(
"homeassistant.components.recorder.run_history.RunHistory.start",
):
instance = await async_setup_recorder_instance(hass)
run_history = instance.run_history
assert run_history.current.start == run_history.recording_start
with instance.get_session() as session:
run_history.start(session)
def _start_run_history():
with instance.get_session() as session:
run_history.start(session)
# Ideally we would run run_history.start in the recorder thread
# but since we mocked it out above, we run it directly here
# via the database executor to avoid blocking the event loop.
await instance.async_add_executor_job(_start_run_history)
assert run_history.current.start == run_history.recording_start
assert run_history.current.created >= run_history.recording_start