From 4a1c40f09ba18876f31f6aef4b2ed3806fff5bf3 Mon Sep 17 00:00:00 2001 From: Erik Date: Wed, 12 Oct 2022 15:12:12 +0200 Subject: [PATCH] Revert "Refactor recorder migration" This reverts commit 69e10e59821f7e5ca1d4d305079f059774b67864. --- homeassistant/components/recorder/core.py | 39 ++++------ .../components/recorder/migration.py | 72 +++++-------------- tests/components/recorder/test_migrate.py | 12 ++-- 3 files changed, 39 insertions(+), 84 deletions(-) diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index f7d2b774aeb..2530b303e15 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -588,31 +588,24 @@ class Recorder(threading.Thread): def run(self) -> None: """Start processing events to save.""" - setup_result = self._setup_recorder() + current_version = self._setup_recorder() - if not setup_result: - # Give up if we could not connect + if current_version is None: self.hass.add_job(self.async_connection_failed) return - schema_status = migration.validate_db_schema(self.hass, self.get_session) - if schema_status is None: - # Give up if we could not validate the schema - self.hass.add_job(self.async_connection_failed) - return - self.schema_version = schema_status.current_version + self.schema_version = current_version - schema_is_valid = migration.schema_is_valid(schema_status) - - if schema_is_valid: + schema_is_current = migration.schema_is_current(current_version) + if schema_is_current: self._setup_run() else: self.migration_in_progress = True - self.migration_is_live = migration.live_migration(schema_status) + self.migration_is_live = migration.live_migration(current_version) self.hass.add_job(self.async_connection_success) - if self.migration_is_live or schema_is_valid: + if self.migration_is_live or schema_is_current: # If the migrate is live or the schema is current, we need to # wait for startup to complete. If its not live, we need to continue # on. @@ -630,8 +623,8 @@ class Recorder(threading.Thread): self.hass.add_job(self.async_set_db_ready) return - if not schema_is_valid: - if self._migrate_schema_and_setup_run(schema_status): + if not schema_is_current: + if self._migrate_schema_and_setup_run(current_version): self.schema_version = SCHEMA_VERSION if not self._event_listener: # If the schema migration takes so long that the end @@ -696,14 +689,14 @@ class Recorder(threading.Thread): # happens to rollback and recover self._reopen_event_session() - def _setup_recorder(self) -> bool: - """Create a connection to the database.""" + def _setup_recorder(self) -> None | int: + """Create connect to the database and get the schema version.""" tries = 1 while tries <= self.db_max_retries: try: self._setup_connection() - return True + return migration.get_schema_version(self.get_session) except UnsupportedDialect: break except Exception as err: # pylint: disable=broad-except @@ -715,16 +708,14 @@ class Recorder(threading.Thread): tries += 1 time.sleep(self.db_retry_wait) - return False + return None @callback def _async_migration_started(self) -> None: """Set the migration started event.""" self.async_migration_event.set() - def _migrate_schema_and_setup_run( - self, schema_status: migration.SchemaValidationStatus - ) -> bool: + def _migrate_schema_and_setup_run(self, current_version: int) -> bool: """Migrate schema to the latest version.""" persistent_notification.create( self.hass, @@ -736,7 +727,7 @@ class Recorder(threading.Thread): try: migration.migrate_schema( - self, self.hass, self.engine, self.get_session, schema_status + self, self.hass, self.engine, self.get_session, current_version ) except exc.DatabaseError as err: if self._handle_database_error(err): diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 227500aaf0f..3482f9aa942 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -3,7 +3,6 @@ from __future__ import annotations from collections.abc import Callable, Iterable import contextlib -from dataclasses import dataclass from datetime import timedelta import logging from typing import TYPE_CHECKING, cast @@ -62,65 +61,33 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) raise ex -def get_schema_version(session_maker: Callable[[], Session]) -> int | None: +def get_schema_version(session_maker: Callable[[], Session]) -> int: """Get the schema version.""" - try: - with session_scope(session=session_maker()) as session: - res = ( - session.query(SchemaChanges) - .order_by(SchemaChanges.change_id.desc()) - .first() + with session_scope(session=session_maker()) as session: + res = ( + session.query(SchemaChanges) + .order_by(SchemaChanges.change_id.desc()) + .first() + ) + current_version = getattr(res, "schema_version", None) + + if current_version is None: + current_version = _inspect_schema_version(session) + _LOGGER.debug( + "No schema version found. Inspected version: %s", current_version ) - current_version = getattr(res, "schema_version", None) - if current_version is None: - current_version = _inspect_schema_version(session) - _LOGGER.debug( - "No schema version found. Inspected version: %s", current_version - ) - - return cast(int, current_version) - except Exception as err: # pylint: disable=broad-except - _LOGGER.exception("Error when determining DB schema version: %s", err) - return None + return cast(int, current_version) -@dataclass -class SchemaValidationStatus: - """Store schema validation status.""" - - current_version: int - - -def _schema_is_current(current_version: int) -> bool: +def schema_is_current(current_version: int) -> bool: """Check if the schema is current.""" return current_version == SCHEMA_VERSION -def schema_is_valid(schema_status: SchemaValidationStatus) -> bool: - """Check if the schema is valid.""" - return _schema_is_current(schema_status.current_version) - - -def validate_db_schema( - hass: HomeAssistant, session_maker: Callable[[], Session] -) -> SchemaValidationStatus | None: - """Check if the schema is valid. - - This checks that the schema is the current version as well as for some common schema - errors caused by manual migration between database engines, for example importing an - SQLite database to MariaDB. - """ - current_version = get_schema_version(session_maker) - if current_version is None: - return None - - return SchemaValidationStatus(current_version) - - -def live_migration(schema_status: SchemaValidationStatus) -> bool: +def live_migration(current_version: int) -> bool: """Check if live migration is possible.""" - return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION + return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION def migrate_schema( @@ -128,14 +95,13 @@ def migrate_schema( hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session], - schema_status: SchemaValidationStatus, + current_version: int, ) -> None: """Check if the schema needs to be upgraded.""" - current_version = schema_status.current_version _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) db_ready = False for version in range(current_version, SCHEMA_VERSION): - if live_migration(SchemaValidationStatus(version)) and not db_ready: + if live_migration(version) and not db_ready: db_ready = True instance.migration_is_live = True hass.add_job(instance.async_set_db_ready) diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 45268ae819b..9e0609de5b6 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -134,16 +134,14 @@ async def test_database_migration_encounters_corruption(hass): sqlite3_exception.__cause__ = sqlite3.DatabaseError() with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( - "homeassistant.components.recorder.migration._schema_is_current", - side_effect=[False], + "homeassistant.components.recorder.migration.schema_is_current", + side_effect=[False, True], ), patch( "homeassistant.components.recorder.migration.migrate_schema", side_effect=sqlite3_exception, ), patch( "homeassistant.components.recorder.core.move_away_broken_database" - ) as move_away, patch( - "homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics", - ): + ) as move_away: recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} @@ -161,8 +159,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass): assert recorder.util.async_migration_in_progress(hass) is False with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( - "homeassistant.components.recorder.migration._schema_is_current", - side_effect=[False], + "homeassistant.components.recorder.migration.schema_is_current", + side_effect=[False, True], ), patch( "homeassistant.components.recorder.migration.migrate_schema", side_effect=DatabaseError("statement", {}, []),