Don't run recorder data migration on new databases (#133412)

* Don't run recorder data migration on new databases

* Add tests
This commit is contained in:
Erik Montnemery 2024-12-17 20:02:12 +01:00 committed by GitHub
parent 633433709f
commit d22668a166
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 330 additions and 47 deletions

View File

@ -970,6 +970,7 @@ class Recorder(threading.Thread):
# which does not need migration or repair. # which does not need migration or repair.
new_schema_status = migration.SchemaValidationStatus( new_schema_status = migration.SchemaValidationStatus(
current_version=SCHEMA_VERSION, current_version=SCHEMA_VERSION,
initial_version=SCHEMA_VERSION,
migration_needed=False, migration_needed=False,
non_live_data_migration_needed=False, non_live_data_migration_needed=False,
schema_errors=set(), schema_errors=set(),

View File

@ -180,7 +180,27 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
raise ex raise ex
def _get_schema_version(session: Session) -> int | None: def _get_initial_schema_version(session: Session) -> int | None:
"""Get the schema version the database was created with."""
res = (
session.query(SchemaChanges.schema_version)
.order_by(SchemaChanges.change_id.asc())
.first()
)
return getattr(res, "schema_version", None)
def get_initial_schema_version(session_maker: Callable[[], Session]) -> int | None:
"""Get the schema version the database was created with."""
try:
with session_scope(session=session_maker(), read_only=True) as session:
return _get_initial_schema_version(session)
except Exception:
_LOGGER.exception("Error when determining DB schema version")
return None
def _get_current_schema_version(session: Session) -> int | None:
"""Get the schema version.""" """Get the schema version."""
res = ( res = (
session.query(SchemaChanges.schema_version) session.query(SchemaChanges.schema_version)
@ -190,11 +210,11 @@ def _get_schema_version(session: Session) -> int | None:
return getattr(res, "schema_version", None) return getattr(res, "schema_version", None)
def get_schema_version(session_maker: Callable[[], Session]) -> int | None: def get_current_schema_version(session_maker: Callable[[], Session]) -> int | None:
"""Get the schema version.""" """Get the schema version."""
try: try:
with session_scope(session=session_maker(), read_only=True) as session: with session_scope(session=session_maker(), read_only=True) as session:
return _get_schema_version(session) return _get_current_schema_version(session)
except Exception: except Exception:
_LOGGER.exception("Error when determining DB schema version") _LOGGER.exception("Error when determining DB schema version")
return None return None
@ -205,6 +225,7 @@ class SchemaValidationStatus:
"""Store schema validation status.""" """Store schema validation status."""
current_version: int current_version: int
initial_version: int
migration_needed: bool migration_needed: bool
non_live_data_migration_needed: bool non_live_data_migration_needed: bool
schema_errors: set[str] schema_errors: set[str]
@ -227,8 +248,9 @@ def validate_db_schema(
""" """
schema_errors: set[str] = set() schema_errors: set[str] = set()
current_version = get_schema_version(session_maker) current_version = get_current_schema_version(session_maker)
if current_version is None: initial_version = get_initial_schema_version(session_maker)
if current_version is None or initial_version is None:
return None return None
if is_current := _schema_is_current(current_version): if is_current := _schema_is_current(current_version):
@ -238,11 +260,15 @@ def validate_db_schema(
schema_migration_needed = not is_current schema_migration_needed = not is_current
_non_live_data_migration_needed = non_live_data_migration_needed( _non_live_data_migration_needed = non_live_data_migration_needed(
instance, session_maker, current_version instance,
session_maker,
initial_schema_version=initial_version,
start_schema_version=current_version,
) )
return SchemaValidationStatus( return SchemaValidationStatus(
current_version=current_version, current_version=current_version,
initial_version=initial_version,
non_live_data_migration_needed=_non_live_data_migration_needed, non_live_data_migration_needed=_non_live_data_migration_needed,
migration_needed=schema_migration_needed or _non_live_data_migration_needed, migration_needed=schema_migration_needed or _non_live_data_migration_needed,
schema_errors=schema_errors, schema_errors=schema_errors,
@ -377,17 +403,26 @@ def _get_migration_changes(session: Session) -> dict[str, int]:
def non_live_data_migration_needed( def non_live_data_migration_needed(
instance: Recorder, instance: Recorder,
session_maker: Callable[[], Session], session_maker: Callable[[], Session],
schema_version: int, *,
initial_schema_version: int,
start_schema_version: int,
) -> bool: ) -> bool:
"""Return True if non-live data migration is needed. """Return True if non-live data migration is needed.
:param initial_schema_version: The schema version the database was created with.
:param start_schema_version: The schema version when starting the migration.
This must only be called if database schema is current. This must only be called if database schema is current.
""" """
migration_needed = False migration_needed = False
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
migration_changes = _get_migration_changes(session) migration_changes = _get_migration_changes(session)
for migrator_cls in NON_LIVE_DATA_MIGRATORS: for migrator_cls in NON_LIVE_DATA_MIGRATORS:
migrator = migrator_cls(schema_version, migration_changes) migrator = migrator_cls(
initial_schema_version=initial_schema_version,
start_schema_version=start_schema_version,
migration_changes=migration_changes,
)
migration_needed |= migrator.needs_migrate(instance, session) migration_needed |= migrator.needs_migrate(instance, session)
return migration_needed return migration_needed
@ -406,7 +441,11 @@ def migrate_data_non_live(
migration_changes = _get_migration_changes(session) migration_changes = _get_migration_changes(session)
for migrator_cls in NON_LIVE_DATA_MIGRATORS: for migrator_cls in NON_LIVE_DATA_MIGRATORS:
migrator = migrator_cls(schema_status.start_version, migration_changes) migrator = migrator_cls(
initial_schema_version=schema_status.initial_version,
start_schema_version=schema_status.start_version,
migration_changes=migration_changes,
)
migrator.migrate_all(instance, session_maker) migrator.migrate_all(instance, session_maker)
@ -423,7 +462,11 @@ def migrate_data_live(
migration_changes = _get_migration_changes(session) migration_changes = _get_migration_changes(session)
for migrator_cls in LIVE_DATA_MIGRATORS: for migrator_cls in LIVE_DATA_MIGRATORS:
migrator = migrator_cls(schema_status.start_version, migration_changes) migrator = migrator_cls(
initial_schema_version=schema_status.initial_version,
start_schema_version=schema_status.start_version,
migration_changes=migration_changes,
)
migrator.queue_migration(instance, session) migrator.queue_migration(instance, session)
@ -2233,7 +2276,7 @@ def initialize_database(session_maker: Callable[[], Session]) -> bool:
"""Initialize a new database.""" """Initialize a new database."""
try: try:
with session_scope(session=session_maker(), read_only=True) as session: with session_scope(session=session_maker(), read_only=True) as session:
if _get_schema_version(session) is not None: if _get_current_schema_version(session) is not None:
return True return True
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
@ -2277,13 +2320,25 @@ class BaseMigration(ABC):
"""Base class for migrations.""" """Base class for migrations."""
index_to_drop: tuple[str, str] | None = None index_to_drop: tuple[str, str] | None = None
required_schema_version = 0 required_schema_version = 0 # Schema version required to run migration queries
max_initial_schema_version: int # Skip migration if db created after this version
migration_version = 1 migration_version = 1
migration_id: str migration_id: str
def __init__(self, schema_version: int, migration_changes: dict[str, int]) -> None: def __init__(
"""Initialize a new BaseRunTimeMigration.""" self,
self.schema_version = schema_version *,
initial_schema_version: int,
start_schema_version: int,
migration_changes: dict[str, int],
) -> None:
"""Initialize a new BaseRunTimeMigration.
:param initial_schema_version: The schema version the database was created with.
:param start_schema_version: The schema version when starting the migration.
"""
self.initial_schema_version = initial_schema_version
self.start_schema_version = start_schema_version
self.migration_changes = migration_changes self.migration_changes = migration_changes
@abstractmethod @abstractmethod
@ -2324,7 +2379,15 @@ class BaseMigration(ABC):
mark the migration as done in the database if its not already mark the migration as done in the database if its not already
marked as done. marked as done.
""" """
if self.schema_version < self.required_schema_version: if self.initial_schema_version > self.max_initial_schema_version:
_LOGGER.debug(
"Data migration '%s' not needed, database created with version %s "
"after migrator was added",
self.migration_id,
self.initial_schema_version,
)
return False
if self.start_schema_version < self.required_schema_version:
# Schema is too old, we must have to migrate # Schema is too old, we must have to migrate
_LOGGER.info( _LOGGER.info(
"Data migration '%s' needed, schema too old", self.migration_id "Data migration '%s' needed, schema too old", self.migration_id
@ -2426,6 +2489,7 @@ class StatesContextIDMigration(BaseMigrationWithQuery, BaseOffLineMigration):
"""Migration to migrate states context_ids to binary format.""" """Migration to migrate states context_ids to binary format."""
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
max_initial_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION - 1
migration_id = "state_context_id_as_binary" migration_id = "state_context_id_as_binary"
migration_version = 2 migration_version = 2
index_to_drop = ("states", "ix_states_context_id") index_to_drop = ("states", "ix_states_context_id")
@ -2469,6 +2533,7 @@ class EventsContextIDMigration(BaseMigrationWithQuery, BaseOffLineMigration):
"""Migration to migrate events context_ids to binary format.""" """Migration to migrate events context_ids to binary format."""
required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
max_initial_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION - 1
migration_id = "event_context_id_as_binary" migration_id = "event_context_id_as_binary"
migration_version = 2 migration_version = 2
index_to_drop = ("events", "ix_events_context_id") index_to_drop = ("events", "ix_events_context_id")
@ -2512,6 +2577,7 @@ class EventTypeIDMigration(BaseMigrationWithQuery, BaseOffLineMigration):
"""Migration to migrate event_type to event_type_ids.""" """Migration to migrate event_type to event_type_ids."""
required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION
max_initial_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION - 1
migration_id = "event_type_id_migration" migration_id = "event_type_id_migration"
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@ -2581,6 +2647,7 @@ class EntityIDMigration(BaseMigrationWithQuery, BaseOffLineMigration):
"""Migration to migrate entity_ids to states_meta.""" """Migration to migrate entity_ids to states_meta."""
required_schema_version = STATES_META_SCHEMA_VERSION required_schema_version = STATES_META_SCHEMA_VERSION
max_initial_schema_version = STATES_META_SCHEMA_VERSION - 1
migration_id = "entity_id_migration" migration_id = "entity_id_migration"
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@ -2660,6 +2727,7 @@ class EventIDPostMigration(BaseRunTimeMigration):
"""Migration to remove old event_id index from states.""" """Migration to remove old event_id index from states."""
migration_id = "event_id_post_migration" migration_id = "event_id_post_migration"
max_initial_schema_version = LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION - 1
task = MigrationTask task = MigrationTask
migration_version = 2 migration_version = 2
@ -2728,7 +2796,7 @@ class EventIDPostMigration(BaseRunTimeMigration):
self, instance: Recorder, session: Session self, instance: Recorder, session: Session
) -> DataMigrationStatus: ) -> DataMigrationStatus:
"""Return if the migration needs to run.""" """Return if the migration needs to run."""
if self.schema_version <= LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION: if self.start_schema_version <= LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
return DataMigrationStatus(needs_migrate=False, migration_done=False) return DataMigrationStatus(needs_migrate=False, migration_done=False)
if get_index_by_name( if get_index_by_name(
session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX session, TABLE_STATES, LEGACY_STATES_EVENT_ID_INDEX
@ -2745,6 +2813,7 @@ class EntityIDPostMigration(BaseMigrationWithQuery, BaseOffLineMigration):
""" """
migration_id = "entity_id_post_migration" migration_id = "entity_id_post_migration"
max_initial_schema_version = STATES_META_SCHEMA_VERSION - 1
index_to_drop = (TABLE_STATES, LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX) index_to_drop = (TABLE_STATES, LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX)
def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus: def migrate_data_impl(self, instance: Recorder) -> DataMigrationStatus:
@ -2758,8 +2827,8 @@ class EntityIDPostMigration(BaseMigrationWithQuery, BaseOffLineMigration):
NON_LIVE_DATA_MIGRATORS: tuple[type[BaseOffLineMigration], ...] = ( NON_LIVE_DATA_MIGRATORS: tuple[type[BaseOffLineMigration], ...] = (
StatesContextIDMigration, # Introduced in HA Core 2023.4 StatesContextIDMigration, # Introduced in HA Core 2023.4 by PR #88942
EventsContextIDMigration, # Introduced in HA Core 2023.4 EventsContextIDMigration, # Introduced in HA Core 2023.4 by PR #88942
EventTypeIDMigration, # Introduced in HA Core 2023.4 by PR #89465 EventTypeIDMigration, # Introduced in HA Core 2023.4 by PR #89465
EntityIDMigration, # Introduced in HA Core 2023.4 by PR #89557 EntityIDMigration, # Introduced in HA Core 2023.4 by PR #89557
EntityIDPostMigration, # Introduced in HA Core 2023.4 by PR #89557 EntityIDPostMigration, # Introduced in HA Core 2023.4 by PR #89557

View File

@ -964,12 +964,17 @@ async def test_recorder_setup_failure(hass: HomeAssistant) -> None:
hass.stop() hass.stop()
async def test_recorder_validate_schema_failure(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
"function_to_patch", ["_get_current_schema_version", "_get_initial_schema_version"]
)
async def test_recorder_validate_schema_failure(
hass: HomeAssistant, function_to_patch: str
) -> None:
"""Test some exceptions.""" """Test some exceptions."""
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
with ( with (
patch( patch(
"homeassistant.components.recorder.migration._get_schema_version" f"homeassistant.components.recorder.migration.{function_to_patch}"
) as inspect_schema_version, ) as inspect_schema_version,
patch("homeassistant.components.recorder.core.time.sleep"), patch("homeassistant.components.recorder.core.time.sleep"),
): ):

View File

@ -97,6 +97,7 @@ async def test_schema_update_calls(
session_maker, session_maker,
migration.SchemaValidationStatus( migration.SchemaValidationStatus(
current_version=0, current_version=0,
initial_version=0,
migration_needed=True, migration_needed=True,
non_live_data_migration_needed=True, non_live_data_migration_needed=True,
schema_errors=set(), schema_errors=set(),
@ -111,6 +112,7 @@ async def test_schema_update_calls(
session_maker, session_maker,
migration.SchemaValidationStatus( migration.SchemaValidationStatus(
current_version=42, current_version=42,
initial_version=0,
migration_needed=True, migration_needed=True,
non_live_data_migration_needed=True, non_live_data_migration_needed=True,
schema_errors=set(), schema_errors=set(),

View File

@ -1,8 +1,9 @@
"""Test run time migrations are remembered in the migration_changes table.""" """Test run time migrations are remembered in the migration_changes table."""
from collections.abc import Callable
import importlib import importlib
import sys import sys
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
@ -10,6 +11,7 @@ from sqlalchemy.orm import Session
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import core, migration, statistics from homeassistant.components.recorder import core, migration, statistics
from homeassistant.components.recorder.db_schema import SCHEMA_VERSION
from homeassistant.components.recorder.migration import MigrationTask from homeassistant.components.recorder.migration import MigrationTask
from homeassistant.components.recorder.queries import get_migration_changes from homeassistant.components.recorder.queries import get_migration_changes
from homeassistant.components.recorder.util import ( from homeassistant.components.recorder.util import (
@ -25,7 +27,8 @@ from tests.common import async_test_home_assistant
from tests.typing import RecorderInstanceGenerator from tests.typing import RecorderInstanceGenerator
CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine" CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine"
SCHEMA_MODULE = "tests.components.recorder.db_schema_32" SCHEMA_MODULE_32 = "tests.components.recorder.db_schema_32"
SCHEMA_MODULE_CURRENT = "homeassistant.components.recorder.db_schema"
@pytest.fixture @pytest.fixture
@ -46,26 +49,190 @@ def _get_migration_id(hass: HomeAssistant) -> dict[str, int]:
return dict(execute_stmt_lambda_element(session, get_migration_changes())) return dict(execute_stmt_lambda_element(session, get_migration_changes()))
def _create_engine_test(*args, **kwargs): def _create_engine_test(
schema_module: str, *, initial_version: int | None = None
) -> Callable:
"""Test version of create_engine that initializes with old schema. """Test version of create_engine that initializes with old schema.
This simulates an existing db with the old schema. This simulates an existing db with the old schema.
""" """
importlib.import_module(SCHEMA_MODULE)
old_db_schema = sys.modules[SCHEMA_MODULE] def _create_engine_test(*args, **kwargs):
engine = create_engine(*args, **kwargs) """Test version of create_engine that initializes with old schema.
old_db_schema.Base.metadata.create_all(engine)
with Session(engine) as session: This simulates an existing db with the old schema.
session.add( """
recorder.db_schema.StatisticsRuns(start=statistics.get_start_time()) importlib.import_module(schema_module)
) old_db_schema = sys.modules[schema_module]
session.add( engine = create_engine(*args, **kwargs)
recorder.db_schema.SchemaChanges( old_db_schema.Base.metadata.create_all(engine)
schema_version=old_db_schema.SCHEMA_VERSION with Session(engine) as session:
session.add(
recorder.db_schema.StatisticsRuns(start=statistics.get_start_time())
) )
if initial_version is not None:
session.add(
recorder.db_schema.SchemaChanges(schema_version=initial_version)
)
session.add(
recorder.db_schema.SchemaChanges(
schema_version=old_db_schema.SCHEMA_VERSION
)
)
session.commit()
return engine
return _create_engine_test
@pytest.mark.usefixtures("hass_storage") # Prevent test hass from writing to storage
@pytest.mark.parametrize(
("initial_version", "expected_migrator_calls"),
[
(
27,
{
"state_context_id_as_binary": 1,
"event_context_id_as_binary": 1,
"event_type_id_migration": 1,
"entity_id_migration": 1,
"event_id_post_migration": 1,
"entity_id_post_migration": 1,
},
),
(
28,
{
"state_context_id_as_binary": 1,
"event_context_id_as_binary": 1,
"event_type_id_migration": 1,
"entity_id_migration": 1,
"event_id_post_migration": 0,
"entity_id_post_migration": 1,
},
),
(
36,
{
"state_context_id_as_binary": 0,
"event_context_id_as_binary": 0,
"event_type_id_migration": 1,
"entity_id_migration": 1,
"event_id_post_migration": 0,
"entity_id_post_migration": 1,
},
),
(
37,
{
"state_context_id_as_binary": 0,
"event_context_id_as_binary": 0,
"event_type_id_migration": 0,
"entity_id_migration": 1,
"event_id_post_migration": 0,
"entity_id_post_migration": 1,
},
),
(
38,
{
"state_context_id_as_binary": 0,
"event_context_id_as_binary": 0,
"event_type_id_migration": 0,
"entity_id_migration": 0,
"event_id_post_migration": 0,
"entity_id_post_migration": 0,
},
),
(
SCHEMA_VERSION,
{
"state_context_id_as_binary": 0,
"event_context_id_as_binary": 0,
"event_type_id_migration": 0,
"entity_id_migration": 0,
"event_id_post_migration": 0,
"entity_id_post_migration": 0,
},
),
],
)
async def test_data_migrator_new_database(
async_test_recorder: RecorderInstanceGenerator,
initial_version: int,
expected_migrator_calls: dict[str, int],
) -> None:
"""Test that the data migrators are not executed on a new database."""
config = {recorder.CONF_COMMIT_INTERVAL: 1}
def needs_migrate_mock() -> Mock:
return Mock(
spec_set=[],
return_value=migration.DataMigrationStatus(
needs_migrate=False, migration_done=True
),
) )
session.commit()
return engine migrator_mocks = {
"state_context_id_as_binary": needs_migrate_mock(),
"event_context_id_as_binary": needs_migrate_mock(),
"event_type_id_migration": needs_migrate_mock(),
"entity_id_migration": needs_migrate_mock(),
"event_id_post_migration": needs_migrate_mock(),
"entity_id_post_migration": needs_migrate_mock(),
}
with (
patch.object(
migration.StatesContextIDMigration,
"needs_migrate_impl",
side_effect=migrator_mocks["state_context_id_as_binary"],
),
patch.object(
migration.EventsContextIDMigration,
"needs_migrate_impl",
side_effect=migrator_mocks["event_context_id_as_binary"],
),
patch.object(
migration.EventTypeIDMigration,
"needs_migrate_impl",
side_effect=migrator_mocks["event_type_id_migration"],
),
patch.object(
migration.EntityIDMigration,
"needs_migrate_impl",
side_effect=migrator_mocks["entity_id_migration"],
),
patch.object(
migration.EventIDPostMigration,
"needs_migrate_impl",
side_effect=migrator_mocks["event_id_post_migration"],
),
patch.object(
migration.EntityIDPostMigration,
"needs_migrate_impl",
side_effect=migrator_mocks["entity_id_post_migration"],
),
patch(
CREATE_ENGINE_TARGET,
new=_create_engine_test(
SCHEMA_MODULE_CURRENT, initial_version=initial_version
),
),
):
async with (
async_test_home_assistant() as hass,
async_test_recorder(hass, config),
):
await hass.async_block_till_done()
await async_wait_recording_done(hass)
await _async_wait_migration_done(hass)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
await hass.async_stop()
for migrator, mock in migrator_mocks.items():
assert len(mock.mock_calls) == expected_migrator_calls[migrator]
@pytest.mark.parametrize("enable_migrate_state_context_ids", [True]) @pytest.mark.parametrize("enable_migrate_state_context_ids", [True])
@ -84,8 +251,8 @@ async def test_migration_changes_prevent_trying_to_migrate_again(
""" """
config = {recorder.CONF_COMMIT_INTERVAL: 1} config = {recorder.CONF_COMMIT_INTERVAL: 1}
importlib.import_module(SCHEMA_MODULE) importlib.import_module(SCHEMA_MODULE_32)
old_db_schema = sys.modules[SCHEMA_MODULE] old_db_schema = sys.modules[SCHEMA_MODULE_32]
# Start with db schema that needs migration (version 32) # Start with db schema that needs migration (version 32)
with ( with (
@ -98,7 +265,7 @@ async def test_migration_changes_prevent_trying_to_migrate_again(
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes), patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test), patch(CREATE_ENGINE_TARGET, new=_create_engine_test(SCHEMA_MODULE_32)),
): ):
async with ( async with (
async_test_home_assistant() as hass, async_test_home_assistant() as hass,

View File

@ -30,7 +30,9 @@ SCHEMA_MODULE_30 = "tests.components.recorder.db_schema_30"
SCHEMA_MODULE_32 = "tests.components.recorder.db_schema_32" SCHEMA_MODULE_32 = "tests.components.recorder.db_schema_32"
def _create_engine_test(schema_module: str) -> Callable: def _create_engine_test(
schema_module: str, *, initial_version: int | None = None
) -> Callable:
"""Test version of create_engine that initializes with old schema. """Test version of create_engine that initializes with old schema.
This simulates an existing db with the old schema. This simulates an existing db with the old schema.
@ -49,6 +51,10 @@ def _create_engine_test(schema_module: str) -> Callable:
session.add( session.add(
recorder.db_schema.StatisticsRuns(start=statistics.get_start_time()) recorder.db_schema.StatisticsRuns(start=statistics.get_start_time())
) )
if initial_version is not None:
session.add(
recorder.db_schema.SchemaChanges(schema_version=initial_version)
)
session.add( session.add(
recorder.db_schema.SchemaChanges( recorder.db_schema.SchemaChanges(
schema_version=old_db_schema.SCHEMA_VERSION schema_version=old_db_schema.SCHEMA_VERSION
@ -70,7 +76,10 @@ async def test_migrate_times(
async_test_recorder: RecorderInstanceGenerator, async_test_recorder: RecorderInstanceGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test we can migrate times in the events and states tables.""" """Test we can migrate times in the events and states tables.
Also tests entity id post migration.
"""
importlib.import_module(SCHEMA_MODULE_30) importlib.import_module(SCHEMA_MODULE_30)
old_db_schema = sys.modules[SCHEMA_MODULE_30] old_db_schema = sys.modules[SCHEMA_MODULE_30]
now = dt_util.utcnow() now = dt_util.utcnow()
@ -122,7 +131,13 @@ async def test_migrate_times(
patch.object(core, "EventData", old_db_schema.EventData), patch.object(core, "EventData", old_db_schema.EventData),
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test(SCHEMA_MODULE_30)), patch(
CREATE_ENGINE_TARGET,
new=_create_engine_test(
SCHEMA_MODULE_30,
initial_version=27, # Set to 27 for the entity id post migration to run
),
),
): ):
async with ( async with (
async_test_home_assistant() as hass, async_test_home_assistant() as hass,
@ -274,7 +289,13 @@ async def test_migrate_can_resume_entity_id_post_migration(
patch.object(core, "EventData", old_db_schema.EventData), patch.object(core, "EventData", old_db_schema.EventData),
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test(SCHEMA_MODULE_32)), patch(
CREATE_ENGINE_TARGET,
new=_create_engine_test(
SCHEMA_MODULE_32,
initial_version=27, # Set to 27 for the entity id post migration to run
),
),
): ):
async with ( async with (
async_test_home_assistant() as hass, async_test_home_assistant() as hass,
@ -394,7 +415,13 @@ async def test_migrate_can_resume_ix_states_event_id_removed(
patch.object(core, "EventData", old_db_schema.EventData), patch.object(core, "EventData", old_db_schema.EventData),
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test(SCHEMA_MODULE_32)), patch(
CREATE_ENGINE_TARGET,
new=_create_engine_test(
SCHEMA_MODULE_32,
initial_version=27, # Set to 27 for the entity id post migration to run
),
),
): ):
async with ( async with (
async_test_home_assistant() as hass, async_test_home_assistant() as hass,
@ -527,7 +554,13 @@ async def test_out_of_disk_space_while_rebuild_states_table(
patch.object(core, "EventData", old_db_schema.EventData), patch.object(core, "EventData", old_db_schema.EventData),
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test(SCHEMA_MODULE_32)), patch(
CREATE_ENGINE_TARGET,
new=_create_engine_test(
SCHEMA_MODULE_32,
initial_version=27, # Set to 27 for the entity id post migration to run
),
),
): ):
async with ( async with (
async_test_home_assistant() as hass, async_test_home_assistant() as hass,
@ -705,7 +738,13 @@ async def test_out_of_disk_space_while_removing_foreign_key(
patch.object(core, "EventData", old_db_schema.EventData), patch.object(core, "EventData", old_db_schema.EventData),
patch.object(core, "States", old_db_schema.States), patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events), patch.object(core, "Events", old_db_schema.Events),
patch(CREATE_ENGINE_TARGET, new=_create_engine_test(SCHEMA_MODULE_32)), patch(
CREATE_ENGINE_TARGET,
new=_create_engine_test(
SCHEMA_MODULE_32,
initial_version=27, # Set to 27 for the entity id post migration to run
),
),
): ):
async with ( async with (
async_test_home_assistant() as hass, async_test_home_assistant() as hass,