diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 242e503611c..324bdd5ea13 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -2201,8 +2201,8 @@ class CommitBeforeMigrationTask(MigrationTask): @dataclass(frozen=True, kw_only=True) -class NeedsMigrateResult: - """Container for the return value of BaseRunTimeMigration.needs_migrate_impl.""" +class DataMigrationStatus: + """Container for data migrator status.""" needs_migrate: bool migration_done: bool @@ -2229,36 +2229,30 @@ class BaseRunTimeMigration(ABC): else: self.migration_done(instance, session) + @retryable_database_job("migrate data", method=True) def migrate_data(self, instance: Recorder) -> bool: """Migrate some data, returns True if migration is completed.""" - if result := self.migrate_data_impl(instance): + status = self.migrate_data_impl(instance) + if status.migration_done: if self.index_to_drop is not None: - self._remove_index(instance, self.index_to_drop) - self.migration_done(instance, None) - return result + table, index = self.index_to_drop + _drop_index(instance.get_session, table, index) + with session_scope(session=instance.get_session()) as session: + self.migration_done(instance, session) + _mark_migration_done(session, self.__class__) + return not status.needs_migrate - @staticmethod @abstractmethod - def migrate_data_impl(instance: Recorder) -> bool: - """Migrate some data, returns True if migration is completed.""" + def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: + """Migrate some data, return if the migration needs to run and if it is done.""" - @staticmethod - @database_job_retry_wrapper("remove index") - def _remove_index(instance: Recorder, index_to_drop: tuple[str, str]) -> None: - """Remove indices. - - Called when migration is completed. - """ - table, index = index_to_drop - _drop_index(instance.get_session, table, index) - - def migration_done(self, instance: Recorder, session: Session | None) -> None: + def migration_done(self, instance: Recorder, session: Session) -> None: """Will be called after migrate returns True or if migration is not needed.""" @abstractmethod def needs_migrate_impl( self, instance: Recorder, session: Session - ) -> NeedsMigrateResult: + ) -> DataMigrationStatus: """Return if the migration needs to run and if it is done.""" def needs_migrate(self, instance: Recorder, session: Session) -> bool: @@ -2300,10 +2294,10 @@ class BaseRunTimeMigrationWithQuery(BaseRunTimeMigration): def needs_migrate_impl( self, instance: Recorder, session: Session - ) -> NeedsMigrateResult: + ) -> DataMigrationStatus: """Return if the migration needs to run.""" needs_migrate = execute_stmt_lambda_element(session, self.needs_migrate_query()) - return NeedsMigrateResult( + return DataMigrationStatus( needs_migrate=bool(needs_migrate), migration_done=not needs_migrate ) @@ -2315,9 +2309,7 @@ class StatesContextIDMigration(BaseRunTimeMigrationWithQuery): migration_id = "state_context_id_as_binary" index_to_drop = ("states", "ix_states_context_id") - @staticmethod - @retryable_database_job("migrate states context_ids to binary format") - def migrate_data_impl(instance: Recorder) -> bool: + def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: """Migrate states context_ids to use binary format, return True if completed.""" _to_bytes = _context_id_to_bytes session_maker = instance.get_session @@ -2342,13 +2334,10 @@ class StatesContextIDMigration(BaseRunTimeMigrationWithQuery): for state_id, last_updated_ts, context_id, context_user_id, context_parent_id in states ], ) - # If there is more work to do return False - # so that we can be called again - if is_done := not states: - _mark_migration_done(session, StatesContextIDMigration) + is_done = not states _LOGGER.debug("Migrating states context_ids to binary format: done=%s", is_done) - return is_done + return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done) def needs_migrate_query(self) -> StatementLambdaElement: """Return the query to check if the migration needs to run.""" @@ -2362,9 +2351,7 @@ class EventsContextIDMigration(BaseRunTimeMigrationWithQuery): migration_id = "event_context_id_as_binary" index_to_drop = ("events", "ix_events_context_id") - @staticmethod - @retryable_database_job("migrate events context_ids to binary format") - def migrate_data_impl(instance: Recorder) -> bool: + def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: """Migrate events context_ids to use binary format, return True if completed.""" _to_bytes = _context_id_to_bytes session_maker = instance.get_session @@ -2389,13 +2376,10 @@ class EventsContextIDMigration(BaseRunTimeMigrationWithQuery): for event_id, time_fired_ts, context_id, context_user_id, context_parent_id in events ], ) - # If there is more work to do return False - # so that we can be called again - if is_done := not events: - _mark_migration_done(session, EventsContextIDMigration) + is_done = not events _LOGGER.debug("Migrating events context_ids to binary format: done=%s", is_done) - return is_done + return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done) def needs_migrate_query(self) -> StatementLambdaElement: """Return the query to check if the migration needs to run.""" @@ -2412,9 +2396,7 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery): # no new pending event_types about to be added to # the db since this happens live - @staticmethod - @retryable_database_job("migrate events event_types to event_type_ids") - def migrate_data_impl(instance: Recorder) -> bool: + def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: """Migrate event_type to event_type_ids, return True if completed.""" session_maker = instance.get_session _LOGGER.debug("Migrating event_types") @@ -2467,15 +2449,12 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery): ], ) - # If there is more work to do return False - # so that we can be called again - if is_done := not events: - _mark_migration_done(session, EventTypeIDMigration) + is_done = not events _LOGGER.debug("Migrating event_types done=%s", is_done) - return is_done + return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done) - def migration_done(self, instance: Recorder, session: Session | None) -> None: + def migration_done(self, instance: Recorder, session: Session) -> None: """Will be called after migrate returns True.""" _LOGGER.debug("Activating event_types manager as all data is migrated") instance.event_type_manager.active = True @@ -2495,9 +2474,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery): # no new pending states_meta about to be added to # the db since this happens live - @staticmethod - @retryable_database_job("migrate states entity_ids to states_meta") - def migrate_data_impl(instance: Recorder) -> bool: + def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: """Migrate entity_ids to states_meta, return True if completed. We do this in two steps because we need the history queries to work @@ -2560,15 +2537,12 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery): ], ) - # If there is more work to do return False - # so that we can be called again - if is_done := not states: - _mark_migration_done(session, EntityIDMigration) + is_done = not states _LOGGER.debug("Migrating entity_ids done=%s", is_done) - return is_done + return DataMigrationStatus(needs_migrate=not is_done, migration_done=is_done) - def migration_done(self, instance: Recorder, _session: Session | None) -> None: + def migration_done(self, instance: Recorder, session: Session) -> None: """Will be called after migrate returns True.""" # The migration has finished, now we start the post migration # to remove the old entity_id data from the states table @@ -2576,15 +2550,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery): # so we set active to True _LOGGER.debug("Activating states_meta manager as all data is migrated") instance.states_meta_manager.active = True - session_generator = ( - contextlib.nullcontext(_session) - if _session - else session_scope(session=instance.get_session()) - ) - with ( - contextlib.suppress(SQLAlchemyError), - session_generator as session, - ): + with contextlib.suppress(SQLAlchemyError): # If ix_states_entity_id_last_updated_ts still exists # on the states table it means the entity id migration # finished by the EntityIDPostMigrationTask did not @@ -2609,9 +2575,7 @@ class EventIDPostMigration(BaseRunTimeMigration): task = MigrationTask migration_version = 2 - @staticmethod - @retryable_database_job("cleanup_legacy_event_ids") - def migrate_data_impl(instance: Recorder) -> bool: + def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: """Remove old event_id index from states, returns True if completed. We used to link states to events using the event_id column but we no @@ -2651,9 +2615,8 @@ class EventIDPostMigration(BaseRunTimeMigration): if fk_remove_ok: _drop_index(session_maker, "states", LEGACY_STATES_EVENT_ID_INDEX) instance.use_legacy_events_index = False - _mark_migration_done(session, EventIDPostMigration) - return True + return DataMigrationStatus(needs_migrate=False, migration_done=fk_remove_ok) @staticmethod def _legacy_event_id_foreign_key_exists(instance: Recorder) -> bool: @@ -2674,16 +2637,16 @@ class EventIDPostMigration(BaseRunTimeMigration): def needs_migrate_impl( self, instance: Recorder, session: Session - ) -> NeedsMigrateResult: + ) -> DataMigrationStatus: """Return if the migration needs to run.""" if self.schema_version <= LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION: - return NeedsMigrateResult(needs_migrate=False, migration_done=False) + return DataMigrationStatus(needs_migrate=False, migration_done=False) if get_index_by_name( session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX ) is not None or self._legacy_event_id_foreign_key_exists(instance): instance.use_legacy_events_index = True - return NeedsMigrateResult(needs_migrate=True, migration_done=False) - return NeedsMigrateResult(needs_migrate=False, migration_done=True) + return DataMigrationStatus(needs_migrate=True, migration_done=False) + return DataMigrationStatus(needs_migrate=False, migration_done=True) @dataclass(slots=True) diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 4d494aed7d5..9f6cdccd79a 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -645,23 +645,24 @@ def _is_retryable_error(instance: Recorder, err: OperationalError) -> bool: type _FuncType[_T, **_P, _R] = Callable[Concatenate[_T, _P], _R] +type _FuncOrMethType[**_P, _R] = Callable[_P, _R] -def retryable_database_job[_RecorderT: Recorder, **_P]( - description: str, -) -> Callable[[_FuncType[_RecorderT, _P, bool]], _FuncType[_RecorderT, _P, bool]]: +def retryable_database_job[**_P]( + description: str, method: bool = False +) -> Callable[[_FuncOrMethType[_P, bool]], _FuncOrMethType[_P, bool]]: """Try to execute a database job. The job should return True if it finished, and False if it needs to be rescheduled. """ + recorder_pos = 1 if method else 0 - def decorator( - job: _FuncType[_RecorderT, _P, bool], - ) -> _FuncType[_RecorderT, _P, bool]: + def decorator(job: _FuncOrMethType[_P, bool]) -> _FuncOrMethType[_P, bool]: @functools.wraps(job) - def wrapper(instance: _RecorderT, *args: _P.args, **kwargs: _P.kwargs) -> bool: + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> bool: + instance: Recorder = args[recorder_pos] # type: ignore[assignment] try: - return job(instance, *args, **kwargs) + return job(*args, **kwargs) except OperationalError as err: if _is_retryable_error(instance, err): assert isinstance(err.orig, BaseException) # noqa: PT017