Avoid creating nested sessions in recorder migration (#122580)

This commit is contained in:
Erik Montnemery 2024-07-25 15:44:48 +02:00 committed by GitHub
parent f1b933ae0c
commit 0c7ab2062f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2010,7 +2010,7 @@ class MigrationTask(RecorderTask):
# Schedule a new migration task if this one didn't finish # Schedule a new migration task if this one didn't finish
instance.queue_task(MigrationTask(self.migrator)) instance.queue_task(MigrationTask(self.migrator))
else: else:
self.migrator.migration_done(instance) self.migrator.migration_done(instance, None)
@dataclass(slots=True) @dataclass(slots=True)
@ -2046,14 +2046,14 @@ class BaseRunTimeMigration(ABC):
if self.needs_migrate(instance, session): if self.needs_migrate(instance, session):
instance.queue_task(self.task(self)) instance.queue_task(self.task(self))
else: else:
self.migration_done(instance) self.migration_done(instance, session)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def migrate_data(instance: Recorder) -> bool: def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed.""" """Migrate some data, returns True if migration is completed."""
def migration_done(self, instance: Recorder) -> None: def migration_done(self, instance: Recorder, session: Session | None) -> None:
"""Will be called after migrate returns True or if migration is not needed.""" """Will be called after migrate returns True or if migration is not needed."""
@abstractmethod @abstractmethod
@ -2274,7 +2274,7 @@ class EventTypeIDMigration(BaseRunTimeMigrationWithQuery):
_LOGGER.debug("Migrating event_types done=%s", is_done) _LOGGER.debug("Migrating event_types done=%s", is_done)
return is_done return is_done
def migration_done(self, instance: Recorder) -> None: def migration_done(self, instance: Recorder, session: Session | None) -> None:
"""Will be called after migrate returns True.""" """Will be called after migrate returns True."""
_LOGGER.debug("Activating event_types manager as all data is migrated") _LOGGER.debug("Activating event_types manager as all data is migrated")
instance.event_type_manager.active = True instance.event_type_manager.active = True
@ -2367,7 +2367,7 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
_LOGGER.debug("Migrating entity_ids done=%s", is_done) _LOGGER.debug("Migrating entity_ids done=%s", is_done)
return is_done return is_done
def migration_done(self, instance: Recorder) -> None: def migration_done(self, instance: Recorder, _session: Session | None) -> None:
"""Will be called after migrate returns True.""" """Will be called after migrate returns True."""
# The migration has finished, now we start the post migration # The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table # to remove the old entity_id data from the states table
@ -2375,9 +2375,14 @@ class EntityIDMigration(BaseRunTimeMigrationWithQuery):
# so we set active to True # so we set active to True
_LOGGER.debug("Activating states_meta manager as all data is migrated") _LOGGER.debug("Activating states_meta manager as all data is migrated")
instance.states_meta_manager.active = True instance.states_meta_manager.active = True
session_generator = (
contextlib.nullcontext(_session)
if _session
else session_scope(session=instance.get_session())
)
with ( with (
contextlib.suppress(SQLAlchemyError), contextlib.suppress(SQLAlchemyError),
session_scope(session=instance.get_session()) as session, session_generator as session,
): ):
# If ix_states_entity_id_last_updated_ts still exists # If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration # on the states table it means the entity id migration