Ensure recorder always attempts clean shutdown if recorder thread raises (#91261)

* Ensure recorder run shutdown if the run loop raises

If anything goes wrong with the recorder we should
still try to shutdown cleanly

* tweak

* tests

* tests

* handle migraiton failure

* tweak comment

* naming

* order

* order

* order

* reword

* adjust test

* fixes

* threading

* failure case

* fix test

* have to wait for stop because the task blocks on thread join
This commit is contained in:
J. Nick Koston 2023-04-14 15:03:24 -10:00 committed by GitHub
parent 56cc6633f5
commit 1379ad60c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 17 deletions

View File

@ -444,10 +444,17 @@ class Recorder(threading.Thread):
async_at_started(self.hass, self._async_hass_started) async_at_started(self.hass, self._async_hass_started)
@callback @callback
def async_connection_failed(self) -> None: def _async_startup_failed(self) -> None:
"""Connect failed tasks.""" """Report startup failure."""
self.async_db_connected.set_result(False) # If a live migration failed, we were able to connect (async_db_connected
self.async_db_ready.set_result(False) # marked True), the database was marked ready (async_db_ready marked
# True), the data in the queue cannot be written to the database because
# the schema not in the correct format so we must stop listeners and report
# failure.
if not self.async_db_connected.done():
self.async_db_connected.set_result(False)
if not self.async_db_ready.done():
self.async_db_ready.set_result(False)
persistent_notification.async_create( persistent_notification.async_create(
self.hass, self.hass,
"The recorder could not start, check [the logs](/config/logs)", "The recorder could not start, check [the logs](/config/logs)",
@ -645,19 +652,26 @@ class Recorder(threading.Thread):
return SHUTDOWN_TASK return SHUTDOWN_TASK
def run(self) -> None: def run(self) -> None:
"""Run the recorder thread."""
try:
self._run()
finally:
# Ensure shutdown happens cleanly if
# anything goes wrong in the run loop
self._shutdown()
def _run(self) -> None:
"""Start processing events to save.""" """Start processing events to save."""
self.thread_id = threading.get_ident() self.thread_id = threading.get_ident()
setup_result = self._setup_recorder() setup_result = self._setup_recorder()
if not setup_result: if not setup_result:
# Give up if we could not connect # Give up if we could not connect
self.hass.add_job(self.async_connection_failed)
return return
schema_status = migration.validate_db_schema(self.hass, self, self.get_session) schema_status = migration.validate_db_schema(self.hass, self, self.get_session)
if schema_status is None: if schema_status is None:
# Give up if we could not validate the schema # Give up if we could not validate the schema
self.hass.add_job(self.async_connection_failed)
return return
self.schema_version = schema_status.current_version self.schema_version = schema_status.current_version
@ -684,7 +698,6 @@ class Recorder(threading.Thread):
self.migration_in_progress = False self.migration_in_progress = False
# Make sure we cleanly close the run if # Make sure we cleanly close the run if
# we restart before startup finishes # we restart before startup finishes
self._shutdown()
return return
if not schema_status.valid: if not schema_status.valid:
@ -702,8 +715,6 @@ class Recorder(threading.Thread):
"Database Migration Failed", "Database Migration Failed",
"recorder_database_migration", "recorder_database_migration",
) )
self.hass.add_job(self.async_set_db_ready)
self._shutdown()
return return
if not database_was_ready: if not database_was_ready:
@ -715,7 +726,6 @@ class Recorder(threading.Thread):
self._adjust_lru_size() self._adjust_lru_size()
self.hass.add_job(self._async_set_recorder_ready_migration_done) self.hass.add_job(self._async_set_recorder_ready_migration_done)
self._run_event_loop() self._run_event_loop()
self._shutdown()
def _activate_and_set_db_ready(self) -> None: def _activate_and_set_db_ready(self) -> None:
"""Activate the table managers or schedule migrations and mark the db as ready.""" """Activate the table managers or schedule migrations and mark the db as ready."""
@ -1355,9 +1365,9 @@ class Recorder(threading.Thread):
def _close_connection(self) -> None: def _close_connection(self) -> None:
"""Close the connection.""" """Close the connection."""
assert self.engine is not None if self.engine:
self.engine.dispose() self.engine.dispose()
self.engine = None self.engine = None
self._get_session = None self._get_session = None
def _setup_run(self) -> None: def _setup_run(self) -> None:
@ -1389,9 +1399,19 @@ class Recorder(threading.Thread):
def _shutdown(self) -> None: def _shutdown(self) -> None:
"""Save end time for current run.""" """Save end time for current run."""
_LOGGER.debug("Shutting down recorder") _LOGGER.debug("Shutting down recorder")
self.hass.add_job(self._async_stop_listeners) if not self.schema_version or self.schema_version != SCHEMA_VERSION:
self._stop_executor() # If the schema version is not set, we never had a working
# connection to the database or the schema never reached a
# good state.
#
# In either case, we want to mark startup as failed.
#
self.hass.add_job(self._async_startup_failed)
else:
self.hass.add_job(self._async_stop_listeners)
try: try:
self._end_session() self._end_session()
finally: finally:
self._stop_executor()
self._close_connection() self._close_connection()

View File

@ -338,7 +338,6 @@ def test_state_changes_during_period_descending(
> hist_states[1].last_changed > hist_states[1].last_changed
> hist_states[2].last_changed > hist_states[2].last_changed
) )
hist = history.state_changes_during_period( hist = history.state_changes_during_period(
hass, hass,
start_time, # Pick a point where we will generate a start time state start_time, # Pick a point where we will generate a start time state

View File

@ -8,7 +8,7 @@ from pathlib import Path
import sqlite3 import sqlite3
import threading import threading
from typing import cast from typing import cast
from unittest.mock import Mock, patch from unittest.mock import MagicMock, Mock, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
@ -27,6 +27,7 @@ from homeassistant.components.recorder import (
SQLITE_URL_PREFIX, SQLITE_URL_PREFIX,
Recorder, Recorder,
get_instance, get_instance,
migration,
pool, pool,
statistics, statistics,
) )
@ -2239,3 +2240,90 @@ async def test_lru_increases_with_many_entities(
== mock_entity_count * 2 == mock_entity_count * 2
) )
assert recorder_mock.states_meta_manager._id_map.get_size() == mock_entity_count * 2 assert recorder_mock.states_meta_manager._id_map.get_size() == mock_entity_count * 2
async def test_clean_shutdown_when_recorder_thread_raises_during_initialize_database(
hass: HomeAssistant,
) -> None:
"""Test we still shutdown cleanly when the recorder thread raises during initialize_database."""
with patch.object(migration, "initialize_database", side_effect=Exception), patch(
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert not await async_setup_component(
hass,
recorder.DOMAIN,
{
recorder.DOMAIN: {
CONF_DB_URL: "sqlite://",
CONF_DB_RETRY_WAIT: 0,
CONF_DB_MAX_RETRIES: 1,
}
},
)
await hass.async_block_till_done()
instance = recorder.get_instance(hass)
await hass.async_stop()
assert instance.engine is None
async def test_clean_shutdown_when_recorder_thread_raises_during_validate_db_schema(
hass: HomeAssistant,
) -> None:
"""Test we still shutdown cleanly when the recorder thread raises during validate_db_schema."""
with patch.object(migration, "validate_db_schema", side_effect=Exception), patch(
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert not await async_setup_component(
hass,
recorder.DOMAIN,
{
recorder.DOMAIN: {
CONF_DB_URL: "sqlite://",
CONF_DB_RETRY_WAIT: 0,
CONF_DB_MAX_RETRIES: 1,
}
},
)
await hass.async_block_till_done()
instance = recorder.get_instance(hass)
await hass.async_stop()
assert instance.engine is None
async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -> None:
"""Test we still shutdown cleanly when schema migration fails."""
with patch.object(
migration,
"validate_db_schema",
return_value=MagicMock(valid=False, current_version=1),
), patch(
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True
), patch.object(
migration,
"migrate_schema",
side_effect=Exception,
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass,
recorder.DOMAIN,
{
recorder.DOMAIN: {
CONF_DB_URL: "sqlite://",
CONF_DB_RETRY_WAIT: 0,
CONF_DB_MAX_RETRIES: 1,
}
},
)
await hass.async_block_till_done()
instance = recorder.get_instance(hass)
await hass.async_stop()
assert instance.engine is None