Teach recorder data migrator base class to update MigrationChanges (#125214)

* Teach recorder data migrator base class to update MigrationChanges

* Bump migration version

* Improve test coverage

* Update migration.py

* Revert migrator version bump

* Remove unneeded change
This commit is contained in:
Erik Montnemery 2024-09-05 08:56:18 +02:00 committed by GitHub
parent 4c56cbe8c8
commit a8f2204f4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 83 deletions

View File

@ -2201,8 +2201,8 @@ class CommitBeforeMigrationTask(MigrationTask):
@dataclass(frozen=True, kw_only=True)
class NeedsMigrateResult:
"""Container for the return value of BaseRunTimeMigration.needs_migrate_impl."""
class DataMigrationStatus:
"""Container for data migrator status."""
needs_migrate: bool
migration_done: bool
@ -2229,36 +2229,30 @@ class BaseRunTimeMigration(ABC):
else:
self.migration_done(instance, session)
@retryable_database_job("migrate data", method=True)
def migrate_data(self, instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
if result := self.migrate_data_impl(instance):
status = self.migrate_data_impl(instance)
if status.migration_done:
if self.index_to_drop is not None:
self._remove_index(instance, self.index_to_drop)
self.migration_done(instance, None)
return result
table, index = self.index_to_drop
_drop_index(instance.get_session, table, index)
with session_scope(session=instance.get_session()) as session:
self.migration_done(instance, session)
_mark_migration_done(session, self.__class__)
return not status.needs_migrate
@staticmethod
@abstractmethod
def migrate_data_impl(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
"""Migrate some data, return if the migration needs to run and if it is done."""
@staticmethod
@database_job_retry_wrapper("remove index")
def _remove_index(instance: Recorder, index_to_drop: tuple[str, str]) -> None:
"""Remove indices.
Called when migration is completed.
"""
table, index = index_to_drop
_drop_index(instance.get_session, table, index)
def migration_done(self, instance: Recorder, session: Session | None) -> None:
def migration_done(self, instance: Recorder, session: Session) -> None:
"""Will be called after migrate returns True or if migration is not needed."""
@abstractmethod
def needs_migrate_impl(
self, instance: Recorder, session: Session
) -> NeedsMigrateResult:
) -> DataMigrationStatus:
"""Return if the migration needs to run and if it is done."""
def needs_migrate(self, instance: Recorder, session: Session) -> bool:
@ -2300,10 +2294,10 @@ class BaseRunTimeMigrationWithQuery(BaseRunTimeMigration):
def needs_migrate_impl(
self, instance: Recorder, session: Session
) -> NeedsMigrateResult:
) -> DataMigrationStatus:
"""Return if the migration needs to run."""
needs_migrate = execute_stmt_lambda_element(session, self.needs_migrate_query())
return NeedsMigrateResult(
return DataMigrationStatus(
needs_migrate=bool(needs_migrate), migration_done=not needs_migrate
)
@ -2315,9 +2309,7 @@ class StatesContextIDMigration(BaseRunTimeMigrationWithQuery):
migration_id = "state_context_id_as_binary"
index_to_drop = ("states", "ix_states_context_id")
@staticmethod
@retryable_database_job("migrate states context_ids to binary format")
def migrate_data_impl(instance: Recorder) -> bool:
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
"""Migrate states context_ids to use binary format, return True if completed."""
_to_bytes = _context_id_to_bytes
session_maker = instance.get_session
@ -2342,13 +2334,10 @@ class StatesContextIDMigration(BaseRunTimeMigrationWithQuery):
for state_id, last_updated_ts, context_id, context_user_id, context_parent_id in states
],
)
# If there is more work to do return False
# so that we can be called again
if is_done := not states:
_mark_migration_done(session, StatesContextIDMigration)
is_done = not states
_LOGGER.debug("Migrating states context_ids to binary format: done=%s", is_done)
return is_done
return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
@ -2362,9 +2351,7 @@ class EventsContextIDMigration(BaseRunTimeMigrationWithQuery):
migration_id = "event_context_id_as_binary"
index_to_drop = ("events", "ix_events_context_id")
@staticmethod
@retryable_database_job("migrate events context_ids to binary format")
def migrate_data_impl(instance: Recorder) -> bool:
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
"""Migrate events context_ids to use binary format, return True if completed."""
_to_bytes = _context_id_to_bytes
session_maker = instance.get_session
@ -2389,13 +2376,10 @@ class EventsContextIDMigration(BaseRunTimeMigrationWithQuery):
for event_id, time_fired_ts, context_id, context_user_id, context_parent_id in events
],
)
# If there is more work to do return False
# so that we can be called again
if is_done := not events:
_mark_migration_done(session, EventsContextIDMigration)
is_done = not events
_LOGGER.debug("Migrating events context_ids to binary format: done=%s", is_done)
return is_done
return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
@ -2412,9 +2396,7 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery):
# no new pending event_types about to be added to
# the db since this happens live
@staticmethod
@retryable_database_job("migrate events event_types to event_type_ids")
def migrate_data_impl(instance: Recorder) -> bool:
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
"""Migrate event_type to event_type_ids, return True if completed."""
session_maker = instance.get_session
_LOGGER.debug("Migrating event_types")
@ -2467,15 +2449,12 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery):
],
)
# If there is more work to do return False
# so that we can be called again
if is_done := not events:
_mark_migration_done(session, EventTypeIDMigration)
is_done = not events
_LOGGER.debug("Migrating event_types done=%s", is_done)
return is_done
return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def migration_done(self, instance: Recorder, session: Session | None) -> None:
def migration_done(self, instance: Recorder, session: Session) -> None:
"""Will be called after migrate returns True."""
_LOGGER.debug("Activating event_types manager as all data is migrated")
instance.event_type_manager.active = True
@ -2495,9 +2474,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
# no new pending states_meta about to be added to
# the db since this happens live
@staticmethod
@retryable_database_job("migrate states entity_ids to states_meta")
def migrate_data_impl(instance: Recorder) -> bool:
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
"""Migrate entity_ids to states_meta, return True if completed.
We do this in two steps because we need the history queries to work
@ -2560,15 +2537,12 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
],
)
# If there is more work to do return False
# so that we can be called again
if is_done := not states:
_mark_migration_done(session, EntityIDMigration)
is_done = not states
_LOGGER.debug("Migrating entity_ids done=%s", is_done)
return is_done
return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done)
def migration_done(self, instance: Recorder, _session: Session | None) -> None:
def migration_done(self, instance: Recorder, session: Session) -> None:
"""Will be called after migrate returns True."""
# The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table
@ -2576,15 +2550,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
# so we set active to True
_LOGGER.debug("Activating states_meta manager as all data is migrated")
instance.states_meta_manager.active = True
session_generator = (
contextlib.nullcontext(_session)
if _session
else session_scope(session=instance.get_session())
)
with (
contextlib.suppress(SQLAlchemyError),
session_generator as session,
):
with contextlib.suppress(SQLAlchemyError):
# If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not
@ -2609,9 +2575,7 @@ class EventIDPostMigration(BaseRunTimeMigration):
task = MigrationTask
migration_version = 2
@staticmethod
@retryable_database_job("cleanup_legacy_event_ids")
def migrate_data_impl(instance: Recorder) -> bool:
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
"""Remove old event_id index from states, returns True if completed.
We used to link states to events using the event_id column but we no
@ -2651,9 +2615,8 @@ class EventIDPostMigration(BaseRunTimeMigration):
if fk_remove_ok:
_drop_index(session_maker, "states", LEGACY_STATES_EVENT_ID_INDEX)
instance.use_legacy_events_index = False
_mark_migration_done(session, EventIDPostMigration)
return True
return DataMigrationStatus(needs_migrate=False, migration_done=fk_remove_ok)
@staticmethod
def _legacy_event_id_foreign_key_exists(instance: Recorder) -> bool:
@ -2674,16 +2637,16 @@ class EventIDPostMigration(BaseRunTimeMigration):
def needs_migrate_impl(
self, instance: Recorder, session: Session
) -> NeedsMigrateResult:
) -> DataMigrationStatus:
"""Return if the migration needs to run."""
if self.schema_version <= LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
return NeedsMigrateResult(needs_migrate=False, migration_done=False)
return DataMigrationStatus(needs_migrate=False, migration_done=False)
if get_index_by_name(
session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX
) is not None or self._legacy_event_id_foreign_key_exists(instance):
instance.use_legacy_events_index = True
return NeedsMigrateResult(needs_migrate=True, migration_done=False)
return NeedsMigrateResult(needs_migrate=False, migration_done=True)
return DataMigrationStatus(needs_migrate=True, migration_done=False)
return DataMigrationStatus(needs_migrate=False, migration_done=True)
@dataclass(slots=True)

