Improve error handling when creating new SQLite database (#122406)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Erik Montnemery 2024-07-22 21:16:11 +02:00 committed by GitHub
parent 20fc5233a1
commit 3dc36cf068
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 14 deletions

View File

@ -769,9 +769,7 @@ class Recorder(threading.Thread):
# Do live migration steps and repairs, if needed # Do live migration steps and repairs, if needed
if schema_status.migration_needed or schema_status.schema_errors: if schema_status.migration_needed or schema_status.schema_errors:
result, schema_status = self._migrate_schema_live_and_setup_run( result, schema_status = self._migrate_schema_live(schema_status)
schema_status
)
if result: if result:
self.schema_version = SCHEMA_VERSION self.schema_version = SCHEMA_VERSION
if not self._event_listener: if not self._event_listener:
@ -789,6 +787,7 @@ class Recorder(threading.Thread):
if self.migration_in_progress: if self.migration_in_progress:
self.migration_in_progress = False self.migration_in_progress = False
self._dismiss_migration_in_progress() self._dismiss_migration_in_progress()
self._setup_run()
# Catch up with missed statistics # Catch up with missed statistics
self._schedule_compile_missing_statistics() self._schedule_compile_missing_statistics()
@ -907,7 +906,7 @@ class Recorder(threading.Thread):
self._commit_event_session_or_retry() self._commit_event_session_or_retry()
task.run(self) task.run(self)
except exc.DatabaseError as err: except exc.DatabaseError as err:
if self._handle_database_error(err): if self._handle_database_error(err, setup_run=True):
return return
_LOGGER.exception("Unhandled database error while processing task %s", task) _LOGGER.exception("Unhandled database error while processing task %s", task)
except SQLAlchemyError: except SQLAlchemyError:
@ -953,7 +952,7 @@ class Recorder(threading.Thread):
"""Migrate schema to the latest version.""" """Migrate schema to the latest version."""
return self._migrate_schema(schema_status, False) return self._migrate_schema(schema_status, False)
def _migrate_schema_live_and_setup_run( def _migrate_schema_live(
self, schema_status: migration.SchemaValidationStatus self, schema_status: migration.SchemaValidationStatus
) -> tuple[bool, migration.SchemaValidationStatus]: ) -> tuple[bool, migration.SchemaValidationStatus]:
"""Migrate schema to the latest version.""" """Migrate schema to the latest version."""
@ -971,10 +970,7 @@ class Recorder(threading.Thread):
"recorder_database_migration", "recorder_database_migration",
) )
self.hass.add_job(self._async_migration_started) self.hass.add_job(self._async_migration_started)
migration_result, schema_status = self._migrate_schema(schema_status, True) return self._migrate_schema(schema_status, True)
if migration_result:
self._setup_run()
return migration_result, schema_status
def _migrate_schema( def _migrate_schema(
self, self,
@ -992,7 +988,7 @@ class Recorder(threading.Thread):
self, self.hass, self.engine, self.get_session, schema_status self, self.hass, self.engine, self.get_session, schema_status
) )
except exc.DatabaseError as err: except exc.DatabaseError as err:
if self._handle_database_error(err): if self._handle_database_error(err, setup_run=False):
# If _handle_database_error returns True, we have a new database # If _handle_database_error returns True, we have a new database
# which does not need migration or repair. # which does not need migration or repair.
new_schema_status = migration.SchemaValidationStatus( new_schema_status = migration.SchemaValidationStatus(
@ -1179,7 +1175,7 @@ class Recorder(threading.Thread):
self._add_to_session(session, dbstate) self._add_to_session(session, dbstate)
def _handle_database_error(self, err: Exception) -> bool: def _handle_database_error(self, err: Exception, *, setup_run: bool) -> bool:
"""Handle a database error that may result in moving away the corrupt db.""" """Handle a database error that may result in moving away the corrupt db."""
if ( if (
(cause := err.__cause__) (cause := err.__cause__)
@ -1193,7 +1189,7 @@ class Recorder(threading.Thread):
_LOGGER.exception( _LOGGER.exception(
"Unrecoverable sqlite3 database corruption detected: %s", err "Unrecoverable sqlite3 database corruption detected: %s", err
) )
self._handle_sqlite_corruption() self._handle_sqlite_corruption(setup_run)
return True return True
return False return False
@ -1260,7 +1256,7 @@ class Recorder(threading.Thread):
self._commits_without_expire = 0 self._commits_without_expire = 0
session.expire_all() session.expire_all()
def _handle_sqlite_corruption(self) -> None: def _handle_sqlite_corruption(self, setup_run: bool) -> None:
"""Handle the sqlite3 database being corrupt.""" """Handle the sqlite3 database being corrupt."""
try: try:
self._close_event_session() self._close_event_session()
@ -1269,7 +1265,8 @@ class Recorder(threading.Thread):
move_away_broken_database(dburl_to_path(self.db_url)) move_away_broken_database(dburl_to_path(self.db_url))
self.recorder_runs_manager.reset() self.recorder_runs_manager.reset()
self._setup_recorder() self._setup_recorder()
self._setup_run() if setup_run:
self._setup_run()
def _close_event_session(self) -> None: def _close_event_session(self) -> None:
"""Close the event session.""" """Close the event session."""

View File

@ -190,6 +190,11 @@ async def test_live_database_migration_encounters_corruption(
patch( patch(
"homeassistant.components.recorder.core.move_away_broken_database" "homeassistant.components.recorder.core.move_away_broken_database"
) as move_away, ) as move_away,
patch(
"homeassistant.components.recorder.core.Recorder._setup_run",
autospec=True,
wraps=recorder.Recorder._setup_run,
) as setup_run,
): ):
await async_setup_recorder_instance(hass) await async_setup_recorder_instance(hass)
hass.states.async_set("my.entity", "on", {}) hass.states.async_set("my.entity", "on", {})
@ -198,6 +203,7 @@ async def test_live_database_migration_encounters_corruption(
assert recorder.util.async_migration_in_progress(hass) is False assert recorder.util.async_migration_in_progress(hass) is False
move_away.assert_called_once() move_away.assert_called_once()
setup_run.assert_called_once()
@pytest.mark.skip_on_db_engine(["mysql", "postgresql"]) @pytest.mark.skip_on_db_engine(["mysql", "postgresql"])
@ -235,6 +241,11 @@ async def test_non_live_database_migration_encounters_corruption(
patch( patch(
"homeassistant.components.recorder.core.move_away_broken_database" "homeassistant.components.recorder.core.move_away_broken_database"
) as move_away, ) as move_away,
patch(
"homeassistant.components.recorder.core.Recorder._setup_run",
autospec=True,
wraps=recorder.Recorder._setup_run,
) as setup_run,
): ):
await async_setup_recorder_instance(hass) await async_setup_recorder_instance(hass)
hass.states.async_set("my.entity", "on", {}) hass.states.async_set("my.entity", "on", {})
@ -244,6 +255,7 @@ async def test_non_live_database_migration_encounters_corruption(
assert recorder.util.async_migration_in_progress(hass) is False assert recorder.util.async_migration_in_progress(hass) is False
move_away.assert_called_once() move_away.assert_called_once()
migrate_schema_live.assert_not_called() migrate_schema_live.assert_not_called()
setup_run.assert_called_once()
@pytest.mark.parametrize( @pytest.mark.parametrize(