mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Refactor recorder migration (#80175)
* Refactor recorder migration * Improve test coverage
This commit is contained in:
parent
ca4c4774ca
commit
466c4656ca
@ -588,24 +588,31 @@ class Recorder(threading.Thread):
|
||||
|
||||
def run(self) -> None:
|
||||
"""Start processing events to save."""
|
||||
current_version = self._setup_recorder()
|
||||
setup_result = self._setup_recorder()
|
||||
|
||||
if current_version is None:
|
||||
if not setup_result:
|
||||
# Give up if we could not connect
|
||||
self.hass.add_job(self.async_connection_failed)
|
||||
return
|
||||
|
||||
self.schema_version = current_version
|
||||
schema_status = migration.validate_db_schema(self.hass, self.get_session)
|
||||
if schema_status is None:
|
||||
# Give up if we could not validate the schema
|
||||
self.hass.add_job(self.async_connection_failed)
|
||||
return
|
||||
self.schema_version = schema_status.current_version
|
||||
|
||||
schema_is_current = migration.schema_is_current(current_version)
|
||||
if schema_is_current:
|
||||
schema_is_valid = migration.schema_is_valid(schema_status)
|
||||
|
||||
if schema_is_valid:
|
||||
self._setup_run()
|
||||
else:
|
||||
self.migration_in_progress = True
|
||||
self.migration_is_live = migration.live_migration(current_version)
|
||||
self.migration_is_live = migration.live_migration(schema_status)
|
||||
|
||||
self.hass.add_job(self.async_connection_success)
|
||||
|
||||
if self.migration_is_live or schema_is_current:
|
||||
if self.migration_is_live or schema_is_valid:
|
||||
# If the migrate is live or the schema is current, we need to
|
||||
# wait for startup to complete. If its not live, we need to continue
|
||||
# on.
|
||||
@ -623,8 +630,8 @@ class Recorder(threading.Thread):
|
||||
self.hass.add_job(self.async_set_db_ready)
|
||||
return
|
||||
|
||||
if not schema_is_current:
|
||||
if self._migrate_schema_and_setup_run(current_version):
|
||||
if not schema_is_valid:
|
||||
if self._migrate_schema_and_setup_run(schema_status):
|
||||
self.schema_version = SCHEMA_VERSION
|
||||
if not self._event_listener:
|
||||
# If the schema migration takes so long that the end
|
||||
@ -689,14 +696,14 @@ class Recorder(threading.Thread):
|
||||
# happens to rollback and recover
|
||||
self._reopen_event_session()
|
||||
|
||||
def _setup_recorder(self) -> None | int:
|
||||
"""Create connect to the database and get the schema version."""
|
||||
def _setup_recorder(self) -> bool:
|
||||
"""Create a connection to the database."""
|
||||
tries = 1
|
||||
|
||||
while tries <= self.db_max_retries:
|
||||
try:
|
||||
self._setup_connection()
|
||||
return migration.get_schema_version(self.get_session)
|
||||
return True
|
||||
except UnsupportedDialect:
|
||||
break
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
@ -708,14 +715,16 @@ class Recorder(threading.Thread):
|
||||
tries += 1
|
||||
time.sleep(self.db_retry_wait)
|
||||
|
||||
return None
|
||||
return False
|
||||
|
||||
@callback
|
||||
def _async_migration_started(self) -> None:
|
||||
"""Set the migration started event."""
|
||||
self.async_migration_event.set()
|
||||
|
||||
def _migrate_schema_and_setup_run(self, current_version: int) -> bool:
|
||||
def _migrate_schema_and_setup_run(
|
||||
self, schema_status: migration.SchemaValidationStatus
|
||||
) -> bool:
|
||||
"""Migrate schema to the latest version."""
|
||||
persistent_notification.create(
|
||||
self.hass,
|
||||
@ -727,7 +736,7 @@ class Recorder(threading.Thread):
|
||||
|
||||
try:
|
||||
migration.migrate_schema(
|
||||
self, self.hass, self.engine, self.get_session, current_version
|
||||
self, self.hass, self.engine, self.get_session, schema_status
|
||||
)
|
||||
except exc.DatabaseError as err:
|
||||
if self._handle_database_error(err):
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast
|
||||
@ -61,8 +62,9 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
|
||||
raise ex
|
||||
|
||||
|
||||
def get_schema_version(session_maker: Callable[[], Session]) -> int:
|
||||
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
|
||||
"""Get the schema version."""
|
||||
try:
|
||||
with session_scope(session=session_maker()) as session:
|
||||
res = (
|
||||
session.query(SchemaChanges)
|
||||
@ -78,16 +80,47 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int:
|
||||
)
|
||||
|
||||
return cast(int, current_version)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error when determining DB schema version: %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def schema_is_current(current_version: int) -> bool:
|
||||
@dataclass
|
||||
class SchemaValidationStatus:
|
||||
"""Store schema validation status."""
|
||||
|
||||
current_version: int
|
||||
|
||||
|
||||
def _schema_is_current(current_version: int) -> bool:
|
||||
"""Check if the schema is current."""
|
||||
return current_version == SCHEMA_VERSION
|
||||
|
||||
|
||||
def live_migration(current_version: int) -> bool:
|
||||
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
|
||||
"""Check if the schema is valid."""
|
||||
return _schema_is_current(schema_status.current_version)
|
||||
|
||||
|
||||
def validate_db_schema(
|
||||
hass: HomeAssistant, session_maker: Callable[[], Session]
|
||||
) -> SchemaValidationStatus | None:
|
||||
"""Check if the schema is valid.
|
||||
|
||||
This checks that the schema is the current version as well as for some common schema
|
||||
errors caused by manual migration between database engines, for example importing an
|
||||
SQLite database to MariaDB.
|
||||
"""
|
||||
current_version = get_schema_version(session_maker)
|
||||
if current_version is None:
|
||||
return None
|
||||
|
||||
return SchemaValidationStatus(current_version)
|
||||
|
||||
|
||||
def live_migration(schema_status: SchemaValidationStatus) -> bool:
|
||||
"""Check if live migration is possible."""
|
||||
return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
|
||||
return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
|
||||
|
||||
|
||||
def migrate_schema(
|
||||
@ -95,13 +128,14 @@ def migrate_schema(
|
||||
hass: HomeAssistant,
|
||||
engine: Engine,
|
||||
session_maker: Callable[[], Session],
|
||||
current_version: int,
|
||||
schema_status: SchemaValidationStatus,
|
||||
) -> None:
|
||||
"""Check if the schema needs to be upgraded."""
|
||||
current_version = schema_status.current_version
|
||||
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
|
||||
db_ready = False
|
||||
for version in range(current_version, SCHEMA_VERSION):
|
||||
if live_migration(version) and not db_ready:
|
||||
if live_migration(SchemaValidationStatus(version)) and not db_ready:
|
||||
db_ready = True
|
||||
instance.migration_is_live = True
|
||||
hass.add_job(instance.async_set_db_ready)
|
||||
|
@ -665,6 +665,23 @@ def test_recorder_setup_failure(hass):
|
||||
hass.stop()
|
||||
|
||||
|
||||
def test_recorder_validate_schema_failure(hass):
|
||||
"""Test some exceptions."""
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
with patch(
|
||||
"homeassistant.components.recorder.migration._inspect_schema_version"
|
||||
) as inspect_schema_version, patch(
|
||||
"homeassistant.components.recorder.core.time.sleep"
|
||||
):
|
||||
inspect_schema_version.side_effect = ImportError("driver not found")
|
||||
rec = _default_recorder(hass)
|
||||
rec.async_initialize()
|
||||
rec.start()
|
||||
rec.join()
|
||||
|
||||
hass.stop()
|
||||
|
||||
|
||||
def test_recorder_setup_failure_without_event_listener(hass):
|
||||
"""Test recorder setup failure when the event listener is not setup."""
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
|
@ -134,14 +134,16 @@ async def test_database_migration_encounters_corruption(hass):
|
||||
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
|
||||
|
||||
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
||||
"homeassistant.components.recorder.migration.schema_is_current",
|
||||
side_effect=[False, True],
|
||||
"homeassistant.components.recorder.migration._schema_is_current",
|
||||
side_effect=[False],
|
||||
), patch(
|
||||
"homeassistant.components.recorder.migration.migrate_schema",
|
||||
side_effect=sqlite3_exception,
|
||||
), patch(
|
||||
"homeassistant.components.recorder.core.move_away_broken_database"
|
||||
) as move_away:
|
||||
) as move_away, patch(
|
||||
"homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics",
|
||||
):
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
await async_setup_component(
|
||||
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
|
||||
@ -159,8 +161,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass):
|
||||
assert recorder.util.async_migration_in_progress(hass) is False
|
||||
|
||||
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
||||
"homeassistant.components.recorder.migration.schema_is_current",
|
||||
side_effect=[False, True],
|
||||
"homeassistant.components.recorder.migration._schema_is_current",
|
||||
side_effect=[False],
|
||||
), patch(
|
||||
"homeassistant.components.recorder.migration.migrate_schema",
|
||||
side_effect=DatabaseError("statement", {}, []),
|
||||
|
Loading…
x
Reference in New Issue
Block a user