Refactor recorder schema migration (#122372)

* Refactor recorder schema migration

* Simplify

* Remove unused imports

* Refactor _migrate_schema according to review comments

* Add comment
This commit is contained in:
Erik Montnemery 2024-07-22 16:53:54 +02:00 committed by GitHub
parent c73e7ae178
commit e8b88557ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 168 additions and 76 deletions

View File

@ -716,6 +716,15 @@ class Recorder(threading.Thread):
self._event_session_has_pending_writes = True self._event_session_has_pending_writes = True
session.add(obj) session.add(obj)
def _notify_migration_failed(self) -> None:
"""Notify the user schema migration failed."""
persistent_notification.create(
self.hass,
"The database migration failed, check [the logs](/config/logs).",
"Database Migration Failed",
"recorder_database_migration",
)
def _run(self) -> None: def _run(self) -> None:
"""Start processing events to save.""" """Start processing events to save."""
thread_id = threading.get_ident() thread_id = threading.get_ident()
@ -741,14 +750,20 @@ class Recorder(threading.Thread):
self.migration_is_live = migration.live_migration(schema_status) self.migration_is_live = migration.live_migration(schema_status)
self.hass.add_job(self.async_connection_success) self.hass.add_job(self.async_connection_success)
database_was_ready = self.migration_is_live or schema_status.valid
if database_was_ready: # First do non-live migration steps, if needed
# If the migrate is live or the schema is valid, we need to if not schema_status.valid:
# wait for startup to complete. If its not live, we need to continue result, schema_status = self._migrate_schema_offline(schema_status)
# on. if not result:
self._activate_and_set_db_ready() self._notify_migration_failed()
self.migration_in_progress = False
return
self.schema_version = schema_status.current_version
# Non-live migration is now completed, remaining steps are live
self.migration_is_live = True
# After non-live migration, activate the recorder
self._activate_and_set_db_ready(schema_status)
# We wait to start a live migration until startup has finished # We wait to start a live migration until startup has finished
# since it can be cpu intensive and we do not want it to compete # since it can be cpu intensive and we do not want it to compete
# with startup which is also cpu intensive # with startup which is also cpu intensive
@ -759,8 +774,12 @@ class Recorder(threading.Thread):
# we restart before startup finishes # we restart before startup finishes
return return
# Do live migration steps, if needed
if not schema_status.valid: if not schema_status.valid:
if self._migrate_schema_and_setup_run(schema_status): result, schema_status = self._migrate_schema_live_and_setup_run(
schema_status
)
if result:
self.schema_version = SCHEMA_VERSION self.schema_version = SCHEMA_VERSION
if not self._event_listener: if not self._event_listener:
# If the schema migration takes so long that the end # If the schema migration takes so long that the end
@ -768,17 +787,9 @@ class Recorder(threading.Thread):
# was True, we need to reinitialize the listener. # was True, we need to reinitialize the listener.
self.hass.add_job(self.async_initialize) self.hass.add_job(self.async_initialize)
else: else:
persistent_notification.create( self._notify_migration_failed()
self.hass,
"The database migration failed, check [the logs](/config/logs).",
"Database Migration Failed",
"recorder_database_migration",
)
return return
if not database_was_ready:
self._activate_and_set_db_ready()
# Catch up with missed statistics # Catch up with missed statistics
self._schedule_compile_missing_statistics() self._schedule_compile_missing_statistics()
_LOGGER.debug("Recorder processing the queue") _LOGGER.debug("Recorder processing the queue")
@ -786,7 +797,9 @@ class Recorder(threading.Thread):
self.hass.add_job(self._async_set_recorder_ready_migration_done) self.hass.add_job(self._async_set_recorder_ready_migration_done)
self._run_event_loop() self._run_event_loop()
def _activate_and_set_db_ready(self) -> None: def _activate_and_set_db_ready(
self, schema_status: migration.SchemaValidationStatus
) -> None:
"""Activate the table managers or schedule migrations and mark the db as ready.""" """Activate the table managers or schedule migrations and mark the db as ready."""
with session_scope(session=self.get_session()) as session: with session_scope(session=self.get_session()) as session:
# Prime the statistics meta manager as soon as possible # Prime the statistics meta manager as soon as possible
@ -808,7 +821,7 @@ class Recorder(threading.Thread):
EventTypeIDMigration, EventTypeIDMigration,
EntityIDMigration, EntityIDMigration,
): ):
migrator = migrator_cls(schema_version, migration_changes) migrator = migrator_cls(schema_status.start_version, migration_changes)
migrator.do_migrate(self, session) migrator.do_migrate(self, session)
if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION: if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
@ -947,9 +960,15 @@ class Recorder(threading.Thread):
"""Set the migration started event.""" """Set the migration started event."""
self.async_migration_event.set() self.async_migration_event.set()
def _migrate_schema_and_setup_run( def _migrate_schema_offline(
self, schema_status: migration.SchemaValidationStatus self, schema_status: migration.SchemaValidationStatus
) -> bool: ) -> tuple[bool, migration.SchemaValidationStatus]:
"""Migrate schema to the latest version."""
return self._migrate_schema(schema_status, False)
def _migrate_schema_live_and_setup_run(
self, schema_status: migration.SchemaValidationStatus
) -> tuple[bool, migration.SchemaValidationStatus]:
"""Migrate schema to the latest version.""" """Migrate schema to the latest version."""
persistent_notification.create( persistent_notification.create(
self.hass, self.hass,
@ -965,26 +984,40 @@ class Recorder(threading.Thread):
"recorder_database_migration", "recorder_database_migration",
) )
self.hass.add_job(self._async_migration_started) self.hass.add_job(self._async_migration_started)
try: try:
migration_result, schema_status = self._migrate_schema(schema_status, True)
if migration_result:
self._setup_run()
return migration_result, schema_status
finally:
self.migration_in_progress = False
persistent_notification.dismiss(self.hass, "recorder_database_migration")
def _migrate_schema(
self,
schema_status: migration.SchemaValidationStatus,
live: bool,
) -> tuple[bool, migration.SchemaValidationStatus]:
"""Migrate schema to the latest version."""
assert self.engine is not None assert self.engine is not None
migration.migrate_schema( try:
if live:
migrator = migration.migrate_schema_live
else:
migrator = migration.migrate_schema_non_live
new_schema_status = migrator(
self, self.hass, self.engine, self.get_session, schema_status self, self.hass, self.engine, self.get_session, schema_status
) )
except exc.DatabaseError as err: except exc.DatabaseError as err:
if self._handle_database_error(err): if self._handle_database_error(err):
return True return (True, schema_status)
_LOGGER.exception("Database error during schema migration") _LOGGER.exception("Database error during schema migration")
return False return (False, schema_status)
except Exception: except Exception:
_LOGGER.exception("Error during schema migration") _LOGGER.exception("Error during schema migration")
return False return (False, schema_status)
else: else:
self._setup_run() return (True, new_schema_status)
return True
finally:
self.migration_in_progress = False
persistent_notification.dismiss(self.hass, "recorder_database_migration")
def _lock_database(self, task: DatabaseLockTask) -> None: def _lock_database(self, task: DatabaseLockTask) -> None:
@callback @callback

