diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 2530b303e15..f7d2b774aeb 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -588,24 +588,31 @@ class Recorder(threading.Thread): def run(self) -> None: """Start processing events to save.""" - current_version = self._setup_recorder() + setup_result = self._setup_recorder() - if current_version is None: + if not setup_result: + # Give up if we could not connect self.hass.add_job(self.async_connection_failed) return - self.schema_version = current_version + schema_status = migration.validate_db_schema(self.hass, self.get_session) + if schema_status is None: + # Give up if we could not validate the schema + self.hass.add_job(self.async_connection_failed) + return + self.schema_version = schema_status.current_version - schema_is_current = migration.schema_is_current(current_version) - if schema_is_current: + schema_is_valid = migration.schema_is_valid(schema_status) + + if schema_is_valid: self._setup_run() else: self.migration_in_progress = True - self.migration_is_live = migration.live_migration(current_version) + self.migration_is_live = migration.live_migration(schema_status) self.hass.add_job(self.async_connection_success) - if self.migration_is_live or schema_is_current: + if self.migration_is_live or schema_is_valid: # If the migrate is live or the schema is current, we need to # wait for startup to complete. If its not live, we need to continue # on. @@ -623,8 +630,8 @@ class Recorder(threading.Thread): self.hass.add_job(self.async_set_db_ready) return - if not schema_is_current: - if self._migrate_schema_and_setup_run(current_version): + if not schema_is_valid: + if self._migrate_schema_and_setup_run(schema_status): self.schema_version = SCHEMA_VERSION if not self._event_listener: # If the schema migration takes so long that the end @@ -689,14 +696,14 @@ class Recorder(threading.Thread): # happens to rollback and recover self._reopen_event_session() - def _setup_recorder(self) -> None | int: - """Create connect to the database and get the schema version.""" + def _setup_recorder(self) -> bool: + """Create a connection to the database.""" tries = 1 while tries <= self.db_max_retries: try: self._setup_connection() - return migration.get_schema_version(self.get_session) + return True except UnsupportedDialect: break except Exception as err: # pylint: disable=broad-except @@ -708,14 +715,16 @@ class Recorder(threading.Thread): tries += 1 time.sleep(self.db_retry_wait) - return None + return False @callback def _async_migration_started(self) -> None: """Set the migration started event.""" self.async_migration_event.set() - def _migrate_schema_and_setup_run(self, current_version: int) -> bool: + def _migrate_schema_and_setup_run( + self, schema_status: migration.SchemaValidationStatus + ) -> bool: """Migrate schema to the latest version.""" persistent_notification.create( self.hass, @@ -727,7 +736,7 @@ class Recorder(threading.Thread): try: migration.migrate_schema( - self, self.hass, self.engine, self.get_session, current_version + self, self.hass, self.engine, self.get_session, schema_status ) except exc.DatabaseError as err: if self._handle_database_error(err): diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 3482f9aa942..227500aaf0f 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable import contextlib +from dataclasses import dataclass from datetime import timedelta import logging from typing import TYPE_CHECKING, cast @@ -61,33 +62,65 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) raise ex -def get_schema_version(session_maker: Callable[[], Session]) -> int: +def get_schema_version(session_maker: Callable[[], Session]) -> int | None: """Get the schema version.""" - with session_scope(session=session_maker()) as session: - res = ( - session.query(SchemaChanges) - .order_by(SchemaChanges.change_id.desc()) - .first() - ) - current_version = getattr(res, "schema_version", None) - - if current_version is None: - current_version = _inspect_schema_version(session) - _LOGGER.debug( - "No schema version found. Inspected version: %s", current_version + try: + with session_scope(session=session_maker()) as session: + res = ( + session.query(SchemaChanges) + .order_by(SchemaChanges.change_id.desc()) + .first() ) + current_version = getattr(res, "schema_version", None) - return cast(int, current_version) + if current_version is None: + current_version = _inspect_schema_version(session) + _LOGGER.debug( + "No schema version found. Inspected version: %s", current_version + ) + + return cast(int, current_version) + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error when determining DB schema version: %s", err) + return None -def schema_is_current(current_version: int) -> bool: +@dataclass +class SchemaValidationStatus: + """Store schema validation status.""" + + current_version: int + + +def _schema_is_current(current_version: int) -> bool: """Check if the schema is current.""" return current_version == SCHEMA_VERSION -def live_migration(current_version: int) -> bool: +def schema_is_valid(schema_status: SchemaValidationStatus) -> bool: + """Check if the schema is valid.""" + return _schema_is_current(schema_status.current_version) + + +def validate_db_schema( + hass: HomeAssistant, session_maker: Callable[[], Session] +) -> SchemaValidationStatus | None: + """Check if the schema is valid. + + This checks that the schema is the current version as well as for some common schema + errors caused by manual migration between database engines, for example importing an + SQLite database to MariaDB. + """ + current_version = get_schema_version(session_maker) + if current_version is None: + return None + + return SchemaValidationStatus(current_version) + + +def live_migration(schema_status: SchemaValidationStatus) -> bool: """Check if live migration is possible.""" - return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION + return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION def migrate_schema( @@ -95,13 +128,14 @@ def migrate_schema( hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session], - current_version: int, + schema_status: SchemaValidationStatus, ) -> None: """Check if the schema needs to be upgraded.""" + current_version = schema_status.current_version _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) db_ready = False for version in range(current_version, SCHEMA_VERSION): - if live_migration(version) and not db_ready: + if live_migration(SchemaValidationStatus(version)) and not db_ready: db_ready = True instance.migration_is_live = True hass.add_job(instance.async_set_db_ready) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 815af89198d..977e32e9a71 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -665,6 +665,23 @@ def test_recorder_setup_failure(hass): hass.stop() +def test_recorder_validate_schema_failure(hass): + """Test some exceptions.""" + recorder_helper.async_initialize_recorder(hass) + with patch( + "homeassistant.components.recorder.migration._inspect_schema_version" + ) as inspect_schema_version, patch( + "homeassistant.components.recorder.core.time.sleep" + ): + inspect_schema_version.side_effect = ImportError("driver not found") + rec = _default_recorder(hass) + rec.async_initialize() + rec.start() + rec.join() + + hass.stop() + + def test_recorder_setup_failure_without_event_listener(hass): """Test recorder setup failure when the event listener is not setup.""" recorder_helper.async_initialize_recorder(hass) diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 9e0609de5b6..45268ae819b 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -134,14 +134,16 @@ async def test_database_migration_encounters_corruption(hass): sqlite3_exception.__cause__ = sqlite3.DatabaseError() with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( - "homeassistant.components.recorder.migration.schema_is_current", - side_effect=[False, True], + "homeassistant.components.recorder.migration._schema_is_current", + side_effect=[False], ), patch( "homeassistant.components.recorder.migration.migrate_schema", side_effect=sqlite3_exception, ), patch( "homeassistant.components.recorder.core.move_away_broken_database" - ) as move_away: + ) as move_away, patch( + "homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics", + ): recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} @@ -159,8 +161,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass): assert recorder.util.async_migration_in_progress(hass) is False with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( - "homeassistant.components.recorder.migration.schema_is_current", - side_effect=[False, True], + "homeassistant.components.recorder.migration._schema_is_current", + side_effect=[False], ), patch( "homeassistant.components.recorder.migration.migrate_schema", side_effect=DatabaseError("statement", {}, []),