diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 3ef9b65e259..6f438106ab6 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -2010,7 +2010,7 @@ class MigrationTask(RecorderTask): # Schedule a new migration task if this one didn't finish instance.queue_task(MigrationTask(self.migrator)) else: - self.migrator.migration_done(instance) + self.migrator.migration_done(instance, None) @dataclass(slots=True) @@ -2046,14 +2046,14 @@ class BaseRunTimeMigration(ABC): if self.needs_migrate(instance, session): instance.queue_task(self.task(self)) else: - self.migration_done(instance) + self.migration_done(instance, session) @staticmethod @abstractmethod def migrate_data(instance: Recorder) -> bool: """Migrate some data, returns True if migration is completed.""" - def migration_done(self, instance: Recorder) -> None: + def migration_done(self, instance: Recorder, session: Session | None) -> None: """Will be called after migrate returns True or if migration is not needed.""" @abstractmethod @@ -2274,7 +2274,7 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery): _LOGGER.debug("Migrating event_types done=%s", is_done) return is_done - def migration_done(self, instance: Recorder) -> None: + def migration_done(self, instance: Recorder, session: Session | None) -> None: """Will be called after migrate returns True.""" _LOGGER.debug("Activating event_types manager as all data is migrated") instance.event_type_manager.active = True @@ -2367,7 +2367,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery): _LOGGER.debug("Migrating entity_ids done=%s", is_done) return is_done - def migration_done(self, instance: Recorder) -> None: + def migration_done(self, instance: Recorder, _session: Session | None) -> None: """Will be called after migrate returns True.""" # The migration has finished, now we start the post migration # to remove the old entity_id data from the states table @@ -2375,9 +2375,14 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery): # so we set active to True _LOGGER.debug("Activating states_meta manager as all data is migrated") instance.states_meta_manager.active = True + session_generator = ( + contextlib.nullcontext(_session) + if _session + else session_scope(session=instance.get_session()) + ) with ( contextlib.suppress(SQLAlchemyError), - session_scope(session=instance.get_session()) as session, + session_generator as session, ): # If ix_states_entity_id_last_updated_ts still exists # on the states table it means the entity id migration