Revert "Refactor recorder migration"

This reverts commit 69e10e59821f7e5ca1d4d305079f059774b67864.
This commit is contained in:
Erik 2022-10-12 15:12:12 +02:00
parent 3a5b66fd60
commit 4a1c40f09b
3 changed files with 39 additions and 84 deletions

View File

@ -588,31 +588,24 @@ class Recorder(threading.Thread):
def run(self) -> None:
"""Start processing events to save."""
setup_result = self._setup_recorder()
current_version = self._setup_recorder()
if not setup_result:
# Give up if we could not connect
if current_version is None:
self.hass.add_job(self.async_connection_failed)
return
schema_status = migration.validate_db_schema(self.hass, self.get_session)
if schema_status is None:
# Give up if we could not validate the schema
self.hass.add_job(self.async_connection_failed)
return
self.schema_version = schema_status.current_version
self.schema_version = current_version
schema_is_valid = migration.schema_is_valid(schema_status)
if schema_is_valid:
schema_is_current = migration.schema_is_current(current_version)
if schema_is_current:
self._setup_run()
else:
self.migration_in_progress = True
self.migration_is_live = migration.live_migration(schema_status)
self.migration_is_live = migration.live_migration(current_version)
self.hass.add_job(self.async_connection_success)
if self.migration_is_live or schema_is_valid:
if self.migration_is_live or schema_is_current:
# If the migrate is live or the schema is current, we need to
# wait for startup to complete. If its not live, we need to continue
# on.
@ -630,8 +623,8 @@ class Recorder(threading.Thread):
self.hass.add_job(self.async_set_db_ready)
return
if not schema_is_valid:
if self._migrate_schema_and_setup_run(schema_status):
if not schema_is_current:
if self._migrate_schema_and_setup_run(current_version):
self.schema_version = SCHEMA_VERSION
if not self._event_listener:
# If the schema migration takes so long that the end
@ -696,14 +689,14 @@ class Recorder(threading.Thread):
# happens to rollback and recover
self._reopen_event_session()
def _setup_recorder(self) -> bool:
"""Create a connection to the database."""
def _setup_recorder(self) -> None | int:
"""Create connect to the database and get the schema version."""
tries = 1
while tries <= self.db_max_retries:
try:
self._setup_connection()
return True
return migration.get_schema_version(self.get_session)
except UnsupportedDialect:
break
except Exception as err: # pylint: disable=broad-except
@ -715,16 +708,14 @@ class Recorder(threading.Thread):
tries += 1
time.sleep(self.db_retry_wait)
return False
return None
@callback
def _async_migration_started(self) -> None:
"""Set the migration started event."""
self.async_migration_event.set()
def _migrate_schema_and_setup_run(
self, schema_status: migration.SchemaValidationStatus
) -> bool:
def _migrate_schema_and_setup_run(self, current_version: int) -> bool:
"""Migrate schema to the latest version."""
persistent_notification.create(
self.hass,
@ -736,7 +727,7 @@ class Recorder(threading.Thread):
try:
migration.migrate_schema(
self, self.hass, self.engine, self.get_session, schema_status
self, self.hass, self.engine, self.get_session, current_version
)
except exc.DatabaseError as err:
if self._handle_database_error(err):

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
from dataclasses import dataclass
from datetime import timedelta
import logging
from typing import TYPE_CHECKING, cast
@ -62,65 +61,33 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
raise ex
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
def get_schema_version(session_maker: Callable[[], Session]) -> int:
"""Get the schema version."""
try:
with session_scope(session=session_maker()) as session:
res = (
session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc())
.first()
with session_scope(session=session_maker()) as session:
res = (
session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc())
.first()
)
current_version = getattr(res, "schema_version", None)
if current_version is None:
current_version = _inspect_schema_version(session)
_LOGGER.debug(
"No schema version found. Inspected version: %s", current_version
)
current_version = getattr(res, "schema_version", None)
if current_version is None:
current_version = _inspect_schema_version(session)
_LOGGER.debug(
"No schema version found. Inspected version: %s", current_version
)
return cast(int, current_version)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error when determining DB schema version: %s", err)
return None
return cast(int, current_version)
@dataclass
class SchemaValidationStatus:
"""Store schema validation status."""
current_version: int
def _schema_is_current(current_version: int) -> bool:
def schema_is_current(current_version: int) -> bool:
"""Check if the schema is current."""
return current_version == SCHEMA_VERSION
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
"""Check if the schema is valid."""
return _schema_is_current(schema_status.current_version)
def validate_db_schema(
hass: HomeAssistant, session_maker: Callable[[], Session]
) -> SchemaValidationStatus | None:
"""Check if the schema is valid.
This checks that the schema is the current version as well as for some common schema
errors caused by manual migration between database engines, for example importing an
SQLite database to MariaDB.
"""
current_version = get_schema_version(session_maker)
if current_version is None:
return None
return SchemaValidationStatus(current_version)
def live_migration(schema_status: SchemaValidationStatus) -> bool:
def live_migration(current_version: int) -> bool:
"""Check if live migration is possible."""
return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
def migrate_schema(
@ -128,14 +95,13 @@ def migrate_schema(
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
schema_status: SchemaValidationStatus,
current_version: int,
) -> None:
"""Check if the schema needs to be upgraded."""
current_version = schema_status.current_version
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
db_ready = False
for version in range(current_version, SCHEMA_VERSION):
if live_migration(SchemaValidationStatus(version)) and not db_ready:
if live_migration(version) and not db_ready:
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_db_ready)

View File

@ -134,16 +134,14 @@ async def test_database_migration_encounters_corruption(hass):
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.migration._schema_is_current",
side_effect=[False],
"homeassistant.components.recorder.migration.schema_is_current",
side_effect=[False, True],
), patch(
"homeassistant.components.recorder.migration.migrate_schema",
side_effect=sqlite3_exception,
), patch(
"homeassistant.components.recorder.core.move_away_broken_database"
) as move_away, patch(
"homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics",
):
) as move_away:
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
@ -161,8 +159,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass):
assert recorder.util.async_migration_in_progress(hass) is False
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.migration._schema_is_current",
side_effect=[False],
"homeassistant.components.recorder.migration.schema_is_current",
side_effect=[False, True],
), patch(
"homeassistant.components.recorder.migration.migrate_schema",
side_effect=DatabaseError("statement", {}, []),