From 04cc2ae264c6ed43482e24ccf3cae24d191122f9 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 13 Oct 2022 13:01:27 +0200 Subject: [PATCH] Correct initialization of new databases (#80234) --- homeassistant/components/recorder/core.py | 2 +- .../components/recorder/migration.py | 46 +++++++++++-------- tests/components/recorder/test_init.py | 2 +- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index f7d2b774aeb..d5e095d8104 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -703,7 +703,7 @@ class Recorder(threading.Thread): while tries <= self.db_max_retries: try: self._setup_connection() - return True + return migration.initialize_database(self.get_session) except UnsupportedDialect: break except Exception as err: # pylint: disable=broad-except diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 227500aaf0f..22a3b382c7d 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -6,7 +6,7 @@ import contextlib from dataclasses import dataclass from datetime import timedelta import logging -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import sqlalchemy from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text @@ -62,24 +62,17 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) raise ex +def _get_schema_version(session: Session) -> int | None: + """Get the schema version.""" + res = session.query(SchemaChanges).order_by(SchemaChanges.change_id.desc()).first() + return getattr(res, "schema_version", None) + + def get_schema_version(session_maker: Callable[[], Session]) -> int | None: """Get the schema version.""" try: with session_scope(session=session_maker()) as session: - res = ( - session.query(SchemaChanges) - .order_by(SchemaChanges.change_id.desc()) - .first() - ) - current_version = getattr(res, "schema_version", None) - - if current_version is None: - current_version = _inspect_schema_version(session) - _LOGGER.debug( - "No schema version found. Inspected version: %s", current_version - ) - - return cast(int, current_version) + return _get_schema_version(session) except Exception as err: # pylint: disable=broad-except _LOGGER.exception("Error when determining DB schema version: %s", err) return None @@ -797,8 +790,10 @@ def _apply_update( # noqa: C901 raise ValueError(f"No schema migration defined for version {new_version}") -def _inspect_schema_version(session: Session) -> int: - """Determine the schema version by inspecting the db structure. +def _initialize_database(session: Session) -> bool: + """Initialize a new database, or a database created before introducing schema changes. + + The function determines the schema version by inspecting the db structure. When the schema version is not present in the db, either db was just created with the correct schema, or this is a db created before schema @@ -814,9 +809,22 @@ def _inspect_schema_version(session: Session) -> int: # Schema addition from version 1 detected. New DB. session.add(StatisticsRuns(start=get_start_time())) session.add(SchemaChanges(schema_version=SCHEMA_VERSION)) - return SCHEMA_VERSION + return True # Version 1 schema changes not found, this db needs to be migrated. current_version = SchemaChanges(schema_version=0) session.add(current_version) - return cast(int, current_version.schema_version) + return True + + +def initialize_database(session_maker: Callable[[], Session]) -> bool: + """Initialize a new database, or a database created before introducing schema changes.""" + try: + with session_scope(session=session_maker()) as session: + if _get_schema_version(session) is not None: + return True + return _initialize_database(session) + + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error when initialise database: %s", err) + return False diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 54e82516373..9939fc7fb46 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -669,7 +669,7 @@ def test_recorder_validate_schema_failure(hass): """Test some exceptions.""" recorder_helper.async_initialize_recorder(hass) with patch( - "homeassistant.components.recorder.migration._inspect_schema_version" + "homeassistant.components.recorder.migration._get_schema_version" ) as inspect_schema_version, patch( "homeassistant.components.recorder.core.time.sleep" ):