diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index a7e968fe544..f77305277c8 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -716,6 +716,15 @@ class Recorder(threading.Thread): self._event_session_has_pending_writes = True session.add(obj) + def _notify_migration_failed(self) -> None: + """Notify the user schema migration failed.""" + persistent_notification.create( + self.hass, + "The database migration failed, check [the logs](/config/logs).", + "Database Migration Failed", + "recorder_database_migration", + ) + def _run(self) -> None: """Start processing events to save.""" thread_id = threading.get_ident() @@ -741,26 +750,36 @@ class Recorder(threading.Thread): self.migration_is_live = migration.live_migration(schema_status) self.hass.add_job(self.async_connection_success) - database_was_ready = self.migration_is_live or schema_status.valid - - if database_was_ready: - # If the migrate is live or the schema is valid, we need to - # wait for startup to complete. If its not live, we need to continue - # on. - self._activate_and_set_db_ready() - - # We wait to start a live migration until startup has finished - # since it can be cpu intensive and we do not want it to compete - # with startup which is also cpu intensive - if self._wait_startup_or_shutdown() is SHUTDOWN_TASK: - # Shutdown happened before Home Assistant finished starting - self.migration_in_progress = False - # Make sure we cleanly close the run if - # we restart before startup finishes - return + # First do non-live migration steps, if needed if not schema_status.valid: - if self._migrate_schema_and_setup_run(schema_status): + result, schema_status = self._migrate_schema_offline(schema_status) + if not result: + self._notify_migration_failed() + self.migration_in_progress = False + return + self.schema_version = schema_status.current_version + # Non-live migration is now completed, remaining steps are live + self.migration_is_live = True + + # After non-live migration, activate the recorder + self._activate_and_set_db_ready(schema_status) + # We wait to start a live migration until startup has finished + # since it can be cpu intensive and we do not want it to compete + # with startup which is also cpu intensive + if self._wait_startup_or_shutdown() is SHUTDOWN_TASK: + # Shutdown happened before Home Assistant finished starting + self.migration_in_progress = False + # Make sure we cleanly close the run if + # we restart before startup finishes + return + + # Do live migration steps, if needed + if not schema_status.valid: + result, schema_status = self._migrate_schema_live_and_setup_run( + schema_status + ) + if result: self.schema_version = SCHEMA_VERSION if not self._event_listener: # If the schema migration takes so long that the end @@ -768,17 +787,9 @@ class Recorder(threading.Thread): # was True, we need to reinitialize the listener. self.hass.add_job(self.async_initialize) else: - persistent_notification.create( - self.hass, - "The database migration failed, check [the logs](/config/logs).", - "Database Migration Failed", - "recorder_database_migration", - ) + self._notify_migration_failed() return - if not database_was_ready: - self._activate_and_set_db_ready() - # Catch up with missed statistics self._schedule_compile_missing_statistics() _LOGGER.debug("Recorder processing the queue") @@ -786,7 +797,9 @@ class Recorder(threading.Thread): self.hass.add_job(self._async_set_recorder_ready_migration_done) self._run_event_loop() - def _activate_and_set_db_ready(self) -> None: + def _activate_and_set_db_ready( + self, schema_status: migration.SchemaValidationStatus + ) -> None: """Activate the table managers or schedule migrations and mark the db as ready.""" with session_scope(session=self.get_session()) as session: # Prime the statistics meta manager as soon as possible @@ -808,7 +821,7 @@ class Recorder(threading.Thread): EventTypeIDMigration, EntityIDMigration, ): - migrator = migrator_cls(schema_version, migration_changes) + migrator = migrator_cls(schema_status.start_version, migration_changes) migrator.do_migrate(self, session) if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION: @@ -947,9 +960,15 @@ class Recorder(threading.Thread): """Set the migration started event.""" self.async_migration_event.set() - def _migrate_schema_and_setup_run( + def _migrate_schema_offline( self, schema_status: migration.SchemaValidationStatus - ) -> bool: + ) -> tuple[bool, migration.SchemaValidationStatus]: + """Migrate schema to the latest version.""" + return self._migrate_schema(schema_status, False) + + def _migrate_schema_live_and_setup_run( + self, schema_status: migration.SchemaValidationStatus + ) -> tuple[bool, migration.SchemaValidationStatus]: """Migrate schema to the latest version.""" persistent_notification.create( self.hass, @@ -965,26 +984,40 @@ class Recorder(threading.Thread): "recorder_database_migration", ) self.hass.add_job(self._async_migration_started) - try: - assert self.engine is not None - migration.migrate_schema( + migration_result, schema_status = self._migrate_schema(schema_status, True) + if migration_result: + self._setup_run() + return migration_result, schema_status + finally: + self.migration_in_progress = False + persistent_notification.dismiss(self.hass, "recorder_database_migration") + + def _migrate_schema( + self, + schema_status: migration.SchemaValidationStatus, + live: bool, + ) -> tuple[bool, migration.SchemaValidationStatus]: + """Migrate schema to the latest version.""" + assert self.engine is not None + try: + if live: + migrator = migration.migrate_schema_live + else: + migrator = migration.migrate_schema_non_live + new_schema_status = migrator( self, self.hass, self.engine, self.get_session, schema_status ) except exc.DatabaseError as err: if self._handle_database_error(err): - return True + return (True, schema_status) _LOGGER.exception("Database error during schema migration") - return False + return (False, schema_status) except Exception: _LOGGER.exception("Error during schema migration") - return False + return (False, schema_status) else: - self._setup_run() - return True - finally: - self.migration_in_progress = False - persistent_notification.dismiss(self.hass, "recorder_database_migration") + return (True, new_schema_status) def _lock_database(self, task: DatabaseLockTask) -> None: @callback diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index d0beb4f9895..0af0788a42a 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -188,12 +188,13 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int | None: return None -@dataclass +@dataclass(frozen=True) class SchemaValidationStatus: """Store schema validation status.""" current_version: int schema_errors: set[str] + start_version: int valid: bool @@ -224,7 +225,9 @@ def validate_db_schema( valid = is_current and not schema_errors - return SchemaValidationStatus(current_version, schema_errors, valid) + return SchemaValidationStatus( + current_version, schema_errors, current_version, valid + ) def _find_schema_errors( @@ -260,35 +263,30 @@ def pre_migrate_schema(engine: Engine) -> None: ) -def migrate_schema( +def _migrate_schema( instance: Recorder, hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session], schema_status: SchemaValidationStatus, -) -> None: + end_version: int, +) -> SchemaValidationStatus: """Check if the schema needs to be upgraded.""" current_version = schema_status.current_version - if current_version != SCHEMA_VERSION: + start_version = schema_status.start_version + + if current_version < end_version: _LOGGER.warning( "Database is about to upgrade from schema version: %s to: %s", current_version, - SCHEMA_VERSION, + end_version, ) - db_ready = False - for version in range(current_version, SCHEMA_VERSION): - if ( - live_migration(dataclass_replace(schema_status, current_version=version)) - and not db_ready - ): - db_ready = True - instance.migration_is_live = True - hass.add_job(instance.async_set_db_ready) + schema_status = dataclass_replace(schema_status, current_version=end_version) + + for version in range(current_version, end_version): new_version = version + 1 _LOGGER.info("Upgrading recorder db schema to version %s", new_version) - _apply_update( - instance, hass, engine, session_maker, new_version, current_version - ) + _apply_update(instance, hass, engine, session_maker, new_version, start_version) with session_scope(session=session_maker()) as session: session.add(SchemaChanges(schema_version=new_version)) @@ -296,6 +294,37 @@ def migrate_schema( # so its clear that the upgrade is done _LOGGER.warning("Upgrade to version %s done", new_version) + return schema_status + + +def migrate_schema_non_live( + instance: Recorder, + hass: HomeAssistant, + engine: Engine, + session_maker: Callable[[], Session], + schema_status: SchemaValidationStatus, +) -> SchemaValidationStatus: + """Check if the schema needs to be upgraded.""" + end_version = LIVE_MIGRATION_MIN_SCHEMA_VERSION - 1 + return _migrate_schema( + instance, hass, engine, session_maker, schema_status, end_version + ) + + +def migrate_schema_live( + instance: Recorder, + hass: HomeAssistant, + engine: Engine, + session_maker: Callable[[], Session], + schema_status: SchemaValidationStatus, +) -> SchemaValidationStatus: + """Check if the schema needs to be upgraded.""" + end_version = SCHEMA_VERSION + schema_status = _migrate_schema( + instance, hass, engine, session_maker, schema_status, end_version + ) + + # Repairs are currently done during the live migration if schema_errors := schema_status.schema_errors: _LOGGER.warning( "Database is about to correct DB schema errors: %s", @@ -305,12 +334,15 @@ def migrate_schema( states_correct_db_schema(instance, schema_errors) events_correct_db_schema(instance, schema_errors) - if current_version != SCHEMA_VERSION: - instance.queue_task(PostSchemaMigrationTask(current_version, SCHEMA_VERSION)) + start_version = schema_status.start_version + if start_version != SCHEMA_VERSION: + instance.queue_task(PostSchemaMigrationTask(start_version, SCHEMA_VERSION)) # Make sure the post schema migration task is committed in case # the next task does not have commit_before = True instance.queue_task(CommitTask()) + return schema_status + def _create_index( session_maker: Callable[[], Session], table_name: str, index_name: str diff --git a/tests/components/recorder/conftest.py b/tests/components/recorder/conftest.py index fb58ad581d3..f562ba163ba 100644 --- a/tests/components/recorder/conftest.py +++ b/tests/components/recorder/conftest.py @@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator, Generator from dataclasses import dataclass +from functools import partial import threading from unittest.mock import Mock, patch @@ -69,15 +70,16 @@ async def instrument_migration( ) -> AsyncGenerator[InstrumentedMigration]: """Instrument recorder migration.""" - real_migrate_schema = recorder.migration.migrate_schema + real_migrate_schema_live = recorder.migration.migrate_schema_live + real_migrate_schema_non_live = recorder.migration.migrate_schema_non_live real_apply_update = recorder.migration._apply_update - def _instrument_migrate_schema(*args): + def _instrument_migrate_schema(real_func, *args): """Control migration progress and check results.""" instrumented_migration.migration_started.set() try: - real_migrate_schema(*args) + migration_result = real_func(*args) except Exception: instrumented_migration.migration_done.set() raise @@ -92,6 +94,7 @@ async def instrument_migration( ) instrumented_migration.migration_version = res.schema_version instrumented_migration.migration_done.set() + return migration_result def _instrument_apply_update(*args): """Control migration progress.""" @@ -100,8 +103,12 @@ async def instrument_migration( with ( patch( - "homeassistant.components.recorder.migration.migrate_schema", - wraps=_instrument_migrate_schema, + "homeassistant.components.recorder.migration.migrate_schema_live", + wraps=partial(_instrument_migrate_schema, real_migrate_schema_live), + ), + patch( + "homeassistant.components.recorder.migration.migrate_schema_non_live", + wraps=partial(_instrument_migrate_schema, real_migrate_schema_non_live), ), patch( "homeassistant.components.recorder.migration._apply_update", diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 5715e994d2e..3cd4c3ab4b6 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -2569,7 +2569,13 @@ async def test_clean_shutdown_when_recorder_thread_raises_during_validate_db_sch assert instance.engine is None -async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -> None: +@pytest.mark.parametrize( + ("func_to_patch", "expected_setup_result"), + [("migrate_schema_non_live", False), ("migrate_schema_live", False)], +) +async def test_clean_shutdown_when_schema_migration_fails( + hass: HomeAssistant, func_to_patch: str, expected_setup_result: bool +) -> None: """Test we still shutdown cleanly when schema migration fails.""" with ( patch.object( @@ -2580,13 +2586,13 @@ async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) - patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch.object( migration, - "migrate_schema", + func_to_patch, side_effect=Exception, ), ): if recorder.DOMAIN not in hass.data: recorder_helper.async_initialize_recorder(hass) - assert await async_setup_component( + setup_result = await async_setup_component( hass, recorder.DOMAIN, { @@ -2597,6 +2603,7 @@ async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) - } }, ) + assert setup_result == expected_setup_result await hass.async_block_till_done() instance = recorder.get_instance(hass) diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index f32f5c4aaaf..3bfbcad35fc 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -184,7 +184,7 @@ async def test_database_migration_encounters_corruption( side_effect=[False], ), patch( - "homeassistant.components.recorder.migration.migrate_schema", + "homeassistant.components.recorder.migration.migrate_schema_non_live", side_effect=sqlite3_exception, ), patch( @@ -201,13 +201,26 @@ async def test_database_migration_encounters_corruption( @pytest.mark.parametrize( - ("live_migration", "expected_setup_result"), [(True, True), (False, False)] + ( + "live_migration", + "func_to_patch", + "expected_setup_result", + "expected_pn_create", + "expected_pn_dismiss", + ), + [ + (True, "migrate_schema_live", True, 2, 1), + (False, "migrate_schema_non_live", False, 1, 0), + ], ) async def test_database_migration_encounters_corruption_not_sqlite( hass: HomeAssistant, async_setup_recorder_instance: RecorderInstanceGenerator, live_migration: bool, + func_to_patch: str, expected_setup_result: bool, + expected_pn_create: int, + expected_pn_dismiss: int, ) -> None: """Test we fail on database error when we cannot recover.""" assert recorder.util.async_migration_in_progress(hass) is False @@ -218,7 +231,7 @@ async def test_database_migration_encounters_corruption_not_sqlite( side_effect=[False], ), patch( - "homeassistant.components.recorder.migration.migrate_schema", + f"homeassistant.components.recorder.migration.{func_to_patch}", side_effect=DatabaseError("statement", {}, []), ), patch( @@ -248,8 +261,8 @@ async def test_database_migration_encounters_corruption_not_sqlite( assert recorder.util.async_migration_in_progress(hass) is False assert not move_away.called - assert len(mock_create.mock_calls) == 2 - assert len(mock_dismiss.mock_calls) == 1 + assert len(mock_create.mock_calls) == expected_pn_create + assert len(mock_dismiss.mock_calls) == expected_pn_dismiss async def test_events_during_migration_are_queued( diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index 5f3b1b35c78..1bf56372620 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -2466,7 +2466,7 @@ async def test_recorder_info_bad_recorder_config( client = await hass_ws_client() - with patch("homeassistant.components.recorder.migration.migrate_schema"): + with patch("homeassistant.components.recorder.migration._migrate_schema"): recorder_helper.async_initialize_recorder(hass) assert not await async_setup_component( hass, recorder.DOMAIN, {recorder.DOMAIN: config}