View File

@ -645,23 +645,24 @@ def _is_retryable_error(instance: Recorder, err: OperationalError) -> bool:
type _FuncType[_T, **_P, _R] = Callable[Concatenate[_T, _P], _R]
type _FuncOrMethType[**_P, _R] = Callable[_P, _R]
def retryable_database_job[_RecorderT: Recorder, **_P](
description: str,
) -> Callable[[_FuncType[_RecorderT, _P, bool]], _FuncType[_RecorderT, _P, bool]]:
def retryable_database_job[**_P](
description: str, method: bool = False
) -> Callable[[_FuncOrMethType[_P, bool]], _FuncOrMethType[_P, bool]]:
"""Try to execute a database job.
The job should return True if it finished, and False if it needs to be rescheduled.
"""
recorder_pos = 1 if method else 0
def decorator(
job: _FuncType[_RecorderT, _P, bool],
) -> _FuncType[_RecorderT, _P, bool]:
def decorator(job: _FuncOrMethType[_P, bool]) -> _FuncOrMethType[_P, bool]:
@functools.wraps(job)
def wrapper(instance: _RecorderT, *args: _P.args, **kwargs: _P.kwargs) -> bool:
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> bool:
instance: Recorder = args[recorder_pos] # type: ignore[assignment]
try:
return job(instance, *args, **kwargs)
return job(*args, **kwargs)
except OperationalError as err:
if _is_retryable_error(instance, err):
assert isinstance(err.orig, BaseException) # noqa: PT017