mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Refactor recorder migration (#80175)
* Refactor recorder migration * Improve test coverage
This commit is contained in:
parent
ca4c4774ca
commit
466c4656ca
@ -588,24 +588,31 @@ class Recorder(threading.Thread):
|
|||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Start processing events to save."""
|
"""Start processing events to save."""
|
||||||
current_version = self._setup_recorder()
|
setup_result = self._setup_recorder()
|
||||||
|
|
||||||
if current_version is None:
|
if not setup_result:
|
||||||
|
# Give up if we could not connect
|
||||||
self.hass.add_job(self.async_connection_failed)
|
self.hass.add_job(self.async_connection_failed)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.schema_version = current_version
|
schema_status = migration.validate_db_schema(self.hass, self.get_session)
|
||||||
|
if schema_status is None:
|
||||||
|
# Give up if we could not validate the schema
|
||||||
|
self.hass.add_job(self.async_connection_failed)
|
||||||
|
return
|
||||||
|
self.schema_version = schema_status.current_version
|
||||||
|
|
||||||
schema_is_current = migration.schema_is_current(current_version)
|
schema_is_valid = migration.schema_is_valid(schema_status)
|
||||||
if schema_is_current:
|
|
||||||
|
if schema_is_valid:
|
||||||
self._setup_run()
|
self._setup_run()
|
||||||
else:
|
else:
|
||||||
self.migration_in_progress = True
|
self.migration_in_progress = True
|
||||||
self.migration_is_live = migration.live_migration(current_version)
|
self.migration_is_live = migration.live_migration(schema_status)
|
||||||
|
|
||||||
self.hass.add_job(self.async_connection_success)
|
self.hass.add_job(self.async_connection_success)
|
||||||
|
|
||||||
if self.migration_is_live or schema_is_current:
|
if self.migration_is_live or schema_is_valid:
|
||||||
# If the migrate is live or the schema is current, we need to
|
# If the migrate is live or the schema is current, we need to
|
||||||
# wait for startup to complete. If its not live, we need to continue
|
# wait for startup to complete. If its not live, we need to continue
|
||||||
# on.
|
# on.
|
||||||
@ -623,8 +630,8 @@ class Recorder(threading.Thread):
|
|||||||
self.hass.add_job(self.async_set_db_ready)
|
self.hass.add_job(self.async_set_db_ready)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not schema_is_current:
|
if not schema_is_valid:
|
||||||
if self._migrate_schema_and_setup_run(current_version):
|
if self._migrate_schema_and_setup_run(schema_status):
|
||||||
self.schema_version = SCHEMA_VERSION
|
self.schema_version = SCHEMA_VERSION
|
||||||
if not self._event_listener:
|
if not self._event_listener:
|
||||||
# If the schema migration takes so long that the end
|
# If the schema migration takes so long that the end
|
||||||
@ -689,14 +696,14 @@ class Recorder(threading.Thread):
|
|||||||
# happens to rollback and recover
|
# happens to rollback and recover
|
||||||
self._reopen_event_session()
|
self._reopen_event_session()
|
||||||
|
|
||||||
def _setup_recorder(self) -> None | int:
|
def _setup_recorder(self) -> bool:
|
||||||
"""Create connect to the database and get the schema version."""
|
"""Create a connection to the database."""
|
||||||
tries = 1
|
tries = 1
|
||||||
|
|
||||||
while tries <= self.db_max_retries:
|
while tries <= self.db_max_retries:
|
||||||
try:
|
try:
|
||||||
self._setup_connection()
|
self._setup_connection()
|
||||||
return migration.get_schema_version(self.get_session)
|
return True
|
||||||
except UnsupportedDialect:
|
except UnsupportedDialect:
|
||||||
break
|
break
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
@ -708,14 +715,16 @@ class Recorder(threading.Thread):
|
|||||||
tries += 1
|
tries += 1
|
||||||
time.sleep(self.db_retry_wait)
|
time.sleep(self.db_retry_wait)
|
||||||
|
|
||||||
return None
|
return False
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_migration_started(self) -> None:
|
def _async_migration_started(self) -> None:
|
||||||
"""Set the migration started event."""
|
"""Set the migration started event."""
|
||||||
self.async_migration_event.set()
|
self.async_migration_event.set()
|
||||||
|
|
||||||
def _migrate_schema_and_setup_run(self, current_version: int) -> bool:
|
def _migrate_schema_and_setup_run(
|
||||||
|
self, schema_status: migration.SchemaValidationStatus
|
||||||
|
) -> bool:
|
||||||
"""Migrate schema to the latest version."""
|
"""Migrate schema to the latest version."""
|
||||||
persistent_notification.create(
|
persistent_notification.create(
|
||||||
self.hass,
|
self.hass,
|
||||||
@ -727,7 +736,7 @@ class Recorder(threading.Thread):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
migration.migrate_schema(
|
migration.migrate_schema(
|
||||||
self, self.hass, self.engine, self.get_session, current_version
|
self, self.hass, self.engine, self.get_session, schema_status
|
||||||
)
|
)
|
||||||
except exc.DatabaseError as err:
|
except exc.DatabaseError as err:
|
||||||
if self._handle_database_error(err):
|
if self._handle_database_error(err):
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import TYPE_CHECKING, cast
|
||||||
@ -61,8 +62,9 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
|
|||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
|
||||||
def get_schema_version(session_maker: Callable[[], Session]) -> int:
|
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
|
||||||
"""Get the schema version."""
|
"""Get the schema version."""
|
||||||
|
try:
|
||||||
with session_scope(session=session_maker()) as session:
|
with session_scope(session=session_maker()) as session:
|
||||||
res = (
|
res = (
|
||||||
session.query(SchemaChanges)
|
session.query(SchemaChanges)
|
||||||
@ -78,16 +80,47 @@ def get_schema_version(session_maker: Callable[[], Session]) -> int:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return cast(int, current_version)
|
return cast(int, current_version)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Error when determining DB schema version: %s", err)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def schema_is_current(current_version: int) -> bool:
|
@dataclass
|
||||||
|
class SchemaValidationStatus:
|
||||||
|
"""Store schema validation status."""
|
||||||
|
|
||||||
|
current_version: int
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_is_current(current_version: int) -> bool:
|
||||||
"""Check if the schema is current."""
|
"""Check if the schema is current."""
|
||||||
return current_version == SCHEMA_VERSION
|
return current_version == SCHEMA_VERSION
|
||||||
|
|
||||||
|
|
||||||
def live_migration(current_version: int) -> bool:
|
def schema_is_valid(schema_status: SchemaValidationStatus) -> bool:
|
||||||
|
"""Check if the schema is valid."""
|
||||||
|
return _schema_is_current(schema_status.current_version)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_db_schema(
|
||||||
|
hass: HomeAssistant, session_maker: Callable[[], Session]
|
||||||
|
) -> SchemaValidationStatus | None:
|
||||||
|
"""Check if the schema is valid.
|
||||||
|
|
||||||
|
This checks that the schema is the current version as well as for some common schema
|
||||||
|
errors caused by manual migration between database engines, for example importing an
|
||||||
|
SQLite database to MariaDB.
|
||||||
|
"""
|
||||||
|
current_version = get_schema_version(session_maker)
|
||||||
|
if current_version is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return SchemaValidationStatus(current_version)
|
||||||
|
|
||||||
|
|
||||||
|
def live_migration(schema_status: SchemaValidationStatus) -> bool:
|
||||||
"""Check if live migration is possible."""
|
"""Check if live migration is possible."""
|
||||||
return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
|
return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
|
||||||
|
|
||||||
|
|
||||||
def migrate_schema(
|
def migrate_schema(
|
||||||
@ -95,13 +128,14 @@ def migrate_schema(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
session_maker: Callable[[], Session],
|
session_maker: Callable[[], Session],
|
||||||
current_version: int,
|
schema_status: SchemaValidationStatus,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Check if the schema needs to be upgraded."""
|
"""Check if the schema needs to be upgraded."""
|
||||||
|
current_version = schema_status.current_version
|
||||||
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
|
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
|
||||||
db_ready = False
|
db_ready = False
|
||||||
for version in range(current_version, SCHEMA_VERSION):
|
for version in range(current_version, SCHEMA_VERSION):
|
||||||
if live_migration(version) and not db_ready:
|
if live_migration(SchemaValidationStatus(version)) and not db_ready:
|
||||||
db_ready = True
|
db_ready = True
|
||||||
instance.migration_is_live = True
|
instance.migration_is_live = True
|
||||||
hass.add_job(instance.async_set_db_ready)
|
hass.add_job(instance.async_set_db_ready)
|
||||||
|
@ -665,6 +665,23 @@ def test_recorder_setup_failure(hass):
|
|||||||
hass.stop()
|
hass.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_recorder_validate_schema_failure(hass):
|
||||||
|
"""Test some exceptions."""
|
||||||
|
recorder_helper.async_initialize_recorder(hass)
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.recorder.migration._inspect_schema_version"
|
||||||
|
) as inspect_schema_version, patch(
|
||||||
|
"homeassistant.components.recorder.core.time.sleep"
|
||||||
|
):
|
||||||
|
inspect_schema_version.side_effect = ImportError("driver not found")
|
||||||
|
rec = _default_recorder(hass)
|
||||||
|
rec.async_initialize()
|
||||||
|
rec.start()
|
||||||
|
rec.join()
|
||||||
|
|
||||||
|
hass.stop()
|
||||||
|
|
||||||
|
|
||||||
def test_recorder_setup_failure_without_event_listener(hass):
|
def test_recorder_setup_failure_without_event_listener(hass):
|
||||||
"""Test recorder setup failure when the event listener is not setup."""
|
"""Test recorder setup failure when the event listener is not setup."""
|
||||||
recorder_helper.async_initialize_recorder(hass)
|
recorder_helper.async_initialize_recorder(hass)
|
||||||
|
@ -134,14 +134,16 @@ async def test_database_migration_encounters_corruption(hass):
|
|||||||
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
|
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
|
||||||
|
|
||||||
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
||||||
"homeassistant.components.recorder.migration.schema_is_current",
|
"homeassistant.components.recorder.migration._schema_is_current",
|
||||||
side_effect=[False, True],
|
side_effect=[False],
|
||||||
), patch(
|
), patch(
|
||||||
"homeassistant.components.recorder.migration.migrate_schema",
|
"homeassistant.components.recorder.migration.migrate_schema",
|
||||||
side_effect=sqlite3_exception,
|
side_effect=sqlite3_exception,
|
||||||
), patch(
|
), patch(
|
||||||
"homeassistant.components.recorder.core.move_away_broken_database"
|
"homeassistant.components.recorder.core.move_away_broken_database"
|
||||||
) as move_away:
|
) as move_away, patch(
|
||||||
|
"homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics",
|
||||||
|
):
|
||||||
recorder_helper.async_initialize_recorder(hass)
|
recorder_helper.async_initialize_recorder(hass)
|
||||||
await async_setup_component(
|
await async_setup_component(
|
||||||
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
|
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
|
||||||
@ -159,8 +161,8 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass):
|
|||||||
assert recorder.util.async_migration_in_progress(hass) is False
|
assert recorder.util.async_migration_in_progress(hass) is False
|
||||||
|
|
||||||
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
|
||||||
"homeassistant.components.recorder.migration.schema_is_current",
|
"homeassistant.components.recorder.migration._schema_is_current",
|
||||||
side_effect=[False, True],
|
side_effect=[False],
|
||||||
), patch(
|
), patch(
|
||||||
"homeassistant.components.recorder.migration.migrate_schema",
|
"homeassistant.components.recorder.migration.migrate_schema",
|
||||||
side_effect=DatabaseError("statement", {}, []),
|
side_effect=DatabaseError("statement", {}, []),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user