Add recorder data migrator class to clean up states table (#122069)

This commit is contained in:
Erik Montnemery 2024-07-22 20:04:01 +02:00 committed by GitHub
parent 4c853803f1
commit 20fc5233a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 87 deletions

View File

@ -16,14 +16,7 @@ import time
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import psutil_home_assistant as ha_psutil import psutil_home_assistant as ha_psutil
from sqlalchemy import ( from sqlalchemy import create_engine, event as sqlalchemy_event, exc, select, update
create_engine,
event as sqlalchemy_event,
exc,
inspect,
select,
update,
)
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.engine.interfaces import DBAPIConnection
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -62,7 +55,6 @@ from .const import (
DOMAIN, DOMAIN,
KEEPALIVE_TIME, KEEPALIVE_TIME,
LAST_REPORTED_SCHEMA_VERSION, LAST_REPORTED_SCHEMA_VERSION,
LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION,
MARIADB_PYMYSQL_URL_PREFIX, MARIADB_PYMYSQL_URL_PREFIX,
MARIADB_URL_PREFIX, MARIADB_URL_PREFIX,
MAX_QUEUE_BACKLOG_MIN_VALUE, MAX_QUEUE_BACKLOG_MIN_VALUE,
@ -75,9 +67,7 @@ from .const import (
SupportedDialect, SupportedDialect,
) )
from .db_schema import ( from .db_schema import (
LEGACY_STATES_EVENT_ID_INDEX,
SCHEMA_VERSION, SCHEMA_VERSION,
TABLE_STATES,
Base, Base,
EventData, EventData,
Events, Events,
@ -91,6 +81,7 @@ from .db_schema import (
from .executor import DBInterruptibleThreadPoolExecutor from .executor import DBInterruptibleThreadPoolExecutor
from .migration import ( from .migration import (
EntityIDMigration, EntityIDMigration,
EventIDPostMigration,
EventsContextIDMigration, EventsContextIDMigration,
EventTypeIDMigration, EventTypeIDMigration,
StatesContextIDMigration, StatesContextIDMigration,
@ -113,7 +104,6 @@ from .tasks import (
CommitTask, CommitTask,
CompileMissingStatisticsTask, CompileMissingStatisticsTask,
DatabaseLockTask, DatabaseLockTask,
EventIdMigrationTask,
ImportStatisticsTask, ImportStatisticsTask,
KeepAliveTask, KeepAliveTask,
PerodicCleanupTask, PerodicCleanupTask,
@ -132,7 +122,6 @@ from .util import (
dburl_to_path, dburl_to_path,
end_incomplete_runs, end_incomplete_runs,
execute_stmt_lambda_element, execute_stmt_lambda_element,
get_index_by_name,
is_second_sunday, is_second_sunday,
move_away_broken_database, move_away_broken_database,
session_scope, session_scope,
@ -831,24 +820,11 @@ class Recorder(threading.Thread):
EventsContextIDMigration, EventsContextIDMigration,
EventTypeIDMigration, EventTypeIDMigration,
EntityIDMigration, EntityIDMigration,
EventIDPostMigration,
): ):
migrator = migrator_cls(schema_status.start_version, migration_changes) migrator = migrator_cls(schema_status.start_version, migration_changes)
migrator.do_migrate(self, session) migrator.do_migrate(self, session)
if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
with contextlib.suppress(SQLAlchemyError):
# If the index of event_ids on the states table is still present
# or the event_id foreign key still exists we need to queue a
# task to remove it.
if (
get_index_by_name(
session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX
)
or self._legacy_event_id_foreign_key_exists()
):
self.queue_task(EventIdMigrationTask())
self.use_legacy_events_index = True
# We must only set the db ready after we have set the table managers # We must only set the db ready after we have set the table managers
# to active if there is no data to migrate. # to active if there is no data to migrate.
# #
@ -1327,29 +1303,10 @@ class Recorder(threading.Thread):
"""Run post schema migration tasks.""" """Run post schema migration tasks."""
migration.post_schema_migration(self, old_version, new_version) migration.post_schema_migration(self, old_version, new_version)
def _legacy_event_id_foreign_key_exists(self) -> bool:
"""Check if the legacy event_id foreign key exists."""
engine = self.engine
assert engine is not None
return bool(
next(
(
fk
for fk in inspect(engine).get_foreign_keys(TABLE_STATES)
if fk["constrained_columns"] == ["event_id"]
),
None,
)
)
def _post_migrate_entity_ids(self) -> bool: def _post_migrate_entity_ids(self) -> bool:
"""Post migrate entity_ids if needed.""" """Post migrate entity_ids if needed."""
return migration.post_migrate_entity_ids(self) return migration.post_migrate_entity_ids(self)
def _cleanup_legacy_states_event_ids(self) -> bool:
"""Cleanup legacy event_ids if needed."""
return migration.cleanup_legacy_states_event_ids(self)
def _send_keep_alive(self) -> None: def _send_keep_alive(self) -> None:
"""Send a keep alive to keep the db connection open.""" """Send a keep alive to keep the db connection open."""
assert self.event_session is not None assert self.event_session is not None

View File

@ -52,6 +52,7 @@ from .auto_repairs.statistics.schema import (
from .const import ( from .const import (
CONTEXT_ID_AS_BINARY_SCHEMA_VERSION, CONTEXT_ID_AS_BINARY_SCHEMA_VERSION,
EVENT_TYPE_IDS_SCHEMA_VERSION, EVENT_TYPE_IDS_SCHEMA_VERSION,
LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION,
STATES_META_SCHEMA_VERSION, STATES_META_SCHEMA_VERSION,
SupportedDialect, SupportedDialect,
) )
@ -1949,6 +1950,7 @@ def cleanup_legacy_states_event_ids(instance: Recorder) -> bool:
) )
_drop_index(session_maker, "states", LEGACY_STATES_EVENT_ID_INDEX) _drop_index(session_maker, "states", LEGACY_STATES_EVENT_ID_INDEX)
instance.use_legacy_events_index = False instance.use_legacy_events_index = False
_mark_migration_done(session, EventIDPostMigration)
return True return True
@ -2018,6 +2020,14 @@ class CommitBeforeMigrationTask(MigrationTask):
commit_before = True commit_before = True
@dataclass(frozen=True, kw_only=True)
class NeedsMigrateResult:
"""Container for the return value of BaseRunTimeMigration.needs_migrate_impl."""
needs_migrate: bool
migration_done: bool
class BaseRunTimeMigration(ABC): class BaseRunTimeMigration(ABC):
"""Base class for run time migrations.""" """Base class for run time migrations."""
@ -2033,7 +2043,7 @@ class BaseRunTimeMigration(ABC):
def do_migrate(self, instance: Recorder, session: Session) -> None: def do_migrate(self, instance: Recorder, session: Session) -> None:
"""Start migration if needed.""" """Start migration if needed."""
if self.needs_migrate(session): if self.needs_migrate(instance, session):
instance.queue_task(self.task(self)) instance.queue_task(self.task(self))
else: else:
self.migration_done(instance) self.migration_done(instance)
@ -2047,10 +2057,12 @@ class BaseRunTimeMigration(ABC):
"""Will be called after migrate returns True or if migration is not needed.""" """Will be called after migrate returns True or if migration is not needed."""
@abstractmethod @abstractmethod
def needs_migrate_query(self) -> StatementLambdaElement: def needs_migrate_impl(
"""Return the query to check if the migration needs to run.""" self, instance: Recorder, session: Session
) -> NeedsMigrateResult:
"""Return if the migration needs to run and if it is done."""
def needs_migrate(self, session: Session) -> bool: def needs_migrate(self, instance: Recorder, session: Session) -> bool:
"""Return if the migration needs to run. """Return if the migration needs to run.
If the migration needs to run, it will return True. If the migration needs to run, it will return True.
@ -2068,13 +2080,30 @@ class BaseRunTimeMigration(ABC):
# We do not know if the migration is done from the # We do not know if the migration is done from the
# migration changes table so we must check the data # migration changes table so we must check the data
# This is the slow path # This is the slow path
if not execute_stmt_lambda_element(session, self.needs_migrate_query()): needs_migrate = self.needs_migrate_impl(instance, session)
if needs_migrate.migration_done:
_mark_migration_done(session, self.__class__) _mark_migration_done(session, self.__class__)
return False return needs_migrate.needs_migrate
return True
class StatesContextIDMigration(BaseRunTimeMigration): class BaseRunTimeMigrationWithQuery(BaseRunTimeMigration):
"""Base class for run time migrations."""
@abstractmethod
def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
def needs_migrate_impl(
self, instance: Recorder, session: Session
) -> NeedsMigrateResult:
"""Return if the migration needs to run."""
needs_migrate = execute_stmt_lambda_element(session, self.needs_migrate_query())
return NeedsMigrateResult(
needs_migrate=bool(needs_migrate), migration_done=not needs_migrate
)
class StatesContextIDMigration(BaseRunTimeMigrationWithQuery):
"""Migration to migrate states context_ids to binary format.""" """Migration to migrate states context_ids to binary format."""
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
@ -2123,7 +2152,7 @@ class StatesContextIDMigration(BaseRunTimeMigration):
return has_states_context_ids_to_migrate() return has_states_context_ids_to_migrate()
class EventsContextIDMigration(BaseRunTimeMigration): class EventsContextIDMigration(BaseRunTimeMigrationWithQuery):
"""Migration to migrate events context_ids to binary format.""" """Migration to migrate events context_ids to binary format."""
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
@ -2172,7 +2201,7 @@ class EventsContextIDMigration(BaseRunTimeMigration):
return has_events_context_ids_to_migrate() return has_events_context_ids_to_migrate()
class EventTypeIDMigration(BaseRunTimeMigration): class EventTypeIDMigration(BaseRunTimeMigrationWithQuery):
"""Migration to migrate event_type to event_type_ids.""" """Migration to migrate event_type to event_type_ids."""
required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION
@ -2255,7 +2284,7 @@ class EventTypeIDMigration(BaseRunTimeMigration):
return has_event_type_to_migrate() return has_event_type_to_migrate()
class EntityIDMigration(BaseRunTimeMigration): class EntityIDMigration(BaseRunTimeMigrationWithQuery):
"""Migration to migrate entity_ids to states_meta.""" """Migration to migrate entity_ids to states_meta."""
required_schema_version = STATES_META_SCHEMA_VERSION required_schema_version = STATES_META_SCHEMA_VERSION
@ -2367,6 +2396,48 @@ class EntityIDMigration(BaseRunTimeMigration):
return has_entity_ids_to_migrate() return has_entity_ids_to_migrate()
class EventIDPostMigration(BaseRunTimeMigration):
"""Migration to remove old event_id index from states."""
migration_id = "event_id_post_migration"
task = MigrationTask
@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return cleanup_legacy_states_event_ids(instance)
@staticmethod
def _legacy_event_id_foreign_key_exists(instance: Recorder) -> bool:
"""Check if the legacy event_id foreign key exists."""
engine = instance.engine
assert engine is not None
inspector = sqlalchemy.inspect(engine)
return bool(
next(
(
fk
for fk in inspector.get_foreign_keys(TABLE_STATES)
if fk["constrained_columns"] == ["event_id"]
),
None,
)
)
def needs_migrate_impl(
self, instance: Recorder, session: Session
) -> NeedsMigrateResult:
"""Return if the migration needs to run."""
if self.schema_version <= LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
return NeedsMigrateResult(needs_migrate=False, migration_done=False)
if get_index_by_name(
session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX
) is not None or self._legacy_event_id_foreign_key_exists(instance):
instance.use_legacy_events_index = True
return NeedsMigrateResult(needs_migrate=True, migration_done=False)
return NeedsMigrateResult(needs_migrate=False, migration_done=True)
def _mark_migration_done( def _mark_migration_done(
session: Session, migration: type[BaseRunTimeMigration] session: Session, migration: type[BaseRunTimeMigration]
) -> None: ) -> None:

View File

@ -371,20 +371,6 @@ class EntityIDPostMigrationTask(RecorderTask):
instance.queue_task(EntityIDPostMigrationTask()) instance.queue_task(EntityIDPostMigrationTask())
@dataclass(slots=True)
class EventIdMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to cleanup legacy event_ids in the states table.
This task should only be queued if the ix_states_event_id index exists
since it is used to scan the states table and it will be removed after this
task is run if its no longer needed.
"""
def run(self, instance: Recorder) -> None:
"""Clean up the legacy event_id index on states."""
instance._cleanup_legacy_states_event_ids() # noqa: SLF001
@dataclass(slots=True) @dataclass(slots=True)
class RefreshEventTypesTask(RecorderTask): class RefreshEventTypesTask(RecorderTask):
"""An object to insert into the recorder queue to refresh event types.""" """An object to insert into the recorder queue to refresh event types."""

View File

@ -3,7 +3,7 @@
from datetime import timedelta from datetime import timedelta
import importlib import importlib
import sys import sys
from unittest.mock import DEFAULT, patch from unittest.mock import patch
import pytest import pytest
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
@ -107,10 +107,9 @@ async def test_migrate_times(
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple( patch("homeassistant.components.recorder.Recorder._post_migrate_entity_ids"),
"homeassistant.components.recorder.Recorder", patch(
_post_migrate_entity_ids=DEFAULT, "homeassistant.components.recorder.migration.cleanup_legacy_states_event_ids"
_cleanup_legacy_states_event_ids=DEFAULT,
), ),
): ):
async with ( async with (
@ -259,10 +258,9 @@ async def test_migrate_can_resume_entity_id_post_migration(
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple( patch("homeassistant.components.recorder.Recorder._post_migrate_entity_ids"),
"homeassistant.components.recorder.Recorder", patch(
_post_migrate_entity_ids=DEFAULT, "homeassistant.components.recorder.migration.cleanup_legacy_states_event_ids"
_cleanup_legacy_states_event_ids=DEFAULT,
), ),
): ):
async with ( async with (
@ -314,6 +312,7 @@ async def test_migrate_can_resume_entity_id_post_migration(
await hass.async_stop() await hass.async_stop()
@pytest.mark.parametrize("enable_migrate_event_ids", [True])
@pytest.mark.parametrize("persistent_database", [True]) @pytest.mark.parametrize("persistent_database", [True])
@pytest.mark.usefixtures("hass_storage") # Prevent test hass from writing to storage @pytest.mark.usefixtures("hass_storage") # Prevent test hass from writing to storage
async def test_migrate_can_resume_ix_states_event_id_removed( async def test_migrate_can_resume_ix_states_event_id_removed(
@ -381,10 +380,9 @@ async def test_migrate_can_resume_ix_states_event_id_removed(
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test),
patch.multiple( patch("homeassistant.components.recorder.Recorder._post_migrate_entity_ids"),
"homeassistant.components.recorder.Recorder", patch(
_post_migrate_entity_ids=DEFAULT, "homeassistant.components.recorder.migration.cleanup_legacy_states_event_ids"
_cleanup_legacy_states_event_ids=DEFAULT,
), ),
): ):
async with ( async with (
@ -440,6 +438,7 @@ async def test_migrate_can_resume_ix_states_event_id_removed(
states_indexes = await instance.async_add_executor_job(_get_states_index_names) states_indexes = await instance.async_add_executor_job(_get_states_index_names)
states_index_names = {index["name"] for index in states_indexes} states_index_names = {index["name"] for index in states_indexes}
assert instance.use_legacy_events_index is False
assert "ix_states_entity_id_last_updated_ts" not in states_index_names assert "ix_states_entity_id_last_updated_ts" not in states_index_names
assert "ix_states_event_id" not in states_index_names assert "ix_states_event_id" not in states_index_names
assert await instance.async_add_executor_job(_get_event_id_foreign_keys) is None assert await instance.async_add_executor_job(_get_event_id_foreign_keys) is None

View File

@ -1475,9 +1475,9 @@ async def async_test_recorder(
migration.EntityIDMigration.migrate_data if enable_migrate_entity_ids else None migration.EntityIDMigration.migrate_data if enable_migrate_entity_ids else None
) )
legacy_event_id_foreign_key_exists = ( legacy_event_id_foreign_key_exists = (
recorder.Recorder._legacy_event_id_foreign_key_exists migration.EventIDPostMigration._legacy_event_id_foreign_key_exists
if enable_migrate_event_ids if enable_migrate_event_ids
else None else lambda _: None
) )
with ( with (
patch( patch(
@ -1516,7 +1516,7 @@ async def async_test_recorder(
autospec=True, autospec=True,
), ),
patch( patch(
"homeassistant.components.recorder.Recorder._legacy_event_id_foreign_key_exists", "homeassistant.components.recorder.migration.EventIDPostMigration._legacy_event_id_foreign_key_exists",
side_effect=legacy_event_id_foreign_key_exists, side_effect=legacy_event_id_foreign_key_exists,
autospec=True, autospec=True,
), ),