mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Correct initialization of new databases (#80234)
This commit is contained in:
parent
acb1477673
commit
04cc2ae264
@ -703,7 +703,7 @@ class Recorder(threading.Thread):
|
|||||||
while tries <= self.db_max_retries:
|
while tries <= self.db_max_retries:
|
||||||
try:
|
try:
|
||||||
self._setup_connection()
|
self._setup_connection()
|
||||||
return True
|
return migration.initialize_database(self.get_session)
|
||||||
except UnsupportedDialect:
|
except UnsupportedDialect:
|
||||||
break
|
break
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
@ -6,7 +6,7 @@ import contextlib
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
|
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
|
||||||
@ -62,24 +62,17 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
|
|||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
|
||||||
|
def _get_schema_version(session: Session) -> int | None:
|
||||||
|
"""Get the schema version."""
|
||||||
|
res = session.query(SchemaChanges).order_by(SchemaChanges.change_id.desc()).first()
|
||||||
|
return getattr(res, "schema_version", None)
|
||||||
|
|
||||||
|
|
||||||
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
|
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
|
||||||
"""Get the schema version."""
|
"""Get the schema version."""
|
||||||
try:
|
try:
|
||||||
with session_scope(session=session_maker()) as session:
|
with session_scope(session=session_maker()) as session:
|
||||||
res = (
|
return _get_schema_version(session)
|
||||||
session.query(SchemaChanges)
|
|
||||||
.order_by(SchemaChanges.change_id.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
current_version = getattr(res, "schema_version", None)
|
|
||||||
|
|
||||||
if current_version is None:
|
|
||||||
current_version = _inspect_schema_version(session)
|
|
||||||
_LOGGER.debug(
|
|
||||||
"No schema version found. Inspected version: %s", current_version
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(int, current_version)
|
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
_LOGGER.exception("Error when determining DB schema version: %s", err)
|
_LOGGER.exception("Error when determining DB schema version: %s", err)
|
||||||
return None
|
return None
|
||||||
@ -797,8 +790,10 @@ def _apply_update( # noqa: C901
|
|||||||
raise ValueError(f"No schema migration defined for version {new_version}")
|
raise ValueError(f"No schema migration defined for version {new_version}")
|
||||||
|
|
||||||
|
|
||||||
def _inspect_schema_version(session: Session) -> int:
|
def _initialize_database(session: Session) -> bool:
|
||||||
"""Determine the schema version by inspecting the db structure.
|
"""Initialize a new database, or a database created before introducing schema changes.
|
||||||
|
|
||||||
|
The function determines the schema version by inspecting the db structure.
|
||||||
|
|
||||||
When the schema version is not present in the db, either db was just
|
When the schema version is not present in the db, either db was just
|
||||||
created with the correct schema, or this is a db created before schema
|
created with the correct schema, or this is a db created before schema
|
||||||
@ -814,9 +809,22 @@ def _inspect_schema_version(session: Session) -> int:
|
|||||||
# Schema addition from version 1 detected. New DB.
|
# Schema addition from version 1 detected. New DB.
|
||||||
session.add(StatisticsRuns(start=get_start_time()))
|
session.add(StatisticsRuns(start=get_start_time()))
|
||||||
session.add(SchemaChanges(schema_version=SCHEMA_VERSION))
|
session.add(SchemaChanges(schema_version=SCHEMA_VERSION))
|
||||||
return SCHEMA_VERSION
|
return True
|
||||||
|
|
||||||
# Version 1 schema changes not found, this db needs to be migrated.
|
# Version 1 schema changes not found, this db needs to be migrated.
|
||||||
current_version = SchemaChanges(schema_version=0)
|
current_version = SchemaChanges(schema_version=0)
|
||||||
session.add(current_version)
|
session.add(current_version)
|
||||||
return cast(int, current_version.schema_version)
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_database(session_maker: Callable[[], Session]) -> bool:
|
||||||
|
"""Initialize a new database, or a database created before introducing schema changes."""
|
||||||
|
try:
|
||||||
|
with session_scope(session=session_maker()) as session:
|
||||||
|
if _get_schema_version(session) is not None:
|
||||||
|
return True
|
||||||
|
return _initialize_database(session)
|
||||||
|
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Error when initialise database: %s", err)
|
||||||
|
return False
|
||||||
|
@ -669,7 +669,7 @@ def test_recorder_validate_schema_failure(hass):
|
|||||||
"""Test some exceptions."""
|
"""Test some exceptions."""
|
||||||
recorder_helper.async_initialize_recorder(hass)
|
recorder_helper.async_initialize_recorder(hass)
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.recorder.migration._inspect_schema_version"
|
"homeassistant.components.recorder.migration._get_schema_version"
|
||||||
) as inspect_schema_version, patch(
|
) as inspect_schema_version, patch(
|
||||||
"homeassistant.components.recorder.core.time.sleep"
|
"homeassistant.components.recorder.core.time.sleep"
|
||||||
):
|
):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user