View File

@ -188,12 +188,13 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
return None return None
@dataclass @dataclass(frozen=True)
class SchemaValidationStatus: class SchemaValidationStatus:
"""Store schema validation status.""" """Store schema validation status."""
current_version: int current_version: int
schema_errors: set[str] schema_errors: set[str]
start_version: int
valid: bool valid: bool
@ -224,7 +225,9 @@ def validate_db_schema(
valid = is_current and not schema_errors valid = is_current and not schema_errors
return SchemaValidationStatus(current_version, schema_errors, valid) return SchemaValidationStatus(
current_version, schema_errors, current_version, valid
)
def _find_schema_errors( def _find_schema_errors(
@ -260,35 +263,30 @@ def pre_migrate_schema(engine: Engine) -> None:
) )
def migrate_schema( def _migrate_schema(
instance: Recorder, instance: Recorder,
hass: HomeAssistant, hass: HomeAssistant,
engine: Engine, engine: Engine,
session_maker: Callable[[], Session], session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus, schema_status: SchemaValidationStatus,
) -> None: end_version: int,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded.""" """Check if the schema needs to be upgraded."""
current_version = schema_status.current_version current_version = schema_status.current_version
if current_version != SCHEMA_VERSION: start_version = schema_status.start_version
if current_version < end_version:
_LOGGER.warning( _LOGGER.warning(
"Database is about to upgrade from schema version: %s to: %s", "Database is about to upgrade from schema version: %s to: %s",
current_version, current_version,
SCHEMA_VERSION, end_version,
) )
db_ready = False schema_status = dataclass_replace(schema_status, current_version=end_version)
for version in range(current_version, SCHEMA_VERSION):
if ( for version in range(current_version, end_version):
live_migration(dataclass_replace(schema_status, current_version=version))
and not db_ready
):
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_db_ready)
new_version = version + 1 new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version) _LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update( _apply_update(instance, hass, engine, session_maker, new_version, start_version)
instance, hass, engine, session_maker, new_version, current_version
)
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
session.add(SchemaChanges(schema_version=new_version)) session.add(SchemaChanges(schema_version=new_version))
@ -296,6 +294,37 @@ def migrate_schema(
# so its clear that the upgrade is done # so its clear that the upgrade is done
_LOGGER.warning("Upgrade to version %s done", new_version) _LOGGER.warning("Upgrade to version %s done", new_version)
return schema_status
def migrate_schema_non_live(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
end_version = LIVE_MIGRATION_MIN_SCHEMA_VERSION - 1
return _migrate_schema(
instance, hass, engine, session_maker, schema_status, end_version
)
def migrate_schema_live(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
end_version = SCHEMA_VERSION
schema_status = _migrate_schema(
instance, hass, engine, session_maker, schema_status, end_version
)
# Repairs are currently done during the live migration
if schema_errors := schema_status.schema_errors: if schema_errors := schema_status.schema_errors:
_LOGGER.warning( _LOGGER.warning(
"Database is about to correct DB schema errors: %s", "Database is about to correct DB schema errors: %s",
@ -305,12 +334,15 @@ def migrate_schema(
states_correct_db_schema(instance, schema_errors) states_correct_db_schema(instance, schema_errors)
events_correct_db_schema(instance, schema_errors) events_correct_db_schema(instance, schema_errors)
if current_version != SCHEMA_VERSION: start_version = schema_status.start_version
instance.queue_task(PostSchemaMigrationTask(current_version, SCHEMA_VERSION)) if start_version != SCHEMA_VERSION:
instance.queue_task(PostSchemaMigrationTask(start_version, SCHEMA_VERSION))
# Make sure the post schema migration task is committed in case # Make sure the post schema migration task is committed in case
# the next task does not have commit_before = True # the next task does not have commit_before = True
instance.queue_task(CommitTask()) instance.queue_task(CommitTask())
return schema_status
def _create_index( def _create_index(
session_maker: Callable[[], Session], table_name: str, index_name: str session_maker: Callable[[], Session], table_name: str, index_name: str

View File

@ -2,6 +2,7 @@
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
import threading import threading
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -69,15 +70,16 @@ async def instrument_migration(
) -> AsyncGenerator[InstrumentedMigration]: ) -> AsyncGenerator[InstrumentedMigration]:
"""Instrument recorder migration.""" """Instrument recorder migration."""
real_migrate_schema = recorder.migration.migrate_schema real_migrate_schema_live = recorder.migration.migrate_schema_live
real_migrate_schema_non_live = recorder.migration.migrate_schema_non_live
real_apply_update = recorder.migration._apply_update real_apply_update = recorder.migration._apply_update
def _instrument_migrate_schema(*args): def _instrument_migrate_schema(real_func, *args):
"""Control migration progress and check results.""" """Control migration progress and check results."""
instrumented_migration.migration_started.set() instrumented_migration.migration_started.set()
try: try:
real_migrate_schema(*args) migration_result = real_func(*args)
except Exception: except Exception:
instrumented_migration.migration_done.set() instrumented_migration.migration_done.set()
raise raise
@ -92,6 +94,7 @@ async def instrument_migration(
) )
instrumented_migration.migration_version = res.schema_version instrumented_migration.migration_version = res.schema_version
instrumented_migration.migration_done.set() instrumented_migration.migration_done.set()
return migration_result
def _instrument_apply_update(*args): def _instrument_apply_update(*args):
"""Control migration progress.""" """Control migration progress."""
@ -100,8 +103,12 @@ async def instrument_migration(
with ( with (
patch( patch(
"homeassistant.components.recorder.migration.migrate_schema", "homeassistant.components.recorder.migration.migrate_schema_live",
wraps=_instrument_migrate_schema, wraps=partial(_instrument_migrate_schema, real_migrate_schema_live),
),
patch(
"homeassistant.components.recorder.migration.migrate_schema_non_live",
wraps=partial(_instrument_migrate_schema, real_migrate_schema_non_live),
), ),
patch( patch(
"homeassistant.components.recorder.migration._apply_update", "homeassistant.components.recorder.migration._apply_update",

View File

@ -2569,7 +2569,13 @@ async def test_clean_shutdown_when_recorder_thread_raises_during_validate_db_sch
assert instance.engine is None assert instance.engine is None
async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
("func_to_patch", "expected_setup_result"),
[("migrate_schema_non_live", False), ("migrate_schema_live", False)],
)
async def test_clean_shutdown_when_schema_migration_fails(
hass: HomeAssistant, func_to_patch: str, expected_setup_result: bool
) -> None:
"""Test we still shutdown cleanly when schema migration fails.""" """Test we still shutdown cleanly when schema migration fails."""
with ( with (
patch.object( patch.object(
@ -2580,13 +2586,13 @@ async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -
patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True),
patch.object( patch.object(
migration, migration,
"migrate_schema", func_to_patch,
side_effect=Exception, side_effect=Exception,
), ),
): ):
if recorder.DOMAIN not in hass.data: if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component( setup_result = await async_setup_component(
hass, hass,
recorder.DOMAIN, recorder.DOMAIN,
{ {
@ -2597,6 +2603,7 @@ async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -
} }
}, },
) )
assert setup_result == expected_setup_result
await hass.async_block_till_done() await hass.async_block_till_done()
instance = recorder.get_instance(hass) instance = recorder.get_instance(hass)

View File

@ -184,7 +184,7 @@ async def test_database_migration_encounters_corruption(
side_effect=[False], side_effect=[False],
), ),
patch( patch(
"homeassistant.components.recorder.migration.migrate_schema", "homeassistant.components.recorder.migration.migrate_schema_non_live",
side_effect=sqlite3_exception, side_effect=sqlite3_exception,
), ),
patch( patch(
@ -201,13 +201,26 @@ async def test_database_migration_encounters_corruption(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("live_migration", "expected_setup_result"), [(True, True), (False, False)] (
"live_migration",
"func_to_patch",
"expected_setup_result",
"expected_pn_create",
"expected_pn_dismiss",
),
[
(True, "migrate_schema_live", True, 2, 1),
(False, "migrate_schema_non_live", False, 1, 0),
],
) )
async def test_database_migration_encounters_corruption_not_sqlite( async def test_database_migration_encounters_corruption_not_sqlite(
hass: HomeAssistant, hass: HomeAssistant,
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
live_migration: bool, live_migration: bool,
func_to_patch: str,
expected_setup_result: bool, expected_setup_result: bool,
expected_pn_create: int,
expected_pn_dismiss: int,
) -> None: ) -> None:
"""Test we fail on database error when we cannot recover.""" """Test we fail on database error when we cannot recover."""
assert recorder.util.async_migration_in_progress(hass) is False assert recorder.util.async_migration_in_progress(hass) is False
@ -218,7 +231,7 @@ async def test_database_migration_encounters_corruption_not_sqlite(
side_effect=[False], side_effect=[False],
), ),
patch( patch(
"homeassistant.components.recorder.migration.migrate_schema", f"homeassistant.components.recorder.migration.{func_to_patch}",
side_effect=DatabaseError("statement", {}, []), side_effect=DatabaseError("statement", {}, []),
), ),
patch( patch(
@ -248,8 +261,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(
assert recorder.util.async_migration_in_progress(hass) is False assert recorder.util.async_migration_in_progress(hass) is False
assert not move_away.called assert not move_away.called
assert len(mock_create.mock_calls) == 2 assert len(mock_create.mock_calls) == expected_pn_create
assert len(mock_dismiss.mock_calls) == 1 assert len(mock_dismiss.mock_calls) == expected_pn_dismiss
async def test_events_during_migration_are_queued( async def test_events_during_migration_are_queued(

View File

@ -2466,7 +2466,7 @@ async def test_recorder_info_bad_recorder_config(
client = await hass_ws_client() client = await hass_ws_client()
with patch("homeassistant.components.recorder.migration.migrate_schema"): with patch("homeassistant.components.recorder.migration._migrate_schema"):
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
assert not await async_setup_component( assert not await async_setup_component(
hass, recorder.DOMAIN, {recorder.DOMAIN: config} hass, recorder.DOMAIN, {recorder.DOMAIN: config}