Correct initialization of new databases (#80234)

This commit is contained in:
Erik Montnemery 2022-10-13 13:01:27 +02:00 committed by GitHub
parent acb1477673
commit 04cc2ae264
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 21 deletions

View File

@ -703,7 +703,7 @@ class Recorder(threading.Thread):
while tries <= self.db_max_retries: while tries <= self.db_max_retries:
try: try:
self._setup_connection() self._setup_connection()
return True return migration.initialize_database(self.get_session)
except UnsupportedDialect: except UnsupportedDialect:
break break
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except

View File

@ -6,7 +6,7 @@ import contextlib
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING
import sqlalchemy import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -62,24 +62,17 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
raise ex raise ex
def _get_schema_version(session: Session) -> int | None:
"""Get the schema version."""
res = session.query(SchemaChanges).order_by(SchemaChanges.change_id.desc()).first()
return getattr(res, "schema_version", None)
def get_schema_version(session_maker: Callable[[], Session]) -> int | None: def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
"""Get the schema version.""" """Get the schema version."""
try: try:
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
res = ( return _get_schema_version(session)
session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc())
.first()
)
current_version = getattr(res, "schema_version", None)
if current_version is None:
current_version = _inspect_schema_version(session)
_LOGGER.debug(
"No schema version found. Inspected version: %s", current_version
)
return cast(int, current_version)
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error when determining DB schema version: %s", err) _LOGGER.exception("Error when determining DB schema version: %s", err)
return None return None
@ -797,8 +790,10 @@ def _apply_update( # noqa: C901
raise ValueError(f"No schema migration defined for version {new_version}") raise ValueError(f"No schema migration defined for version {new_version}")
def _inspect_schema_version(session: Session) -> int: def _initialize_database(session: Session) -> bool:
"""Determine the schema version by inspecting the db structure. """Initialize a new database, or a database created before introducing schema changes.
The function determines the schema version by inspecting the db structure.
When the schema version is not present in the db, either db was just When the schema version is not present in the db, either db was just
created with the correct schema, or this is a db created before schema created with the correct schema, or this is a db created before schema
@ -814,9 +809,22 @@ def _inspect_schema_version(session: Session) -> int:
# Schema addition from version 1 detected. New DB. # Schema addition from version 1 detected. New DB.
session.add(StatisticsRuns(start=get_start_time())) session.add(StatisticsRuns(start=get_start_time()))
session.add(SchemaChanges(schema_version=SCHEMA_VERSION)) session.add(SchemaChanges(schema_version=SCHEMA_VERSION))
return SCHEMA_VERSION return True
# Version 1 schema changes not found, this db needs to be migrated. # Version 1 schema changes not found, this db needs to be migrated.
current_version = SchemaChanges(schema_version=0) current_version = SchemaChanges(schema_version=0)
session.add(current_version) session.add(current_version)
return cast(int, current_version.schema_version) return True
def initialize_database(session_maker: Callable[[], Session]) -> bool:
"""Initialize a new database, or a database created before introducing schema changes."""
try:
with session_scope(session=session_maker()) as session:
if _get_schema_version(session) is not None:
return True
return _initialize_database(session)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error when initialise database: %s", err)
return False

View File

@ -669,7 +669,7 @@ def test_recorder_validate_schema_failure(hass):
"""Test some exceptions.""" """Test some exceptions."""
recorder_helper.async_initialize_recorder(hass) recorder_helper.async_initialize_recorder(hass)
with patch( with patch(
"homeassistant.components.recorder.migration._inspect_schema_version" "homeassistant.components.recorder.migration._get_schema_version"
) as inspect_schema_version, patch( ) as inspect_schema_version, patch(
"homeassistant.components.recorder.core.time.sleep" "homeassistant.components.recorder.core.time.sleep"
): ):