Refactor recorder schema migration (#122372)

* Refactor recorder schema migration

* Simplify

* Remove unused imports

* Refactor _migrate_schema according to review comments

* Add comment
This commit is contained in:
Erik Montnemery 2024-07-22 16:53:54 +02:00 committed by GitHub
parent c73e7ae178
commit e8b88557ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 168 additions and 76 deletions

View File

@ -716,6 +716,15 @@ class Recorder(threading.Thread):
self._event_session_has_pending_writes = True
session.add(obj)
def _notify_migration_failed(self) -> None:
"""Notify the user schema migration failed."""
persistent_notification.create(
self.hass,
"The database migration failed, check [the logs](/config/logs).",
"Database Migration Failed",
"recorder_database_migration",
)
def _run(self) -> None:
"""Start processing events to save."""
thread_id = threading.get_ident()
@ -741,14 +750,20 @@ class Recorder(threading.Thread):
self.migration_is_live = migration.live_migration(schema_status)
self.hass.add_job(self.async_connection_success)
database_was_ready = self.migration_is_live or schema_status.valid
if database_was_ready:
# If the migrate is live or the schema is valid, we need to
# wait for startup to complete. If its not live, we need to continue
# on.
self._activate_and_set_db_ready()
# First do non-live migration steps, if needed
if not schema_status.valid:
result, schema_status = self._migrate_schema_offline(schema_status)
if not result:
self._notify_migration_failed()
self.migration_in_progress = False
return
self.schema_version = schema_status.current_version
# Non-live migration is now completed, remaining steps are live
self.migration_is_live = True
# After non-live migration, activate the recorder
self._activate_and_set_db_ready(schema_status)
# We wait to start a live migration until startup has finished
# since it can be cpu intensive and we do not want it to compete
# with startup which is also cpu intensive
@ -759,8 +774,12 @@ class Recorder(threading.Thread):
# we restart before startup finishes
return
# Do live migration steps, if needed
if not schema_status.valid:
if self._migrate_schema_and_setup_run(schema_status):
result, schema_status = self._migrate_schema_live_and_setup_run(
schema_status
)
if result:
self.schema_version = SCHEMA_VERSION
if not self._event_listener:
# If the schema migration takes so long that the end
@ -768,17 +787,9 @@ class Recorder(threading.Thread):
# was True, we need to reinitialize the listener.
self.hass.add_job(self.async_initialize)
else:
persistent_notification.create(
self.hass,
"The database migration failed, check [the logs](/config/logs).",
"Database Migration Failed",
"recorder_database_migration",
)
self._notify_migration_failed()
return
if not database_was_ready:
self._activate_and_set_db_ready()
# Catch up with missed statistics
self._schedule_compile_missing_statistics()
_LOGGER.debug("Recorder processing the queue")
@ -786,7 +797,9 @@ class Recorder(threading.Thread):
self.hass.add_job(self._async_set_recorder_ready_migration_done)
self._run_event_loop()
def _activate_and_set_db_ready(self) -> None:
def _activate_and_set_db_ready(
self, schema_status: migration.SchemaValidationStatus
) -> None:
"""Activate the table managers or schedule migrations and mark the db as ready."""
with session_scope(session=self.get_session()) as session:
# Prime the statistics meta manager as soon as possible
@ -808,7 +821,7 @@ class Recorder(threading.Thread):
EventTypeIDMigration,
EntityIDMigration,
):
migrator = migrator_cls(schema_version, migration_changes)
migrator = migrator_cls(schema_status.start_version, migration_changes)
migrator.do_migrate(self, session)
if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
@ -947,9 +960,15 @@ class Recorder(threading.Thread):
"""Set the migration started event."""
self.async_migration_event.set()
def _migrate_schema_and_setup_run(
def _migrate_schema_offline(
self, schema_status: migration.SchemaValidationStatus
) -> bool:
) -> tuple[bool, migration.SchemaValidationStatus]:
"""Migrate schema to the latest version."""
return self._migrate_schema(schema_status, False)
def _migrate_schema_live_and_setup_run(
self, schema_status: migration.SchemaValidationStatus
) -> tuple[bool, migration.SchemaValidationStatus]:
"""Migrate schema to the latest version."""
persistent_notification.create(
self.hass,
@ -965,26 +984,40 @@ class Recorder(threading.Thread):
"recorder_database_migration",
)
self.hass.add_job(self._async_migration_started)
try:
migration_result, schema_status = self._migrate_schema(schema_status, True)
if migration_result:
self._setup_run()
return migration_result, schema_status
finally:
self.migration_in_progress = False
persistent_notification.dismiss(self.hass, "recorder_database_migration")
def _migrate_schema(
self,
schema_status: migration.SchemaValidationStatus,
live: bool,
) -> tuple[bool, migration.SchemaValidationStatus]:
"""Migrate schema to the latest version."""
assert self.engine is not None
migration.migrate_schema(
try:
if live:
migrator = migration.migrate_schema_live
else:
migrator = migration.migrate_schema_non_live
new_schema_status = migrator(
self, self.hass, self.engine, self.get_session, schema_status
)
except exc.DatabaseError as err:
if self._handle_database_error(err):
return True
return (True, schema_status)
_LOGGER.exception("Database error during schema migration")
return False
return (False, schema_status)
except Exception:
_LOGGER.exception("Error during schema migration")
return False
return (False, schema_status)
else:
self._setup_run()
return True
finally:
self.migration_in_progress = False
persistent_notification.dismiss(self.hass, "recorder_database_migration")
return (True, new_schema_status)
def _lock_database(self, task: DatabaseLockTask) -> None:
@callback

View File

@ -188,12 +188,13 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
return None
@dataclass
@dataclass(frozen=True)
class SchemaValidationStatus:
"""Store schema validation status."""
current_version: int
schema_errors: set[str]
start_version: int
valid: bool
@ -224,7 +225,9 @@ def validate_db_schema(
valid = is_current and not schema_errors
return SchemaValidationStatus(current_version, schema_errors, valid)
return SchemaValidationStatus(
current_version, schema_errors, current_version, valid
)
def _find_schema_errors(
@ -260,35 +263,30 @@ def pre_migrate_schema(engine: Engine) -> None:
)
def migrate_schema(
def _migrate_schema(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> None:
end_version: int,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
current_version = schema_status.current_version
if current_version != SCHEMA_VERSION:
start_version = schema_status.start_version
if current_version < end_version:
_LOGGER.warning(
"Database is about to upgrade from schema version: %s to: %s",
current_version,
SCHEMA_VERSION,
end_version,
)
db_ready = False
for version in range(current_version, SCHEMA_VERSION):
if (
live_migration(dataclass_replace(schema_status, current_version=version))
and not db_ready
):
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_db_ready)
schema_status = dataclass_replace(schema_status, current_version=end_version)
for version in range(current_version, end_version):
new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(
instance, hass, engine, session_maker, new_version, current_version
)
_apply_update(instance, hass, engine, session_maker, new_version, start_version)
with session_scope(session=session_maker()) as session:
session.add(SchemaChanges(schema_version=new_version))
@ -296,6 +294,37 @@ def migrate_schema(
# so its clear that the upgrade is done
_LOGGER.warning("Upgrade to version %s done", new_version)
return schema_status
def migrate_schema_non_live(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
end_version = LIVE_MIGRATION_MIN_SCHEMA_VERSION - 1
return _migrate_schema(
instance, hass, engine, session_maker, schema_status, end_version
)
def migrate_schema_live(
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
) -> SchemaValidationStatus:
"""Check if the schema needs to be upgraded."""
end_version = SCHEMA_VERSION
schema_status = _migrate_schema(
instance, hass, engine, session_maker, schema_status, end_version
)
# Repairs are currently done during the live migration
if schema_errors := schema_status.schema_errors:
_LOGGER.warning(
"Database is about to correct DB schema errors: %s",
@ -305,12 +334,15 @@ def migrate_schema(
states_correct_db_schema(instance, schema_errors)
events_correct_db_schema(instance, schema_errors)
if current_version != SCHEMA_VERSION:
instance.queue_task(PostSchemaMigrationTask(current_version, SCHEMA_VERSION))
start_version = schema_status.start_version
if start_version != SCHEMA_VERSION:
instance.queue_task(PostSchemaMigrationTask(start_version, SCHEMA_VERSION))
# Make sure the post schema migration task is committed in case
# the next task does not have commit_before = True
instance.queue_task(CommitTask())
return schema_status
def _create_index(
session_maker: Callable[[], Session], table_name: str, index_name: str

View File

@ -2,6 +2,7 @@
from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass
from functools import partial
import threading
from unittest.mock import Mock, patch
@ -69,15 +70,16 @@ async def instrument_migration(
) -> AsyncGenerator[InstrumentedMigration]:
"""Instrument recorder migration."""
real_migrate_schema = recorder.migration.migrate_schema
real_migrate_schema_live = recorder.migration.migrate_schema_live
real_migrate_schema_non_live = recorder.migration.migrate_schema_non_live
real_apply_update = recorder.migration._apply_update
def _instrument_migrate_schema(*args):
def _instrument_migrate_schema(real_func, *args):
"""Control migration progress and check results."""
instrumented_migration.migration_started.set()
try:
real_migrate_schema(*args)
migration_result = real_func(*args)
except Exception:
instrumented_migration.migration_done.set()
raise
@ -92,6 +94,7 @@ async def instrument_migration(
)
instrumented_migration.migration_version = res.schema_version
instrumented_migration.migration_done.set()
return migration_result
def _instrument_apply_update(*args):
"""Control migration progress."""
@ -100,8 +103,12 @@ async def instrument_migration(
with (
patch(
"homeassistant.components.recorder.migration.migrate_schema",
wraps=_instrument_migrate_schema,
"homeassistant.components.recorder.migration.migrate_schema_live",
wraps=partial(_instrument_migrate_schema, real_migrate_schema_live),
),
patch(
"homeassistant.components.recorder.migration.migrate_schema_non_live",
wraps=partial(_instrument_migrate_schema, real_migrate_schema_non_live),
),
patch(
"homeassistant.components.recorder.migration._apply_update",

View File

@ -2569,7 +2569,13 @@ async def test_clean_shutdown_when_recorder_thread_raises_during_validate_db_sch
assert instance.engine is None
async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -> None:
@pytest.mark.parametrize(
("func_to_patch", "expected_setup_result"),
[("migrate_schema_non_live", False), ("migrate_schema_live", False)],
)
async def test_clean_shutdown_when_schema_migration_fails(
hass: HomeAssistant, func_to_patch: str, expected_setup_result: bool
) -> None:
"""Test we still shutdown cleanly when schema migration fails."""
with (
patch.object(
@ -2580,13 +2586,13 @@ async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -
patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True),
patch.object(
migration,
"migrate_schema",
func_to_patch,
side_effect=Exception,
),
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
setup_result = await async_setup_component(
hass,
recorder.DOMAIN,
{
@ -2597,6 +2603,7 @@ async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -
}
},
)
assert setup_result == expected_setup_result
await hass.async_block_till_done()
instance = recorder.get_instance(hass)

View File

@ -184,7 +184,7 @@ async def test_database_migration_encounters_corruption(
side_effect=[False],
),
patch(
"homeassistant.components.recorder.migration.migrate_schema",
"homeassistant.components.recorder.migration.migrate_schema_non_live",
side_effect=sqlite3_exception,
),
patch(
@ -201,13 +201,26 @@ async def test_database_migration_encounters_corruption(
@pytest.mark.parametrize(
("live_migration", "expected_setup_result"), [(True, True), (False, False)]
(
"live_migration",
"func_to_patch",
"expected_setup_result",
"expected_pn_create",
"expected_pn_dismiss",
),
[
(True, "migrate_schema_live", True, 2, 1),
(False, "migrate_schema_non_live", False, 1, 0),
],
)
async def test_database_migration_encounters_corruption_not_sqlite(
hass: HomeAssistant,
async_setup_recorder_instance: RecorderInstanceGenerator,
live_migration: bool,
func_to_patch: str,
expected_setup_result: bool,
expected_pn_create: int,
expected_pn_dismiss: int,
) -> None:
"""Test we fail on database error when we cannot recover."""
assert recorder.util.async_migration_in_progress(hass) is False
@ -218,7 +231,7 @@ async def test_database_migration_encounters_corruption_not_sqlite(
side_effect=[False],
),
patch(
"homeassistant.components.recorder.migration.migrate_schema",
f"homeassistant.components.recorder.migration.{func_to_patch}",
side_effect=DatabaseError("statement", {}, []),
),
patch(
@ -248,8 +261,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(
assert recorder.util.async_migration_in_progress(hass) is False
assert not move_away.called
assert len(mock_create.mock_calls) == 2
assert len(mock_dismiss.mock_calls) == 1
assert len(mock_create.mock_calls) == expected_pn_create
assert len(mock_dismiss.mock_calls) == expected_pn_dismiss
async def test_events_during_migration_are_queued(

View File

@ -2466,7 +2466,7 @@ async def test_recorder_info_bad_recorder_config(
client = await hass_ws_client()
with patch("homeassistant.components.recorder.migration.migrate_schema"):
with patch("homeassistant.components.recorder.migration._migrate_schema"):
recorder_helper.async_initialize_recorder(hass)
assert not await async_setup_component(
hass, recorder.DOMAIN, {recorder.DOMAIN: config}