Refactor recorder migration (#80175)

* Refactor recorder migration

* Improve test coverage
This commit is contained in:
Erik Montnemery 2022-10-13 08:11:54 +02:00 committed by GitHub
parent ca4c4774ca
commit 466c4656ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 39 deletions

View File

@ -588,24 +588,31 @@ class Recorder(threading.Thread):
def run(self) -> None:
"""Start processing events to save."""
current_version = self._setup_recorder()
setup_result = self._setup_recorder()
if current_version is None:
if not setup_result:
# Give up if we could not connect
self.hass.add_job(self.async_connection_failed)
return
self.schema_version = current_version
schema_status = migration.validate_db_schema(self.hass, self.get_session)
if schema_status is None:
# Give up if we could not validate the schema
self.hass.add_job(self.async_connection_failed)
return
self.schema_version = schema_status.current_version
schema_is_current = migration.schema_is_current(current_version)
if schema_is_current:
schema_is_valid = migration.schema_is_valid(schema_status)
if schema_is_valid:
self._setup_run()
else:
self.migration_in_progress = True
self.migration_is_live = migration.live_migration(current_version)
self.migration_is_live = migration.live_migration(schema_status)
self.hass.add_job(self.async_connection_success)
if self.migration_is_live or schema_is_current:
if self.migration_is_live or schema_is_valid:
# If the migrate is live or the schema is current, we need to
# wait for startup to complete. If its not live, we need to continue
# on.
@ -623,8 +630,8 @@ class Recorder(threading.Thread):
self.hass.add_job(self.async_set_db_ready)
return
if not schema_is_current:
if self._migrate_schema_and_setup_run(current_version):
if not schema_is_valid:
if self._migrate_schema_and_setup_run(schema_status):
self.schema_version = SCHEMA_VERSION
if not self._event_listener:
# If the schema migration takes so long that the end
@ -689,14 +696,14 @@ class Recorder(threading.Thread):
# happens to rollback and recover
self._reopen_event_session()
def _setup_recorder(self) -> None | int:
"""Create connect to the database and get the schema version."""
def _setup_recorder(self) -> bool:
"""Create a connection to the database."""
tries = 1
while tries <= self.db_max_retries:
try:
self._setup_connection()
return migration.get_schema_version(self.get_session)
return True
except UnsupportedDialect:
break
except Exception as err: # pylint: disable=broad-except
@ -708,14 +715,16 @@ class Recorder(threading.Thread):
tries += 1
time.sleep(self.db_retry_wait)
return None
return False
@callback
def _async_migration_started(self) -> None:
"""Set the migration started event."""
self.async_migration_event.set()
def _migrate_schema_and_setup_run(self, current_version: int) -> bool:
def _migrate_schema_and_setup_run(
self, schema_status: migration.SchemaValidationStatus
) -> bool:
"""Migrate schema to the latest version."""
persistent_notification.create(
self.hass,
@ -727,7 +736,7 @@ class Recorder(threading.Thread):
try:
migration.migrate_schema(
self, self.hass, self.engine, self.get_session, current_version
self, self.hass, self.engine, self.get_session, schema_status
)
except exc.DatabaseError as err:
if self._handle_database_error(err):

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
from dataclasses import dataclass
from datetime import timedelta
import logging
from typing import TYPE_CHECKING, cast
@ -61,8 +62,9 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
raise ex
def get_schema_version(session_maker: Callable[[], Session]) -> int:
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
"""Get the schema version."""
try:
with session_scope(session=session_maker()) as session:
res = (
session.query(SchemaChanges)
@ -78,16 +80,47 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int:
)
return cast(int, current_version)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error when determining DB schema version: %s", err)
return None
def schema_is_current(current_version: int) -> bool:
@dataclass
class SchemaValidationStatus:
"""Store schema validation status."""
current_version: int
def _schema_is_current(current_version: int) -> bool:
"""Check if the schema is current."""
return current_version == SCHEMA_VERSION
def live_migration(current_version: int) -> bool:
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
"""Check if the schema is valid."""
return _schema_is_current(schema_status.current_version)
def validate_db_schema(
hass: HomeAssistant, session_maker: Callable[[], Session]
) -> SchemaValidationStatus | None:
"""Check if the schema is valid.
This checks that the schema is the current version as well as for some common schema
errors caused by manual migration between database engines, for example importing an
SQLite database to MariaDB.
"""
current_version = get_schema_version(session_maker)
if current_version is None:
return None
return SchemaValidationStatus(current_version)
def live_migration(schema_status: SchemaValidationStatus) -> bool:
"""Check if live migration is possible."""
return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
def migrate_schema(
@ -95,13 +128,14 @@ def migrate_schema(
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
current_version: int,
schema_status: SchemaValidationStatus,
) -> None:
"""Check if the schema needs to be upgraded."""
current_version = schema_status.current_version
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
db_ready = False
for version in range(current_version, SCHEMA_VERSION):
if live_migration(version) and not db_ready:
if live_migration(SchemaValidationStatus(version)) and not db_ready:
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_db_ready)

View File

@ -665,6 +665,23 @@ def test_recorder_setup_failure(hass):
hass.stop()
def test_recorder_validate_schema_failure(hass):
"""Test some exceptions."""
recorder_helper.async_initialize_recorder(hass)
with patch(
"homeassistant.components.recorder.migration._inspect_schema_version"
) as inspect_schema_version, patch(
"homeassistant.components.recorder.core.time.sleep"
):
inspect_schema_version.side_effect = ImportError("driver not found")
rec = _default_recorder(hass)
rec.async_initialize()
rec.start()
rec.join()
hass.stop()
def test_recorder_setup_failure_without_event_listener(hass):
"""Test recorder setup failure when the event listener is not setup."""
recorder_helper.async_initialize_recorder(hass)

View File

@ -134,14 +134,16 @@ async def test_database_migration_encounters_corruption(hass):
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.migration.schema_is_current",
side_effect=[False, True],
"homeassistant.components.recorder.migration._schema_is_current",
side_effect=[False],
), patch(
"homeassistant.components.recorder.migration.migrate_schema",
side_effect=sqlite3_exception,
), patch(
"homeassistant.components.recorder.core.move_away_broken_database"
) as move_away:
) as move_away, patch(
"homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics",
):
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
@ -159,8 +161,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass):
assert recorder.util.async_migration_in_progress(hass) is False
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.migration.schema_is_current",
side_effect=[False, True],
"homeassistant.components.recorder.migration._schema_is_current",
side_effect=[False],
), patch(
"homeassistant.components.recorder.migration.migrate_schema",
side_effect=DatabaseError("statement", {}, []),