From fd6ffef52f337df71542b48565a95300c0ab2766 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 22 Jul 2022 15:11:34 +0200 Subject: [PATCH] Support non-live database migration (#72433) * Support non-live database migration * Tweak startup order, add test * Address review comments * Fix typo * Clarify comment about promoting dependencies * Tweak * Fix merge mistake * Fix some tests * Fix additional test * Fix additional test * Adjust tests * Improve test coverage --- homeassistant/bootstrap.py | 75 +- homeassistant/components/recorder/__init__.py | 1 - homeassistant/components/recorder/core.py | 42 +- .../components/recorder/migration.py | 15 +- homeassistant/components/recorder/models.py | 10 + .../components/recorder/statistics.py | 6 +- homeassistant/components/recorder/tasks.py | 2 +- homeassistant/components/recorder/util.py | 14 +- .../components/recorder/websocket_api.py | 4 +- homeassistant/helpers/recorder.py | 26 +- tests/common.py | 3 + tests/components/default_config/test_init.py | 2 + tests/components/recorder/db_schema_25.py | 673 ++++++++++++++++++ tests/components/recorder/test_init.py | 25 +- tests/components/recorder/test_migrate.py | 47 +- tests/components/recorder/test_statistics.py | 5 + .../recorder/test_statistics_v23_migration.py | 9 + .../components/recorder/test_websocket_api.py | 15 +- tests/conftest.py | 4 +- tests/test_bootstrap.py | 76 ++ 20 files changed, 993 insertions(+), 61 deletions(-) create mode 100644 tests/components/recorder/db_schema_25.py diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index b8ec5987142..d2858bfcdf1 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -24,7 +24,7 @@ from .const import ( SIGNAL_BOOTSTRAP_INTEGRATONS, ) from .exceptions import HomeAssistantError -from .helpers import area_registry, device_registry, entity_registry +from .helpers import area_registry, device_registry, entity_registry, recorder from .helpers.dispatcher import async_dispatcher_send from .helpers.typing import ConfigType from .setup import ( @@ -66,6 +66,15 @@ LOGGING_INTEGRATIONS = { # Error logging "system_log", "sentry", +} +FRONTEND_INTEGRATIONS = { + # Get the frontend up and running as soon as possible so problem + # integrations can be removed and database migration status is + # visible in frontend + "frontend", +} +RECORDER_INTEGRATIONS = { + # Setup after frontend # To record data "recorder", } @@ -83,10 +92,6 @@ STAGE_1_INTEGRATIONS = { "cloud", # Ensure supervisor is available "hassio", - # Get the frontend up and running as soon - # as possible so problem integrations can - # be removed - "frontend", } @@ -504,11 +509,43 @@ async def _async_set_up_integrations( _LOGGER.info("Domains to be set up: %s", domains_to_setup) + def _cache_uname_processor() -> None: + """Cache the result of platform.uname().processor in the executor. + + Multiple modules call this function at startup which + executes a blocking subprocess call. This is a problem for the + asyncio event loop. By primeing the cache of uname we can + avoid the blocking call in the event loop. + """ + platform.uname().processor # pylint: disable=expression-not-assigned + + # Load the registries and cache the result of platform.uname().processor + await asyncio.gather( + device_registry.async_load(hass), + entity_registry.async_load(hass), + area_registry.async_load(hass), + hass.async_add_executor_job(_cache_uname_processor), + ) + + # Initialize recorder + if "recorder" in domains_to_setup: + recorder.async_initialize_recorder(hass) + # Load logging as soon as possible if logging_domains := domains_to_setup & LOGGING_INTEGRATIONS: _LOGGER.info("Setting up logging: %s", logging_domains) await async_setup_multi_components(hass, logging_domains, config) + # Setup frontend + if frontend_domains := domains_to_setup & FRONTEND_INTEGRATIONS: + _LOGGER.info("Setting up frontend: %s", frontend_domains) + await async_setup_multi_components(hass, frontend_domains, config) + + # Setup recorder + if recorder_domains := domains_to_setup & RECORDER_INTEGRATIONS: + _LOGGER.info("Setting up recorder: %s", recorder_domains) + await async_setup_multi_components(hass, recorder_domains, config) + # Start up debuggers. Start these first in case they want to wait. if debuggers := domains_to_setup & DEBUGGER_INTEGRATIONS: _LOGGER.debug("Setting up debuggers: %s", debuggers) @@ -518,7 +555,8 @@ async def _async_set_up_integrations( stage_1_domains: set[str] = set() # Find all dependencies of any dependency of any stage 1 integration that - # we plan on loading and promote them to stage 1 + # we plan on loading and promote them to stage 1. This is done only to not + # get misleading log messages deps_promotion: set[str] = STAGE_1_INTEGRATIONS while deps_promotion: old_deps_promotion = deps_promotion @@ -535,24 +573,13 @@ async def _async_set_up_integrations( deps_promotion.update(dep_itg.all_dependencies) - stage_2_domains = domains_to_setup - logging_domains - debuggers - stage_1_domains - - def _cache_uname_processor() -> None: - """Cache the result of platform.uname().processor in the executor. - - Multiple modules call this function at startup which - executes a blocking subprocess call. This is a problem for the - asyncio event loop. By primeing the cache of uname we can - avoid the blocking call in the event loop. - """ - platform.uname().processor # pylint: disable=expression-not-assigned - - # Load the registries - await asyncio.gather( - device_registry.async_load(hass), - entity_registry.async_load(hass), - area_registry.async_load(hass), - hass.async_add_executor_job(_cache_uname_processor), + stage_2_domains = ( + domains_to_setup + - logging_domains + - frontend_domains + - recorder_domains + - debuggers + - stage_1_domains ) # Start setup diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 238b013f366..f9ed5f59333 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -123,7 +123,6 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the recorder.""" - hass.data[DOMAIN] = {} exclude_attributes_by_domain: dict[str, set[str]] = {} hass.data[EXCLUDE_ATTRIBUTES] = exclude_attributes_by_domain conf = config[DOMAIN] diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index f3ae79f9909..92e10b47126 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -43,6 +43,7 @@ import homeassistant.util.dt as dt_util from . import migration, statistics from .const import ( DB_WORKER_PREFIX, + DOMAIN, KEEPALIVE_TIME, MAX_QUEUE_BACKLOG, MYSQLDB_URL_PREFIX, @@ -166,7 +167,12 @@ class Recorder(threading.Thread): self.db_max_retries = db_max_retries self.db_retry_wait = db_retry_wait self.engine_version: AwesomeVersion | None = None + # Database connection is ready, but non-live migration may be in progress + db_connected: asyncio.Future[bool] = hass.data[DOMAIN].db_connected + self.async_db_connected: asyncio.Future[bool] = db_connected + # Database is ready to use but live migration may be in progress self.async_db_ready: asyncio.Future[bool] = asyncio.Future() + # Database is ready to use and all migration steps completed (used by tests) self.async_recorder_ready = asyncio.Event() self._queue_watch = threading.Event() self.engine: Engine | None = None @@ -188,6 +194,7 @@ class Recorder(threading.Thread): self._completed_first_database_setup: bool | None = None self.async_migration_event = asyncio.Event() self.migration_in_progress = False + self.migration_is_live = False self._database_lock_task: DatabaseLockTask | None = None self._db_executor: DBInterruptibleThreadPoolExecutor | None = None self._exclude_attributes_by_domain = exclude_attributes_by_domain @@ -289,7 +296,8 @@ class Recorder(threading.Thread): def _stop_executor(self) -> None: """Stop the executor.""" - assert self._db_executor is not None + if self._db_executor is None: + return self._db_executor.shutdown() self._db_executor = None @@ -410,6 +418,7 @@ class Recorder(threading.Thread): @callback def async_connection_failed(self) -> None: """Connect failed tasks.""" + self.async_db_connected.set_result(False) self.async_db_ready.set_result(False) persistent_notification.async_create( self.hass, @@ -420,13 +429,29 @@ class Recorder(threading.Thread): @callback def async_connection_success(self) -> None: - """Connect success tasks.""" + """Connect to the database succeeded, schema version and migration need known. + + The database may not yet be ready for use in case of a non-live migration. + """ + self.async_db_connected.set_result(True) + + @callback + def async_set_recorder_ready(self) -> None: + """Database live and ready for use. + + Called after non-live migration steps are finished. + """ + if self.async_db_ready.done(): + return self.async_db_ready.set_result(True) self.async_start_executor() @callback - def _async_recorder_ready(self) -> None: - """Finish start and mark recorder ready.""" + def _async_set_recorder_ready_migration_done(self) -> None: + """Finish start and mark recorder ready. + + Called after all migration steps are finished. + """ self._async_setup_periodic_tasks() self.async_recorder_ready.set() @@ -548,6 +573,7 @@ class Recorder(threading.Thread): self._setup_run() else: self.migration_in_progress = True + self.migration_is_live = migration.live_migration(current_version) self.hass.add_job(self.async_connection_success) @@ -557,6 +583,7 @@ class Recorder(threading.Thread): # Make sure we cleanly close the run if # we restart before startup finishes self._shutdown() + self.hass.add_job(self.async_set_recorder_ready) return # We wait to start the migration until startup has finished @@ -577,11 +604,14 @@ class Recorder(threading.Thread): "Database Migration Failed", "recorder_database_migration", ) + self.hass.add_job(self.async_set_recorder_ready) self._shutdown() return + self.hass.add_job(self.async_set_recorder_ready) + _LOGGER.debug("Recorder processing the queue") - self.hass.add_job(self._async_recorder_ready) + self.hass.add_job(self._async_set_recorder_ready_migration_done) self._run_event_loop() def _run_event_loop(self) -> None: @@ -659,7 +689,7 @@ class Recorder(threading.Thread): try: migration.migrate_schema( - self.hass, self.engine, self.get_session, current_version + self, self.hass, self.engine, self.get_session, current_version ) except exc.DatabaseError as err: if self._handle_database_error(err): diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 7e11e62502d..de6fd8f01fe 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Iterable import contextlib from datetime import timedelta import logging -from typing import cast +from typing import Any, cast import sqlalchemy from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text @@ -40,6 +40,8 @@ from .statistics import ( ) from .util import session_scope +LIVE_MIGRATION_MIN_SCHEMA_VERSION = 0 + _LOGGER = logging.getLogger(__name__) @@ -78,7 +80,13 @@ def schema_is_current(current_version: int) -> bool: return current_version == SCHEMA_VERSION +def live_migration(current_version: int) -> bool: + """Check if live migration is possible.""" + return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION + + def migrate_schema( + instance: Any, hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session], @@ -86,7 +94,12 @@ def migrate_schema( ) -> None: """Check if the schema needs to be upgraded.""" _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) + db_ready = False for version in range(current_version, SCHEMA_VERSION): + if live_migration(version) and not db_ready: + db_ready = True + instance.migration_is_live = True + hass.add_job(instance.async_set_recorder_ready) new_version = version + 1 _LOGGER.info("Upgrading recorder db schema to version %s", new_version) _apply_update(hass, engine, session_maker, new_version, current_version) diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index ff53d9be3d1..98c9fc7c9b2 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -1,6 +1,8 @@ """Models for Recorder.""" from __future__ import annotations +import asyncio +from dataclasses import dataclass, field from datetime import datetime import logging from typing import Any, TypedDict, overload @@ -30,6 +32,14 @@ class UnsupportedDialect(Exception): """The dialect or its version is not supported.""" +@dataclass +class RecorderData: + """Recorder data stored in hass.data.""" + + recorder_platforms: dict[str, Any] = field(default_factory=dict) + db_connected: asyncio.Future = field(default_factory=asyncio.Future) + + class StatisticResult(TypedDict): """Statistic result data class. diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index bce77e8a31e..fcd8e4f3930 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -576,7 +576,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool: platform_stats: list[StatisticResult] = [] current_metadata: dict[str, tuple[int, StatisticMetaData]] = {} # Collect statistics from all platforms implementing support - for domain, platform in instance.hass.data[DOMAIN].items(): + for domain, platform in instance.hass.data[DOMAIN].recorder_platforms.items(): if not hasattr(platform, "compile_statistics"): continue compiled: PlatformCompiledStatistics = platform.compile_statistics( @@ -851,7 +851,7 @@ def list_statistic_ids( } # Query all integrations with a registered recorder platform - for platform in hass.data[DOMAIN].values(): + for platform in hass.data[DOMAIN].recorder_platforms.values(): if not hasattr(platform, "list_statistic_ids"): continue platform_statistic_ids = platform.list_statistic_ids( @@ -1339,7 +1339,7 @@ def _sorted_statistics_to_dict( def validate_statistics(hass: HomeAssistant) -> dict[str, list[ValidationIssue]]: """Validate statistics.""" platform_validation: dict[str, list[ValidationIssue]] = {} - for platform in hass.data[DOMAIN].values(): + for platform in hass.data[DOMAIN].recorder_platforms.values(): if not hasattr(platform, "validate_statistics"): continue platform_validation.update(platform.validate_statistics(hass)) diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 6d1c9c360ab..cdb97d9d67c 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -249,7 +249,7 @@ class AddRecorderPlatformTask(RecorderTask): domain = self.domain platform = self.platform - platforms: dict[str, Any] = hass.data[DOMAIN] + platforms: dict[str, Any] = hass.data[DOMAIN].recorder_platforms platforms[domain] = platform if hasattr(self.platform, "exclude_attributes"): hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass) diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index fdf42665ef5..ddc8747f79b 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -552,7 +552,7 @@ def write_lock_db_sqlite(instance: Recorder) -> Generator[None, None, None]: def async_migration_in_progress(hass: HomeAssistant) -> bool: - """Determine is a migration is in progress. + """Determine if a migration is in progress. This is a thin wrapper that allows us to change out the implementation later. @@ -563,6 +563,18 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool: return instance.migration_in_progress +def async_migration_is_live(hass: HomeAssistant) -> bool: + """Determine if a migration is live. + + This is a thin wrapper that allows us to change + out the implementation later. + """ + if DATA_INSTANCE not in hass.data: + return False + instance: Recorder = hass.data[DATA_INSTANCE] + return instance.migration_is_live + + def second_sunday(year: int, month: int) -> date: """Return the datetime.date for the second sunday of a month.""" second = date(year, month, FIRST_POSSIBLE_SUNDAY) diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index c143d8b4f0b..16813944780 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -17,7 +17,7 @@ from .statistics import ( list_statistic_ids, validate_statistics, ) -from .util import async_migration_in_progress, get_instance +from .util import async_migration_in_progress, async_migration_is_live, get_instance _LOGGER: logging.Logger = logging.getLogger(__package__) @@ -193,6 +193,7 @@ def ws_info( backlog = instance.backlog if instance else None migration_in_progress = async_migration_in_progress(hass) + migration_is_live = async_migration_is_live(hass) recording = instance.recording if instance else False thread_alive = instance.is_alive() if instance else False @@ -200,6 +201,7 @@ def ws_info( "backlog": backlog, "max_backlog": MAX_QUEUE_BACKLOG, "migration_in_progress": migration_in_progress, + "migration_is_live": migration_is_live, "recording": recording, "thread_running": thread_alive, } diff --git a/homeassistant/helpers/recorder.py b/homeassistant/helpers/recorder.py index a51d9de59e2..2049300e460 100644 --- a/homeassistant/helpers/recorder.py +++ b/homeassistant/helpers/recorder.py @@ -1,7 +1,8 @@ """Helpers to check recorder.""" +import asyncio -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback def async_migration_in_progress(hass: HomeAssistant) -> bool: @@ -12,3 +13,26 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool: from homeassistant.components import recorder return recorder.util.async_migration_in_progress(hass) + + +@callback +def async_initialize_recorder(hass: HomeAssistant) -> None: + """Initialize recorder data.""" + # pylint: disable-next=import-outside-toplevel + from homeassistant.components.recorder import const, models + + hass.data[const.DOMAIN] = models.RecorderData() + + +async def async_wait_recorder(hass: HomeAssistant) -> bool: + """Wait for recorder to initialize and return connection status. + + Returns False immediately if the recorder is not enabled. + """ + # pylint: disable-next=import-outside-toplevel + from homeassistant.components.recorder import const + + if const.DOMAIN not in hass.data: + return False + db_connected: asyncio.Future[bool] = hass.data[const.DOMAIN].db_connected + return await db_connected diff --git a/tests/common.py b/tests/common.py index 80f0913cace..acc50e26889 100644 --- a/tests/common.py +++ b/tests/common.py @@ -50,6 +50,7 @@ from homeassistant.helpers import ( entity_platform, entity_registry, intent, + recorder as recorder_helper, restore_state, storage, ) @@ -914,6 +915,8 @@ def init_recorder_component(hass, add_config=None): with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( "homeassistant.components.recorder.migration.migrate_schema" ): + if recorder.DOMAIN not in hass.data: + recorder_helper.async_initialize_recorder(hass) assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config}) assert recorder.DOMAIN in hass.config.components _LOGGER.info( diff --git a/tests/components/default_config/test_init.py b/tests/components/default_config/test_init.py index 7701eb55b90..d82b4109839 100644 --- a/tests/components/default_config/test_init.py +++ b/tests/components/default_config/test_init.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest +from homeassistant.helpers import recorder as recorder_helper from homeassistant.setup import async_setup_component from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 @@ -24,4 +25,5 @@ def recorder_url_mock(): async def test_setup(hass, mock_zeroconf, mock_get_source_ip): """Test setup.""" + recorder_helper.async_initialize_recorder(hass) assert await async_setup_component(hass, "default_config", {"foo": "bar"}) diff --git a/tests/components/recorder/db_schema_25.py b/tests/components/recorder/db_schema_25.py new file mode 100644 index 00000000000..43aa245a761 --- /dev/null +++ b/tests/components/recorder/db_schema_25.py @@ -0,0 +1,673 @@ +"""Models for SQLAlchemy.""" +from __future__ import annotations + +from datetime import datetime, timedelta +import json +import logging +from typing import Any, TypedDict, cast, overload + +from fnvhash import fnv1a_32 +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Identity, + Index, + Integer, + String, + Text, + distinct, +) +from sqlalchemy.dialects import mysql, oracle, postgresql +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy.orm.session import Session + +from homeassistant.components.recorder.const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP +from homeassistant.const import ( + MAX_LENGTH_EVENT_CONTEXT_ID, + MAX_LENGTH_EVENT_EVENT_TYPE, + MAX_LENGTH_EVENT_ORIGIN, + MAX_LENGTH_STATE_ENTITY_ID, + MAX_LENGTH_STATE_STATE, +) +from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +from homeassistant.helpers.typing import UNDEFINED, UndefinedType +import homeassistant.util.dt as dt_util + +# SQLAlchemy Schema +# pylint: disable=invalid-name +Base = declarative_base() + +SCHEMA_VERSION = 25 + +_LOGGER = logging.getLogger(__name__) + +DB_TIMEZONE = "+00:00" + +TABLE_EVENTS = "events" +TABLE_STATES = "states" +TABLE_STATE_ATTRIBUTES = "state_attributes" +TABLE_RECORDER_RUNS = "recorder_runs" +TABLE_SCHEMA_CHANGES = "schema_changes" +TABLE_STATISTICS = "statistics" +TABLE_STATISTICS_META = "statistics_meta" +TABLE_STATISTICS_RUNS = "statistics_runs" +TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" + +ALL_TABLES = [ + TABLE_STATES, + TABLE_EVENTS, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, + TABLE_STATISTICS, + TABLE_STATISTICS_META, + TABLE_STATISTICS_RUNS, + TABLE_STATISTICS_SHORT_TERM, +] + +EMPTY_JSON_OBJECT = "{}" + + +DATETIME_TYPE = DateTime(timezone=True).with_variant( + mysql.DATETIME(timezone=True, fsp=6), "mysql" +) +DOUBLE_TYPE = ( + Float() + .with_variant(mysql.DOUBLE(asdecimal=False), "mysql") + .with_variant(oracle.DOUBLE_PRECISION(), "oracle") + .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") +) + + +class Events(Base): # type: ignore[misc,valid-type] + """Event history data.""" + + __table_args__ = ( + # Used for fetching events at a specific time + # see logbook + Index("ix_events_event_type_time_fired", "event_type", "time_fired"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_EVENTS + event_id = Column(Integer, Identity(), primary_key=True) + event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) + event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) + time_fired = Column(DATETIME_TYPE, index=True) + context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event( + event: Event, event_data: UndefinedType | None = UNDEFINED + ) -> Events: + """Create an event database object from a native event.""" + return Events( + event_type=event.event_type, + event_data=JSON_DUMP(event.data) if event_data is UNDEFINED else event_data, + origin=str(event.origin.value), + time_fired=event.time_fired, + context_id=event.context.id, + context_user_id=event.context.user_id, + context_parent_id=event.context.parent_id, + ) + + def to_native(self, validate_entity_id: bool = True) -> Event | None: + """Convert to a native HA Event.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id, + parent_id=self.context_parent_id, + ) + try: + return Event( + self.event_type, + json.loads(self.event_data), + EventOrigin(self.origin), + process_timestamp(self.time_fired), + context=context, + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting to event: %s", self) + return None + + +class States(Base): # type: ignore[misc,valid-type] + """State change history.""" + + __table_args__ = ( + # Used for fetching the state of entities at a specific time + # (get_states in history.py) + Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATES + state_id = Column(Integer, Identity(), primary_key=True) + entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID)) + state = Column(String(MAX_LENGTH_STATE_STATE)) + attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + event_id = Column( + Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True + ) + last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow) + last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True) + old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True) + attributes_id = Column( + Integer, ForeignKey("state_attributes.attributes_id"), index=True + ) + event = relationship("Events", uselist=False) + old_state = relationship("States", remote_side=[state_id]) + state_attributes = relationship("StateAttributes") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> States: + """Create object from a state_changed event.""" + entity_id = event.data["entity_id"] + state: State | None = event.data.get("new_state") + dbstate = States(entity_id=entity_id, attributes=None) + + # None state means the state was removed from the state machine + if state is None: + dbstate.state = "" + dbstate.last_changed = event.time_fired + dbstate.last_updated = event.time_fired + else: + dbstate.state = state.state + dbstate.last_changed = state.last_changed + dbstate.last_updated = state.last_updated + + return dbstate + + def to_native(self, validate_entity_id: bool = True) -> State | None: + """Convert to an HA state object.""" + try: + return State( + self.entity_id, + self.state, + # Join the state_attributes table on attributes_id to get the attributes + # for newer states + json.loads(self.attributes) if self.attributes else {}, + process_timestamp(self.last_changed), + process_timestamp(self.last_updated), + # Join the events table on event_id to get the context instead + # as it will always be there for state_changed events + context=Context(id=None), # type: ignore[arg-type] + validate_entity_id=validate_entity_id, + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting row to state: %s", self) + return None + + +class StateAttributes(Base): # type: ignore[misc,valid-type] + """State attribute change history.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATE_ATTRIBUTES + attributes_id = Column(Integer, Identity(), primary_key=True) + hash = Column(BigInteger, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> StateAttributes: + """Create object from a state_changed event.""" + state: State | None = event.data.get("new_state") + # None state means the state was removed from the state machine + dbstate = StateAttributes( + shared_attrs="{}" if state is None else JSON_DUMP(state.attributes) + ) + dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs) + return dbstate + + @staticmethod + def shared_attrs_from_event( + event: Event, exclude_attrs_by_domain: dict[str, set[str]] + ) -> str: + """Create shared_attrs from a state_changed event.""" + state: State | None = event.data.get("new_state") + # None state means the state was removed from the state machine + if state is None: + return "{}" + domain = split_entity_id(state.entity_id)[0] + exclude_attrs = ( + exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS + ) + return JSON_DUMP( + {k: v for k, v in state.attributes.items() if k not in exclude_attrs} + ) + + @staticmethod + def hash_shared_attrs(shared_attrs: str) -> int: + """Return the hash of json encoded shared attributes.""" + return cast(int, fnv1a_32(shared_attrs.encode("utf-8"))) + + def to_native(self) -> dict[str, Any]: + """Convert to an HA state object.""" + try: + return cast(dict[str, Any], json.loads(self.shared_attrs)) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting row to state attributes: %s", self) + return {} + + +class StatisticResult(TypedDict): + """Statistic result data class. + + Allows multiple datapoints for the same statistic_id. + """ + + meta: StatisticMetaData + stat: StatisticData + + +class StatisticDataBase(TypedDict): + """Mandatory fields for statistic data class.""" + + start: datetime + + +class StatisticData(StatisticDataBase, total=False): + """Statistic data class.""" + + mean: float + min: float + max: float + last_reset: datetime | None + state: float + sum: float + + +class StatisticsBase: + """Statistics base class.""" + + id = Column(Integer, Identity(), primary_key=True) + created = Column(DATETIME_TYPE, default=dt_util.utcnow) + + @declared_attr # type: ignore[misc] + def metadata_id(self) -> Column: + """Define the metadata_id column for sub classes.""" + return Column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) + + start = Column(DATETIME_TYPE, index=True) + mean = Column(DOUBLE_TYPE) + min = Column(DOUBLE_TYPE) + max = Column(DOUBLE_TYPE) + last_reset = Column(DATETIME_TYPE) + state = Column(DOUBLE_TYPE) + sum = Column(DOUBLE_TYPE) + + @classmethod + def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase: + """Create object from a statistics.""" + return cls( # type: ignore[call-arg,misc] + metadata_id=metadata_id, + **stats, + ) + + +class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type] + """Long term statistics.""" + + duration = timedelta(hours=1) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True), + ) + __tablename__ = TABLE_STATISTICS + + +class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type] + """Short term statistics.""" + + duration = timedelta(minutes=5) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_short_term_statistic_id_start", + "metadata_id", + "start", + unique=True, + ), + ) + __tablename__ = TABLE_STATISTICS_SHORT_TERM + + +class StatisticMetaData(TypedDict): + """Statistic meta data class.""" + + has_mean: bool + has_sum: bool + name: str | None + source: str + statistic_id: str + unit_of_measurement: str | None + + +class StatisticsMeta(Base): # type: ignore[misc,valid-type] + """Statistics meta data.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATISTICS_META + id = Column(Integer, Identity(), primary_key=True) + statistic_id = Column(String(255), index=True) + source = Column(String(32)) + unit_of_measurement = Column(String(255)) + has_mean = Column(Boolean) + has_sum = Column(Boolean) + name = Column(String(255)) + + @staticmethod + def from_meta(meta: StatisticMetaData) -> StatisticsMeta: + """Create object from meta data.""" + return StatisticsMeta(**meta) + + +class RecorderRuns(Base): # type: ignore[misc,valid-type] + """Representation of recorder run.""" + + __table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),) + __tablename__ = TABLE_RECORDER_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True), default=dt_util.utcnow) + end = Column(DateTime(timezone=True)) + closed_incorrect = Column(Boolean, default=False) + created = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + end = ( + f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None + ) + return ( + f"" + ) + + def entity_ids(self, point_in_time: datetime | None = None) -> list[str]: + """Return the entity ids that existed in this run. + + Specify point_in_time if you want to know which existed at that point + in time inside the run. + """ + session = Session.object_session(self) + + assert session is not None, "RecorderRuns need to be persisted" + + query = session.query(distinct(States.entity_id)).filter( + States.last_updated >= self.start + ) + + if point_in_time is not None: + query = query.filter(States.last_updated < point_in_time) + elif self.end is not None: + query = query.filter(States.last_updated < self.end) + + return [row[0] for row in query] + + def to_native(self, validate_entity_id: bool = True) -> RecorderRuns: + """Return self, native format is this model.""" + return self + + +class SchemaChanges(Base): # type: ignore[misc,valid-type] + """Representation of schema version changes.""" + + __tablename__ = TABLE_SCHEMA_CHANGES + change_id = Column(Integer, Identity(), primary_key=True) + schema_version = Column(Integer) + changed = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +class StatisticsRuns(Base): # type: ignore[misc,valid-type] + """Representation of statistics run.""" + + __tablename__ = TABLE_STATISTICS_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True)) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +@overload +def process_timestamp(ts: None) -> None: + ... + + +@overload +def process_timestamp(ts: datetime) -> datetime: + ... + + +def process_timestamp(ts: datetime | None) -> datetime | None: + """Process a timestamp into datetime object.""" + if ts is None: + return None + if ts.tzinfo is None: + return ts.replace(tzinfo=dt_util.UTC) + + return dt_util.as_utc(ts) + + +@overload +def process_timestamp_to_utc_isoformat(ts: None) -> None: + ... + + +@overload +def process_timestamp_to_utc_isoformat(ts: datetime) -> str: + ... + + +def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None: + """Process a timestamp into UTC isotime.""" + if ts is None: + return None + if ts.tzinfo == dt_util.UTC: + return ts.isoformat() + if ts.tzinfo is None: + return f"{ts.isoformat()}{DB_TIMEZONE}" + return ts.astimezone(dt_util.UTC).isoformat() + + +class LazyState(State): + """A lazy version of core State.""" + + __slots__ = [ + "_row", + "_attributes", + "_last_changed", + "_last_updated", + "_context", + "_attr_cache", + ] + + def __init__( # pylint: disable=super-init-not-called + self, row: Row, attr_cache: dict[str, dict[str, Any]] | None = None + ) -> None: + """Init the lazy state.""" + self._row = row + self.entity_id: str = self._row.entity_id + self.state = self._row.state or "" + self._attributes: dict[str, Any] | None = None + self._last_changed: datetime | None = None + self._last_updated: datetime | None = None + self._context: Context | None = None + self._attr_cache = attr_cache + + @property # type: ignore[override] + def attributes(self) -> dict[str, Any]: # type: ignore[override] + """State attributes.""" + if self._attributes is None: + source = self._row.shared_attrs or self._row.attributes + if self._attr_cache is not None and ( + attributes := self._attr_cache.get(source) + ): + self._attributes = attributes + return attributes + if source == EMPTY_JSON_OBJECT or source is None: + self._attributes = {} + return self._attributes + try: + self._attributes = json.loads(source) + except ValueError: + # When json.loads fails + _LOGGER.exception( + "Error converting row to state attributes: %s", self._row + ) + self._attributes = {} + if self._attr_cache is not None: + self._attr_cache[source] = self._attributes + return self._attributes + + @attributes.setter + def attributes(self, value: dict[str, Any]) -> None: + """Set attributes.""" + self._attributes = value + + @property # type: ignore[override] + def context(self) -> Context: # type: ignore[override] + """State context.""" + if self._context is None: + self._context = Context(id=None) # type: ignore[arg-type] + return self._context + + @context.setter + def context(self, value: Context) -> None: + """Set context.""" + self._context = value + + @property # type: ignore[override] + def last_changed(self) -> datetime: # type: ignore[override] + """Last changed datetime.""" + if self._last_changed is None: + self._last_changed = process_timestamp(self._row.last_changed) + return self._last_changed + + @last_changed.setter + def last_changed(self, value: datetime) -> None: + """Set last changed datetime.""" + self._last_changed = value + + @property # type: ignore[override] + def last_updated(self) -> datetime: # type: ignore[override] + """Last updated datetime.""" + if self._last_updated is None: + if (last_updated := self._row.last_updated) is not None: + self._last_updated = process_timestamp(last_updated) + else: + self._last_updated = self.last_changed + return self._last_updated + + @last_updated.setter + def last_updated(self, value: datetime) -> None: + """Set last updated datetime.""" + self._last_updated = value + + def as_dict(self) -> dict[str, Any]: # type: ignore[override] + """Return a dict representation of the LazyState. + + Async friendly. + + To be used for JSON serialization. + """ + if self._last_changed is None and self._last_updated is None: + last_changed_isoformat = process_timestamp_to_utc_isoformat( + self._row.last_changed + ) + if ( + self._row.last_updated is None + or self._row.last_changed == self._row.last_updated + ): + last_updated_isoformat = last_changed_isoformat + else: + last_updated_isoformat = process_timestamp_to_utc_isoformat( + self._row.last_updated + ) + else: + last_changed_isoformat = self.last_changed.isoformat() + if self.last_changed == self.last_updated: + last_updated_isoformat = last_changed_isoformat + else: + last_updated_isoformat = self.last_updated.isoformat() + return { + "entity_id": self.entity_id, + "state": self.state, + "attributes": self._attributes or self.attributes, + "last_changed": last_changed_isoformat, + "last_updated": last_updated_isoformat, + } + + def __eq__(self, other: Any) -> bool: + """Return the comparison.""" + return ( + other.__class__ in [self.__class__, State] + and self.entity_id == other.entity_id + and self.state == other.state + and self.attributes == other.attributes + ) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 82444f86a05..0c3a41ab8ef 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -51,6 +51,7 @@ from homeassistant.const import ( STATE_UNLOCKED, ) from homeassistant.core import CoreState, Event, HomeAssistant, callback +from homeassistant.helpers import recorder as recorder_helper from homeassistant.setup import async_setup_component, setup_component from homeassistant.util import dt as dt_util @@ -100,9 +101,10 @@ async def test_shutdown_before_startup_finishes( } hass.state = CoreState.not_running - instance = await async_setup_recorder_instance(hass, config) - await instance.async_db_ready - await hass.async_block_till_done() + recorder_helper.async_initialize_recorder(hass) + hass.create_task(async_setup_recorder_instance(hass, config)) + await recorder_helper.async_wait_recorder(hass) + instance = get_instance(hass) session = await hass.async_add_executor_job(instance.get_session) @@ -125,9 +127,11 @@ async def test_canceled_before_startup_finishes( ): """Test recorder shuts down when its startup future is canceled out from under it.""" hass.state = CoreState.not_running - await async_setup_recorder_instance(hass) + recorder_helper.async_initialize_recorder(hass) + hass.create_task(async_setup_recorder_instance(hass)) + await recorder_helper.async_wait_recorder(hass) + instance = get_instance(hass) - await instance.async_db_ready instance._hass_started.cancel() with patch.object(instance, "engine"): await hass.async_block_till_done() @@ -170,7 +174,9 @@ async def test_state_gets_saved_when_set_before_start_event( hass.state = CoreState.not_running - await async_setup_recorder_instance(hass) + recorder_helper.async_initialize_recorder(hass) + hass.create_task(async_setup_recorder_instance(hass)) + await recorder_helper.async_wait_recorder(hass) entity_id = "test.recorder" state = "restoring_from_db" @@ -643,6 +649,7 @@ def test_saving_state_and_removing_entity(hass, hass_recorder): def test_recorder_setup_failure(hass): """Test some exceptions.""" + recorder_helper.async_initialize_recorder(hass) with patch.object(Recorder, "_setup_connection") as setup, patch( "homeassistant.components.recorder.core.time.sleep" ): @@ -657,6 +664,7 @@ def test_recorder_setup_failure(hass): def test_recorder_setup_failure_without_event_listener(hass): """Test recorder setup failure when the event listener is not setup.""" + recorder_helper.async_initialize_recorder(hass) with patch.object(Recorder, "_setup_connection") as setup, patch( "homeassistant.components.recorder.core.time.sleep" ): @@ -985,6 +993,7 @@ def test_compile_missing_statistics(tmpdir): ): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) hass.start() wait_recording_done(hass) @@ -1006,6 +1015,7 @@ def test_compile_missing_statistics(tmpdir): ): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) hass.start() wait_recording_done(hass) @@ -1197,6 +1207,7 @@ def test_service_disable_run_information_recorded(tmpdir): dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) hass.start() wait_recording_done(hass) @@ -1218,6 +1229,7 @@ def test_service_disable_run_information_recorded(tmpdir): hass.stop() hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) hass.start() wait_recording_done(hass) @@ -1246,6 +1258,7 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog): test_db_file = await hass.async_add_executor_job(_create_tmpdir_for_test_db) dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + recorder_helper.async_initialize_recorder(hass) assert await async_setup_component( hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl, CONF_COMMIT_INTERVAL: 0}} ) diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index a57bd246f8b..bbac01bb5d3 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -27,6 +27,7 @@ from homeassistant.components.recorder.db_schema import ( States, ) from homeassistant.components.recorder.util import session_scope +from homeassistant.helpers import recorder as recorder_helper import homeassistant.util.dt as dt_util from .common import async_wait_recording_done, create_engine_test @@ -53,6 +54,7 @@ async def test_schema_update_calls(hass): "homeassistant.components.recorder.migration._apply_update", wraps=migration._apply_update, ) as update: + recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} ) @@ -74,10 +76,11 @@ async def test_migration_in_progress(hass): """Test that we can check for migration in progress.""" assert recorder.util.async_migration_in_progress(hass) is False - with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True,), patch( + with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( "homeassistant.components.recorder.core.create_engine", new=create_engine_test, ): + recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} ) @@ -105,6 +108,7 @@ async def test_database_migration_failed(hass): "homeassistant.components.persistent_notification.dismiss", side_effect=pn.dismiss, ) as mock_dismiss: + recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} ) @@ -136,6 +140,7 @@ async def test_database_migration_encounters_corruption(hass): ), patch( "homeassistant.components.recorder.core.move_away_broken_database" ) as move_away: + recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} ) @@ -165,6 +170,7 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass): "homeassistant.components.persistent_notification.dismiss", side_effect=pn.dismiss, ) as mock_dismiss: + recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", {"recorder": {"db_url": "sqlite://"}} ) @@ -189,6 +195,7 @@ async def test_events_during_migration_are_queued(hass): "homeassistant.components.recorder.core.create_engine", new=create_engine_test, ): + recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", @@ -219,6 +226,7 @@ async def test_events_during_migration_queue_exhausted(hass): "homeassistant.components.recorder.core.create_engine", new=create_engine_test, ), patch.object(recorder.core, "MAX_QUEUE_BACKLOG", 1): + recorder_helper.async_initialize_recorder(hass) await async_setup_component( hass, "recorder", @@ -247,8 +255,11 @@ async def test_events_during_migration_queue_exhausted(hass): assert len(db_states) == 2 -@pytest.mark.parametrize("start_version", [0, 16, 18, 22]) -async def test_schema_migrate(hass, start_version): +@pytest.mark.parametrize( + "start_version,live", + [(0, True), (16, True), (18, True), (22, True), (25, True)], +) +async def test_schema_migrate(hass, start_version, live): """Test the full schema migration logic. We're just testing that the logic can execute successfully here without @@ -259,7 +270,8 @@ async def test_schema_migrate(hass, start_version): migration_done = threading.Event() migration_stall = threading.Event() migration_version = None - real_migration = recorder.migration.migrate_schema + real_migrate_schema = recorder.migration.migrate_schema + real_apply_update = recorder.migration._apply_update def _create_engine_test(*args, **kwargs): """Test version of create_engine that initializes with old schema. @@ -284,14 +296,12 @@ async def test_schema_migrate(hass, start_version): start=self.run_history.recording_start, created=dt_util.utcnow() ) - def _instrument_migration(*args): + def _instrument_migrate_schema(*args): """Control migration progress and check results.""" nonlocal migration_done nonlocal migration_version - nonlocal migration_stall - migration_stall.wait() try: - real_migration(*args) + real_migrate_schema(*args) except Exception: migration_done.set() raise @@ -307,6 +317,12 @@ async def test_schema_migrate(hass, start_version): migration_version = res.schema_version migration_done.set() + def _instrument_apply_update(*args): + """Control migration progress.""" + nonlocal migration_stall + migration_stall.wait() + real_apply_update(*args) + with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( "homeassistant.components.recorder.core.create_engine", new=_create_engine_test, @@ -316,12 +332,21 @@ async def test_schema_migrate(hass, start_version): autospec=True, ) as setup_run, patch( "homeassistant.components.recorder.migration.migrate_schema", - wraps=_instrument_migration, + wraps=_instrument_migrate_schema, + ), patch( + "homeassistant.components.recorder.migration._apply_update", + wraps=_instrument_apply_update, ): - await async_setup_component( - hass, "recorder", {"recorder": {"db_url": "sqlite://"}} + recorder_helper.async_initialize_recorder(hass) + hass.async_create_task( + async_setup_component( + hass, "recorder", {"recorder": {"db_url": "sqlite://"}} + ) ) + await recorder_helper.async_wait_recorder(hass) + assert recorder.util.async_migration_in_progress(hass) is True + assert recorder.util.async_migration_is_live(hass) == live migration_stall.set() await hass.async_block_till_done() await hass.async_add_executor_job(migration_done.wait) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index ee76b40a15b..8db4587f1cf 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -31,6 +31,7 @@ from homeassistant.components.recorder.util import session_scope from homeassistant.const import TEMP_CELSIUS from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import recorder as recorder_helper from homeassistant.setup import setup_component import homeassistant.util.dt as dt_util @@ -1128,6 +1129,7 @@ def test_delete_metadata_duplicates(caplog, tmpdir): "homeassistant.components.recorder.core.create_engine", new=_create_engine_28 ): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) wait_recording_done(hass) wait_recording_done(hass) @@ -1158,6 +1160,7 @@ def test_delete_metadata_duplicates(caplog, tmpdir): # Test that the duplicates are removed during migration from schema 28 hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) hass.start() wait_recording_done(hass) @@ -1217,6 +1220,7 @@ def test_delete_metadata_duplicates_many(caplog, tmpdir): "homeassistant.components.recorder.core.create_engine", new=_create_engine_28 ): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) wait_recording_done(hass) wait_recording_done(hass) @@ -1249,6 +1253,7 @@ def test_delete_metadata_duplicates_many(caplog, tmpdir): # Test that the duplicates are removed during migration from schema 28 hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) hass.start() wait_recording_done(hass) diff --git a/tests/components/recorder/test_statistics_v23_migration.py b/tests/components/recorder/test_statistics_v23_migration.py index 50311a987d6..a7cc2b35e61 100644 --- a/tests/components/recorder/test_statistics_v23_migration.py +++ b/tests/components/recorder/test_statistics_v23_migration.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from homeassistant.components import recorder from homeassistant.components.recorder import SQLITE_URL_PREFIX, statistics from homeassistant.components.recorder.util import session_scope +from homeassistant.helpers import recorder as recorder_helper from homeassistant.setup import setup_component import homeassistant.util.dt as dt_util @@ -179,6 +180,7 @@ def test_delete_duplicates(caplog, tmpdir): recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION ), patch(CREATE_ENGINE_TARGET, new=_create_engine_test): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) wait_recording_done(hass) wait_recording_done(hass) @@ -206,6 +208,7 @@ def test_delete_duplicates(caplog, tmpdir): # Test that the duplicates are removed during migration from schema 23 hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) hass.start() wait_recording_done(hass) @@ -347,6 +350,7 @@ def test_delete_duplicates_many(caplog, tmpdir): recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION ), patch(CREATE_ENGINE_TARGET, new=_create_engine_test): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) wait_recording_done(hass) wait_recording_done(hass) @@ -380,6 +384,7 @@ def test_delete_duplicates_many(caplog, tmpdir): # Test that the duplicates are removed during migration from schema 23 hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) hass.start() wait_recording_done(hass) @@ -492,6 +497,7 @@ def test_delete_duplicates_non_identical(caplog, tmpdir): recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION ), patch(CREATE_ENGINE_TARGET, new=_create_engine_test): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) wait_recording_done(hass) wait_recording_done(hass) @@ -515,6 +521,7 @@ def test_delete_duplicates_non_identical(caplog, tmpdir): # Test that the duplicates are removed during migration from schema 23 hass = get_test_home_assistant() hass.config.config_dir = tmpdir + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) hass.start() wait_recording_done(hass) @@ -592,6 +599,7 @@ def test_delete_duplicates_short_term(caplog, tmpdir): recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION ), patch(CREATE_ENGINE_TARGET, new=_create_engine_test): hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) wait_recording_done(hass) wait_recording_done(hass) @@ -614,6 +622,7 @@ def test_delete_duplicates_short_term(caplog, tmpdir): # Test that the duplicates are removed during migration from schema 23 hass = get_test_home_assistant() hass.config.config_dir = tmpdir + recorder_helper.async_initialize_recorder(hass) setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) hass.start() wait_recording_done(hass) diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index 283883030fa..b604cc53e6c 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -15,6 +15,7 @@ from homeassistant.components.recorder.statistics import ( list_statistic_ids, statistics_during_period, ) +from homeassistant.helpers import recorder as recorder_helper from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util from homeassistant.util.unit_system import METRIC_SYSTEM @@ -274,6 +275,7 @@ async def test_recorder_info(hass, hass_ws_client, recorder_mock): "backlog": 0, "max_backlog": 40000, "migration_in_progress": False, + "migration_is_live": False, "recording": True, "thread_running": True, } @@ -296,6 +298,7 @@ async def test_recorder_info_bad_recorder_config(hass, hass_ws_client): client = await hass_ws_client() with patch("homeassistant.components.recorder.migration.migrate_schema"): + recorder_helper.async_initialize_recorder(hass) assert not await async_setup_component( hass, recorder.DOMAIN, {recorder.DOMAIN: config} ) @@ -318,7 +321,7 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client): migration_done = threading.Event() - real_migration = recorder.migration.migrate_schema + real_migration = recorder.migration._apply_update def stalled_migration(*args): """Make migration stall.""" @@ -334,12 +337,16 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client): ), patch.object( recorder.core, "MAX_QUEUE_BACKLOG", 1 ), patch( - "homeassistant.components.recorder.migration.migrate_schema", + "homeassistant.components.recorder.migration._apply_update", wraps=stalled_migration, ): - await async_setup_component( - hass, "recorder", {"recorder": {"db_url": "sqlite://"}} + recorder_helper.async_initialize_recorder(hass) + hass.create_task( + async_setup_component( + hass, "recorder", {"recorder": {"db_url": "sqlite://"}} + ) ) + await recorder_helper.async_wait_recorder(hass) hass.states.async_set("my.entity", "on", {}) await hass.async_block_till_done() diff --git a/tests/conftest.py b/tests/conftest.py index dc5f3069332..9ca29c60658 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,7 @@ from homeassistant.components.websocket_api.auth import ( from homeassistant.components.websocket_api.http import URL from homeassistant.const import HASSIO_USER_NAME from homeassistant.core import CoreState, HomeAssistant -from homeassistant.helpers import config_entry_oauth2_flow +from homeassistant.helpers import config_entry_oauth2_flow, recorder as recorder_helper from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util, location @@ -790,6 +790,8 @@ async def _async_init_recorder_component(hass, add_config=None): with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( "homeassistant.components.recorder.migration.migrate_schema" ): + if recorder.DOMAIN not in hass.data: + recorder_helper.async_initialize_recorder(hass) assert await async_setup_component( hass, recorder.DOMAIN, {recorder.DOMAIN: config} ) diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 232d8fb6bbf..06f800af7f3 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -211,6 +211,82 @@ async def test_setup_after_deps_in_stage_1_ignored(hass): assert order == ["cloud", "an_after_dep", "normal_integration"] +@pytest.mark.parametrize("load_registries", [False]) +async def test_setup_frontend_before_recorder(hass): + """Test frontend is setup before recorder.""" + order = [] + + def gen_domain_setup(domain): + async def async_setup(hass, config): + order.append(domain) + return True + + return async_setup + + mock_integration( + hass, + MockModule( + domain="normal_integration", + async_setup=gen_domain_setup("normal_integration"), + partial_manifest={"after_dependencies": ["an_after_dep"]}, + ), + ) + mock_integration( + hass, + MockModule( + domain="an_after_dep", + async_setup=gen_domain_setup("an_after_dep"), + ), + ) + mock_integration( + hass, + MockModule( + domain="frontend", + async_setup=gen_domain_setup("frontend"), + partial_manifest={ + "dependencies": ["http"], + "after_dependencies": ["an_after_dep"], + }, + ), + ) + mock_integration( + hass, + MockModule( + domain="http", + async_setup=gen_domain_setup("http"), + ), + ) + mock_integration( + hass, + MockModule( + domain="recorder", + async_setup=gen_domain_setup("recorder"), + ), + ) + + await bootstrap._async_set_up_integrations( + hass, + { + "frontend": {}, + "http": {}, + "recorder": {}, + "normal_integration": {}, + "an_after_dep": {}, + }, + ) + + assert "frontend" in hass.config.components + assert "normal_integration" in hass.config.components + assert "recorder" in hass.config.components + assert order == [ + "http", + "frontend", + "recorder", + "an_after_dep", + "normal_integration", + ] + + @pytest.mark.parametrize("load_registries", [False]) async def test_setup_after_deps_via_platform(hass): """Test after_dependencies set up via platform."""