Correct condition signalling non-live DB migration is in progress (#129464)

This commit is contained in:
Erik Montnemery 2024-10-29 23:26:52 +01:00 committed by GitHub
parent 963829712d
commit db5cb6233c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 48 deletions

View File

@ -964,6 +964,7 @@ class Recorder(threading.Thread):
new_schema_status = migration.SchemaValidationStatus( new_schema_status = migration.SchemaValidationStatus(
current_version=SCHEMA_VERSION, current_version=SCHEMA_VERSION,
migration_needed=False, migration_needed=False,
non_live_data_migration_needed=False,
schema_errors=set(), schema_errors=set(),
start_version=SCHEMA_VERSION, start_version=SCHEMA_VERSION,
) )

View File

@ -200,12 +200,13 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
return None return None
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class SchemaValidationStatus: class SchemaValidationStatus:
"""Store schema validation status.""" """Store schema validation status."""
current_version: int current_version: int
migration_needed: bool migration_needed: bool
non_live_data_migration_needed: bool
schema_errors: set[str] schema_errors: set[str]
start_version: int start_version: int
@ -235,12 +236,17 @@ def validate_db_schema(
# columns may otherwise not exist etc. # columns may otherwise not exist etc.
schema_errors = _find_schema_errors(hass, instance, session_maker) schema_errors = _find_schema_errors(hass, instance, session_maker)
migration_needed = not is_current or non_live_data_migration_needed( schema_migration_needed = not is_current
_non_live_data_migration_needed = non_live_data_migration_needed(
instance, session_maker, current_version instance, session_maker, current_version
) )
return SchemaValidationStatus( return SchemaValidationStatus(
current_version, migration_needed, schema_errors, current_version current_version=current_version,
non_live_data_migration_needed=_non_live_data_migration_needed,
migration_needed=schema_migration_needed or _non_live_data_migration_needed,
schema_errors=schema_errors,
start_version=current_version,
) )
@ -257,7 +263,10 @@ def _find_schema_errors(
def live_migration(schema_status: SchemaValidationStatus) -> bool: def live_migration(schema_status: SchemaValidationStatus) -> bool:
"""Check if live migration is possible.""" """Check if live migration is possible."""
return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION return (
schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
and not schema_status.non_live_data_migration_needed
)
def pre_migrate_schema(engine: Engine) -> None: def pre_migrate_schema(engine: Engine) -> None:

View File

@ -95,7 +95,13 @@ async def test_schema_update_calls(
hass, hass,
engine, engine,
session_maker, session_maker,
migration.SchemaValidationStatus(0, True, set(), 0), migration.SchemaValidationStatus(
current_version=0,
migration_needed=True,
non_live_data_migration_needed=True,
schema_errors=set(),
start_version=0,
),
42, 42,
), ),
call( call(
@ -103,7 +109,13 @@ async def test_schema_update_calls(
hass, hass,
engine, engine,
session_maker, session_maker,
migration.SchemaValidationStatus(42, True, set(), 0), migration.SchemaValidationStatus(
current_version=42,
migration_needed=True,
non_live_data_migration_needed=True,
schema_errors=set(),
start_version=0,
),
db_schema.SCHEMA_VERSION, db_schema.SCHEMA_VERSION,
), ),
] ]

View File

@ -49,6 +49,7 @@ from .common import (
async_recorder_block_till_done, async_recorder_block_till_done,
async_wait_recording_done, async_wait_recording_done,
) )
from .conftest import instrument_migration
from tests.common import async_test_home_assistant from tests.common import async_test_home_assistant
from tests.typing import RecorderInstanceGenerator from tests.typing import RecorderInstanceGenerator
@ -266,11 +267,14 @@ async def test_migrate_events_context_ids(
return {event.event_type: _object_as_dict(event) for event in events} return {event.event_type: _object_as_dict(event) for event in events}
# Run again with new schema, let migration run # Run again with new schema, let migration run
with freeze_time(now): async with async_test_home_assistant() as hass:
async with ( with freeze_time(now), instrument_migration(hass) as instrumented_migration:
async_test_home_assistant() as hass, async with async_test_recorder(
async_test_recorder(hass) as instance, hass, wait_recorder=False, wait_recorder_setup=False
): ) as instance:
# Check the context ID migrator is considered non-live
assert recorder.util.async_migration_is_live(hass) is False
instrumented_migration.migration_stall.set()
instance.recorder_and_worker_thread_ids.add(threading.get_ident()) instance.recorder_and_worker_thread_ids.add(threading.get_ident())
await hass.async_block_till_done() await hass.async_block_till_done()
@ -288,7 +292,8 @@ async def test_migrate_events_context_ids(
# Check the index which will be removed by the migrator no longer exists # Check the index which will be removed by the migrator no longer exists
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
assert ( assert (
get_index_by_name(session, "events", "ix_events_context_id") is None get_index_by_name(session, "events", "ix_events_context_id")
is None
) )
await hass.async_stop() await hass.async_stop()
@ -602,10 +607,14 @@ async def test_migrate_states_context_ids(
return {state.entity_id: _object_as_dict(state) for state in events} return {state.entity_id: _object_as_dict(state) for state in events}
# Run again with new schema, let migration run # Run again with new schema, let migration run
async with ( async with async_test_home_assistant() as hass:
async_test_home_assistant() as hass, with instrument_migration(hass) as instrumented_migration:
async_test_recorder(hass) as instance, async with async_test_recorder(
): hass, wait_recorder=False, wait_recorder_setup=False
) as instance:
# Check the context ID migrator is considered non-live
assert recorder.util.async_migration_is_live(hass) is False
instrumented_migration.migration_stall.set()
instance.recorder_and_worker_thread_ids.add(threading.get_ident()) instance.recorder_and_worker_thread_ids.add(threading.get_ident())
await hass.async_block_till_done() await hass.async_block_till_done()
@ -622,7 +631,10 @@ async def test_migrate_states_context_ids(
# Check the index which will be removed by the migrator no longer exists # Check the index which will be removed by the migrator no longer exists
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
assert get_index_by_name(session, "states", "ix_states_context_id") is None assert (
get_index_by_name(session, "states", "ix_states_context_id")
is None
)
await hass.async_stop() await hass.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()