Teach recorder data migrator base class to update MigrationChanges (#125214)

* Teach recorder data migrator base class to update MigrationChanges

* Bump migration version

* Improve test coverage

* Update migration.py

* Revert migrator version bump

* Remove unneeded change
This commit is contained in:
Erik Montnemery 2024-09-05 08:56:18 +02:00 committed by GitHub
parent 4c56cbe8c8
commit a8f2204f4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 83 deletions

View File

@ -2201,8 +2201,8 @@ class CommitBeforeMigrationTask(MigrationTask):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class NeedsMigrateResult: class DataMigrationStatus:
"""Container for the return value of BaseRunTimeMigration.needs_migrate_impl.""" """Container for data migrator status."""
needs_migrate: bool needs_migrate: bool
migration_done: bool migration_done: bool
@ -2229,36 +2229,30 @@ class BaseRunTimeMigration(ABC):
else: else:
self.migration_done(instance, session) self.migration_done(instance, session)
@retryable_database_job("migrate data", method=True)
def migrate_data(self, instance: Recorder) -> bool: def migrate_data(self, instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed.""" """Migrate some data, returns True if migration is completed."""
if result := self.migrate_data_impl(instance): status = self.migrate_data_impl(instance)
if status.migration_done:
if self.index_to_drop is not None: if self.index_to_drop is not None:
self._remove_index(instance, self.index_to_drop) table, index = self.index_to_drop
self.migration_done(instance, None) _drop_index(instance.get_session, table, index)
return result with session_scope(session=instance.get_session()) as session:
self.migration_done(instance, session)
_mark_migration_done(session, self.__class__)
return not status.needs_migrate
@staticmethod
@abstractmethod @abstractmethod
def migrate_data_impl(instance: Recorder) -> bool: def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
"""Migrate some data, returns True if migration is completed.""" """Migrate some data, return if the migration needs to run and if it is done."""
@staticmethod def migration_done(self, instance: Recorder, session: Session) -> None:
@database_job_retry_wrapper("remove index")
def _remove_index(instance: Recorder, index_to_drop: tuple[str, str]) -> None:
"""Remove indices.
Called when migration is completed.
"""
table, index = index_to_drop
_drop_index(instance.get_session, table, index)
def migration_done(self, instance: Recorder, session: Session | None) -> None:
"""Will be called after migrate returns True or if migration is not needed.""" """Will be called after migrate returns True or if migration is not needed."""
@abstractmethod @abstractmethod
def needs_migrate_impl( def needs_migrate_impl(
self, instance: Recorder, session: Session self, instance: Recorder, session: Session
) -> NeedsMigrateResult: ) -> DataMigrationStatus:
"""Return if the migration needs to run and if it is done.""" """Return if the migration needs to run and if it is done."""
def needs_migrate(self, instance: Recorder, session: Session) -> bool: def needs_migrate(self, instance: Recorder, session: Session) -> bool:
@ -2300,10 +2294,10 @@ class BaseRunTimeMigrationWithQuery(BaseRunTimeMigration):
def needs_migrate_impl( def needs_migrate_impl(
self, instance: Recorder, session: Session self, instance: Recorder, session: Session
) -> NeedsMigrateResult: ) -> DataMigrationStatus:
"""Return if the migration needs to run.""" """Return if the migration needs to run."""
needs_migrate = execute_stmt_lambda_element(session, self.needs_migrate_query()) needs_migrate = execute_stmt_lambda_element(session, self.needs_migrate_query())
return NeedsMigrateResult( return DataMigrationStatus(
needs_migrate=bool(needs_migrate), migration_done=not needs_migrate needs_migrate=bool(needs_migrate), migration_done=not needs_migrate
) )
@ -2315,9 +2309,7 @@ class StatesContextIDMigration(BaseRunTimeMigrationWithQuery):
migration_id = "state_context_id_as_binary" migration_id = "state_context_id_as_binary"
index_to_drop = ("states", "ix_states_context_id") index_to_drop = ("states", "ix_states_context_id")
@staticmethod def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@retryable_database_job("migrate states context_ids to binary format")
def migrate_data_impl(instance: Recorder) -> bool:
"""Migrate states context_ids to use binary format, return True if completed.""" """Migrate states context_ids to use binary format, return True if completed."""
_to_bytes = _context_id_to_bytes _to_bytes = _context_id_to_bytes
session_maker = instance.get_session session_maker = instance.get_session
@ -2342,13 +2334,10 @@ class StatesContextIDMigration(BaseRunTimeMigrationWithQuery):
for state_id, last_updated_ts, context_id, context_user_id, context_parent_id in states for state_id, last_updated_ts, context_id, context_user_id, context_parent_id in states
], ],
) )
# If there is more work to do return False is_done = not states
# so that we can be called again
if is_done := not states:
_mark_migration_done(session, StatesContextIDMigration)
_LOGGER.debug("Migrating states context_ids to binary format: done=%s", is_done) _LOGGER.debug("Migrating states context_ids to binary format: done=%s", is_done)
return is_done return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run.""" """Return the query to check if the migration needs to run."""
@ -2362,9 +2351,7 @@ class EventsContextIDMigration(BaseRunTimeMigrationWithQuery):
migration_id = "event_context_id_as_binary" migration_id = "event_context_id_as_binary"
index_to_drop = ("events", "ix_events_context_id") index_to_drop = ("events", "ix_events_context_id")
@staticmethod def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@retryable_database_job("migrate events context_ids to binary format")
def migrate_data_impl(instance: Recorder) -> bool:
"""Migrate events context_ids to use binary format, return True if completed.""" """Migrate events context_ids to use binary format, return True if completed."""
_to_bytes = _context_id_to_bytes _to_bytes = _context_id_to_bytes
session_maker = instance.get_session session_maker = instance.get_session
@ -2389,13 +2376,10 @@ class EventsContextIDMigration(BaseRunTimeMigrationWithQuery):
for event_id, time_fired_ts, context_id, context_user_id, context_parent_id in events for event_id, time_fired_ts, context_id, context_user_id, context_parent_id in events
], ],
) )
# If there is more work to do return False is_done = not events
# so that we can be called again
if is_done := not events:
_mark_migration_done(session, EventsContextIDMigration)
_LOGGER.debug("Migrating events context_ids to binary format: done=%s", is_done) _LOGGER.debug("Migrating events context_ids to binary format: done=%s", is_done)
return is_done return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run.""" """Return the query to check if the migration needs to run."""
@ -2412,9 +2396,7 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery):
# no new pending event_types about to be added to # no new pending event_types about to be added to
# the db since this happens live # the db since this happens live
@staticmethod def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@retryable_database_job("migrate events event_types to event_type_ids")
def migrate_data_impl(instance: Recorder) -> bool:
"""Migrate event_type to event_type_ids, return True if completed.""" """Migrate event_type to event_type_ids, return True if completed."""
session_maker = instance.get_session session_maker = instance.get_session
_LOGGER.debug("Migrating event_types") _LOGGER.debug("Migrating event_types")
@ -2467,15 +2449,12 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery):
], ],
) )
# If there is more work to do return False is_done = not events
# so that we can be called again
if is_done := not events:
_mark_migration_done(session, EventTypeIDMigration)
_LOGGER.debug("Migrating event_types done=%s", is_done) _LOGGER.debug("Migrating event_types done=%s", is_done)
return is_done return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def migration_done(self, instance: Recorder, session: Session | None) -> None: def migration_done(self, instance: Recorder, session: Session) -> None:
"""Will be called after migrate returns True.""" """Will be called after migrate returns True."""
_LOGGER.debug("Activating event_types manager as all data is migrated") _LOGGER.debug("Activating event_types manager as all data is migrated")
instance.event_type_manager.active = True instance.event_type_manager.active = True
@ -2495,9 +2474,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
# no new pending states_meta about to be added to # no new pending states_meta about to be added to
# the db since this happens live # the db since this happens live
@staticmethod def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@retryable_database_job("migrate states entity_ids to states_meta")
def migrate_data_impl(instance: Recorder) -> bool:
"""Migrate entity_ids to states_meta, return True if completed. """Migrate entity_ids to states_meta, return True if completed.
We do this in two steps because we need the history queries to work We do this in two steps because we need the history queries to work
@ -2560,15 +2537,12 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
], ],
) )
# If there is more work to do return False is_done = not states
# so that we can be called again
if is_done := not states:
_mark_migration_done(session, EntityIDMigration)
_LOGGER.debug("Migrating entity_ids done=%s", is_done) _LOGGER.debug("Migrating entity_ids done=%s", is_done)
return is_done return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def migration_done(self, instance: Recorder, _session: Session | None) -> None: def migration_done(self, instance: Recorder, session: Session) -> None:
"""Will be called after migrate returns True.""" """Will be called after migrate returns True."""
# The migration has finished, now we start the post migration # The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table # to remove the old entity_id data from the states table
@ -2576,15 +2550,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
# so we set active to True # so we set active to True
_LOGGER.debug("Activating states_meta manager as all data is migrated") _LOGGER.debug("Activating states_meta manager as all data is migrated")
instance.states_meta_manager.active = True instance.states_meta_manager.active = True
session_generator = ( with contextlib.suppress(SQLAlchemyError):
contextlib.nullcontext(_session)
if _session
else session_scope(session=instance.get_session())
)
with (
contextlib.suppress(SQLAlchemyError),
session_generator as session,
):
# If ix_states_entity_id_last_updated_ts still exists # If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration # on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not # finished by the EntityIDPostMigrationTask did not
@ -2609,9 +2575,7 @@ class EventIDPostMigration(BaseRunTimeMigration):
task = MigrationTask task = MigrationTask
migration_version = 2 migration_version = 2
@staticmethod def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@retryable_database_job("cleanup_legacy_event_ids")
def migrate_data_impl(instance: Recorder) -> bool:
"""Remove old event_id index from states, returns True if completed. """Remove old event_id index from states, returns True if completed.
We used to link states to events using the event_id column but we no We used to link states to events using the event_id column but we no
@ -2651,9 +2615,8 @@ class EventIDPostMigration(BaseRunTimeMigration):
if fk_remove_ok: if fk_remove_ok:
_drop_index(session_maker, "states", LEGACY_STATES_EVENT_ID_INDEX) _drop_index(session_maker, "states", LEGACY_STATES_EVENT_ID_INDEX)
instance.use_legacy_events_index = False instance.use_legacy_events_index = False
_mark_migration_done(session, EventIDPostMigration)
return True return DataMigrationStatus(needs_migrate=False, migration_done=fk_remove_ok)
@staticmethod @staticmethod
def _legacy_event_id_foreign_key_exists(instance: Recorder) -> bool: def _legacy_event_id_foreign_key_exists(instance: Recorder) -> bool:
@ -2674,16 +2637,16 @@ class EventIDPostMigration(BaseRunTimeMigration):
def needs_migrate_impl( def needs_migrate_impl(
self, instance: Recorder, session: Session self, instance: Recorder, session: Session
) -> NeedsMigrateResult: ) -> DataMigrationStatus:
"""Return if the migration needs to run.""" """Return if the migration needs to run."""
if self.schema_version <= LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION: if self.schema_version <= LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
return NeedsMigrateResult(needs_migrate=False, migration_done=False) return DataMigrationStatus(needs_migrate=False, migration_done=False)
if get_index_by_name( if get_index_by_name(
session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX
) is not None or self._legacy_event_id_foreign_key_exists(instance): ) is not None or self._legacy_event_id_foreign_key_exists(instance):
instance.use_legacy_events_index = True instance.use_legacy_events_index = True
return NeedsMigrateResult(needs_migrate=True, migration_done=False) return DataMigrationStatus(needs_migrate=True, migration_done=False)
return NeedsMigrateResult(needs_migrate=False, migration_done=True) return DataMigrationStatus(needs_migrate=False, migration_done=True)
@dataclass(slots=True) @dataclass(slots=True)

View File

@ -645,23 +645,24 @@ def _is_retryable_error(instance: Recorder, err: OperationalError) -> bool:
type _FuncType[_T, **_P, _R] = Callable[Concatenate[_T, _P], _R] type _FuncType[_T, **_P, _R] = Callable[Concatenate[_T, _P], _R]
type _FuncOrMethType[**_P, _R] = Callable[_P, _R]
def retryable_database_job[_RecorderT: Recorder, **_P]( def retryable_database_job[**_P](
description: str, description: str, method: bool = False
) -> Callable[[_FuncType[_RecorderT, _P, bool]], _FuncType[_RecorderT, _P, bool]]: ) -> Callable[[_FuncOrMethType[_P, bool]], _FuncOrMethType[_P, bool]]:
"""Try to execute a database job. """Try to execute a database job.
The job should return True if it finished, and False if it needs to be rescheduled. The job should return True if it finished, and False if it needs to be rescheduled.
""" """
recorder_pos = 1 if method else 0
def decorator( def decorator(job: _FuncOrMethType[_P, bool]) -> _FuncOrMethType[_P, bool]:
job: _FuncType[_RecorderT, _P, bool],
) -> _FuncType[_RecorderT, _P, bool]:
@functools.wraps(job) @functools.wraps(job)
def wrapper(instance: _RecorderT, *args: _P.args, **kwargs: _P.kwargs) -> bool: def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> bool:
instance: Recorder = args[recorder_pos] # type: ignore[assignment]
try: try:
return job(instance, *args, **kwargs) return job(*args, **kwargs)
except OperationalError as err: except OperationalError as err:
if _is_retryable_error(instance, err): if _is_retryable_error(instance, err):
assert isinstance(err.orig, BaseException) # noqa: PT017 assert isinstance(err.orig, BaseException) # noqa: PT017