From abf0c87e40743cd2c0b1014805b084e6834792de Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 9 Feb 2023 12:24:19 -0600 Subject: [PATCH] Migrate statistics to use timestamp columns (#87321) --- homeassistant/components/recorder/core.py | 18 +- .../components/recorder/db_schema.py | 45 +++- .../components/recorder/migration.py | 185 ++++++++++++--- homeassistant/components/recorder/models.py | 14 ++ homeassistant/components/recorder/queries.py | 3 +- .../components/recorder/statistics.py | 210 +++++++++++++----- homeassistant/components/recorder/tasks.py | 11 + tests/components/recorder/test_migrate.py | 4 + tests/components/recorder/test_purge.py | 2 +- 9 files changed, 383 insertions(+), 109 deletions(-) diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 002618738a1..0dfcda1520d 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -638,10 +638,8 @@ class Recorder(threading.Thread): else: persistent_notification.create( self.hass, - ( - "The database migration failed, check [the logs](/config/logs)." - "Database Migration Failed" - ), + "The database migration failed, check [the logs](/config/logs).", + "Database Migration Failed", "recorder_database_migration", ) self.hass.add_job(self.async_set_db_ready) @@ -730,8 +728,10 @@ class Recorder(threading.Thread): ( "System performance will temporarily degrade during the database" " upgrade. Do not power down or restart the system until the upgrade" - " completes. Integrations that read the database, such as logbook and" - " history, may return inconsistent results until the upgrade completes." + " completes. Integrations that read the database, such as logbook," + " history, and statistics may return inconsistent results until the " + " upgrade completes. This notification will be automatically dismissed" + " when the upgrade completes." ), "Database upgrade in progress", "recorder_database_migration", @@ -1027,11 +1027,7 @@ class Recorder(threading.Thread): def _post_schema_migration(self, old_version: int, new_version: int) -> None: """Run post schema migration tasks.""" - assert self.engine is not None - assert self.event_session is not None - migration.post_schema_migration( - self.engine, self.event_session, old_version, new_version - ) + migration.post_schema_migration(self, old_version, new_version) def _send_keep_alive(self) -> None: """Send a keep alive to keep the db connection open.""" diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index 5d10a459d88..19ed9fbe4bd 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -52,7 +52,12 @@ from homeassistant.helpers.json import ( import homeassistant.util.dt as dt_util from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect -from .models import StatisticData, StatisticMetaData, process_timestamp +from .models import ( + StatisticData, + StatisticMetaData, + datetime_to_timestamp_or_none, + process_timestamp, +) # SQLAlchemy Schema @@ -61,7 +66,7 @@ class Base(DeclarativeBase): """Base class for tables.""" -SCHEMA_VERSION = 33 +SCHEMA_VERSION = 35 _LOGGER = logging.getLogger(__name__) @@ -76,6 +81,8 @@ TABLE_STATISTICS_META = "statistics_meta" TABLE_STATISTICS_RUNS = "statistics_runs" TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" +STATISTICS_TABLES = ("statistics", "statistics_short_term") + MAX_STATE_ATTRS_BYTES = 16384 PSQL_DIALECT = SupportedDialect.POSTGRESQL @@ -502,17 +509,24 @@ class StatisticsBase: """Statistics base class.""" id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) - created: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + created: Mapped[datetime] = mapped_column( + DATETIME_TYPE, default=dt_util.utcnow + ) # No longer used + created_ts: Mapped[float] = mapped_column(TIMESTAMP_TYPE, default=time.time) metadata_id: Mapped[int | None] = mapped_column( Integer, ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), index=True, ) - start: Mapped[datetime | None] = mapped_column(DATETIME_TYPE, index=True) + start: Mapped[datetime | None] = mapped_column( + DATETIME_TYPE, index=True + ) # No longer used + start_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, index=True) mean: Mapped[float | None] = mapped_column(DOUBLE_TYPE) min: Mapped[float | None] = mapped_column(DOUBLE_TYPE) max: Mapped[float | None] = mapped_column(DOUBLE_TYPE) last_reset: Mapped[datetime | None] = mapped_column(DATETIME_TYPE) + last_reset_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) state: Mapped[float | None] = mapped_column(DOUBLE_TYPE) sum: Mapped[float | None] = mapped_column(DOUBLE_TYPE) @@ -521,9 +535,17 @@ class StatisticsBase: @classmethod def from_stats(cls, metadata_id: int, stats: StatisticData) -> Self: """Create object from a statistics.""" - return cls( # type: ignore[call-arg,misc] + return cls( # type: ignore[call-arg] metadata_id=metadata_id, - **stats, + start=None, + start_ts=dt_util.utc_to_timestamp(stats["start"]), + mean=stats.get("mean"), + min=stats.get("min"), + max=stats.get("max"), + last_reset=None, + last_reset_ts=datetime_to_timestamp_or_none(stats.get("last_reset")), + state=stats.get("state"), + sum=stats.get("sum"), ) @@ -534,7 +556,12 @@ class Statistics(Base, StatisticsBase): __table_args__ = ( # Used for fetching statistics for a certain entity at a specific time - Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True), + Index( + "ix_statistics_statistic_id_start_ts", + "metadata_id", + "start_ts", + unique=True, + ), ) __tablename__ = TABLE_STATISTICS @@ -547,9 +574,9 @@ class StatisticsShortTerm(Base, StatisticsBase): __table_args__ = ( # Used for fetching statistics for a certain entity at a specific time Index( - "ix_statistics_short_term_statistic_id_start", + "ix_statistics_short_term_statistic_id_start_ts", "metadata_id", - "start", + "start_ts", unique=True, ), ) diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index ee319aaf38b..a3a609a1b6f 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -27,6 +27,7 @@ from homeassistant.core import HomeAssistant from .const import SupportedDialect from .db_schema import ( SCHEMA_VERSION, + STATISTICS_TABLES, TABLE_STATES, Base, SchemaChanges, @@ -43,7 +44,11 @@ from .statistics import ( get_start_time, validate_db_schema as statistics_validate_db_schema, ) -from .tasks import CommitTask, PostSchemaMigrationTask +from .tasks import ( + CommitTask, + PostSchemaMigrationTask, + StatisticsTimestampMigrationCleanupTask, +) from .util import session_scope if TYPE_CHECKING: @@ -51,6 +56,7 @@ if TYPE_CHECKING: LIVE_MIGRATION_MIN_SCHEMA_VERSION = 0 + _LOGGER = logging.getLogger(__name__) @@ -274,13 +280,21 @@ def _drop_index( "Finished dropping index %s from table %s", index_name, table_name ) else: - if index_name in ("ix_states_entity_id", "ix_states_context_parent_id"): + if index_name in ( + "ix_states_entity_id", + "ix_states_context_parent_id", + "ix_statistics_short_term_statistic_id_start", + "ix_statistics_statistic_id_start", + ): # ix_states_context_parent_id was only there on nightly so we do not want # to generate log noise or issues about it. # # ix_states_entity_id was only there for users who upgraded from schema # version 8 or earlier. Newer installs will not have it so we do not # want to generate log noise or issues about it. + # + # ix_statistics_short_term_statistic_id_start and ix_statistics_statistic_id_start + # were only there for users who upgraded from schema version 23 or earlier. return _LOGGER.warning( @@ -764,35 +778,9 @@ def _apply_update( # noqa: C901 # Add name column to StatisticsMeta _add_columns(session_maker, "statistics_meta", ["name VARCHAR(255)"]) elif new_version == 24: - # Recreate statistics indices to block duplicated statistics - _drop_index(session_maker, "statistics", "ix_statistics_statistic_id_start") - _drop_index( - session_maker, - "statistics_short_term", - "ix_statistics_short_term_statistic_id_start", - ) - try: - _create_index( - session_maker, "statistics", "ix_statistics_statistic_id_start" - ) - _create_index( - session_maker, - "statistics_short_term", - "ix_statistics_short_term_statistic_id_start", - ) - except DatabaseError: - # There may be duplicated statistics entries, delete duplicated statistics - # and try again - with session_scope(session=session_maker()) as session: - delete_statistics_duplicates(hass, session) - _create_index( - session_maker, "statistics", "ix_statistics_statistic_id_start" - ) - _create_index( - session_maker, - "statistics_short_term", - "ix_statistics_short_term_statistic_id_start", - ) + _LOGGER.debug("Deleting duplicated statistics entries") + with session_scope(session=session_maker()) as session: + delete_statistics_duplicates(hass, session) elif new_version == 25: _add_columns(session_maker, "states", [f"attributes_id {big_int}"]) _create_index(session_maker, "states", "ix_states_attributes_id") @@ -881,13 +869,62 @@ def _apply_update( # noqa: C901 # when querying the states table. # https://github.com/home-assistant/core/issues/83787 _drop_index(session_maker, "states", "ix_states_entity_id") + elif new_version == 34: + # Once we require SQLite >= 3.35.5, we should drop the columns: + # ALTER TABLE statistics DROP COLUMN created + # ALTER TABLE statistics DROP COLUMN start + # ALTER TABLE statistics DROP COLUMN last_reset + # ALTER TABLE statistics_short_term DROP COLUMN created + # ALTER TABLE statistics_short_term DROP COLUMN start + # ALTER TABLE statistics_short_term DROP COLUMN last_reset + _add_columns( + session_maker, + "statistics", + [ + f"created_ts {timestamp_type}", + f"start_ts {timestamp_type}", + f"last_reset_ts {timestamp_type}", + ], + ) + _add_columns( + session_maker, + "statistics_short_term", + [ + f"created_ts {timestamp_type}", + f"start_ts {timestamp_type}", + f"last_reset_ts {timestamp_type}", + ], + ) + _create_index(session_maker, "statistics", "ix_statistics_start_ts") + _create_index( + session_maker, "statistics", "ix_statistics_statistic_id_start_ts" + ) + _create_index( + session_maker, "statistics_short_term", "ix_statistics_short_term_start_ts" + ) + _create_index( + session_maker, + "statistics_short_term", + "ix_statistics_short_term_statistic_id_start_ts", + ) + _migrate_statistics_columns_to_timestamp(session_maker, engine) + elif new_version == 35: + # Migration is done in two steps to ensure we can start using + # the new columns before we wipe the old ones. + _drop_index(session_maker, "statistics", "ix_statistics_statistic_id_start") + _drop_index( + session_maker, + "statistics_short_term", + "ix_statistics_short_term_statistic_id_start", + ) + # ix_statistics_start and ix_statistics_statistic_id_start are still used + # for the post migration cleanup and can be removed in a future version. else: raise ValueError(f"No schema migration defined for version {new_version}") def post_schema_migration( - engine: Engine, - session: Session, + instance: Recorder, old_version: int, new_version: int, ) -> None: @@ -905,7 +942,19 @@ def post_schema_migration( # In version 31 we migrated all the time_fired, last_updated, and last_changed # columns to be timestamps. In version 32 we need to wipe the old columns # since they are no longer used and take up a significant amount of space. - _wipe_old_string_time_columns(engine, session) + assert instance.event_session is not None + assert instance.engine is not None + _wipe_old_string_time_columns(instance.engine, instance.event_session) + if old_version < 35 <= new_version: + # In version 34 we migrated all the created, start, and last_reset + # columns to be timestamps. In version 34 we need to wipe the old columns + # since they are no longer used and take up a significant amount of space. + _wipe_old_string_statistics_columns(instance) + + +def _wipe_old_string_statistics_columns(instance: Recorder) -> None: + """Wipe old string statistics columns to save space.""" + instance.queue_task(StatisticsTimestampMigrationCleanupTask()) def _wipe_old_string_time_columns(engine: Engine, session: Session) -> None: @@ -1048,6 +1097,74 @@ def _migrate_columns_to_timestamp( ) +def _migrate_statistics_columns_to_timestamp( + session_maker: Callable[[], Session], engine: Engine +) -> None: + """Migrate statistics columns to use timestamp.""" + # Migrate all data in statistics.start to statistics.start_ts + # Migrate all data in statistics.created to statistics.created_ts + # Migrate all data in statistics.last_reset to statistics.last_reset_ts + # Migrate all data in statistics_short_term.start to statistics_short_term.start_ts + # Migrate all data in statistics_short_term.created to statistics_short_term.created_ts + # Migrate all data in statistics_short_term.last_reset to statistics_short_term.last_reset_ts + result: CursorResult | None = None + if engine.dialect.name == SupportedDialect.SQLITE: + # With SQLite we do this in one go since it is faster + for table in STATISTICS_TABLES: + with session_scope(session=session_maker()) as session: + session.connection().execute( + text( + f"UPDATE {table} set start_ts=strftime('%s',start) + " + "cast(substr(start,-7) AS FLOAT), " + f"created_ts=strftime('%s',created) + " + "cast(substr(created,-7) AS FLOAT), " + f"last_reset_ts=strftime('%s',last_reset) + " + "cast(substr(last_reset,-7) AS FLOAT);" + ) + ) + elif engine.dialect.name == SupportedDialect.MYSQL: + # With MySQL we do this in chunks to avoid hitting the `innodb_buffer_pool_size` limit + # We also need to do this in a loop since we can't be sure that we have + # updated all rows in the table until the rowcount is 0 + for table in STATISTICS_TABLES: + result = None + while result is None or result.rowcount > 0: # type: ignore[unreachable] + with session_scope(session=session_maker()) as session: + result = session.connection().execute( + text( + f"UPDATE {table} set start_ts=" + "IF(start is NULL,0," + "UNIX_TIMESTAMP(start) " + "), " + "created_ts=" + "UNIX_TIMESTAMP(created), " + "last_reset_ts=" + "UNIX_TIMESTAMP(last_reset) " + "where start_ts is NULL " + "LIMIT 250000;" + ) + ) + elif engine.dialect.name == SupportedDialect.POSTGRESQL: + # With Postgresql we do this in chunks to avoid using too much memory + # We also need to do this in a loop since we can't be sure that we have + # updated all rows in the table until the rowcount is 0 + for table in STATISTICS_TABLES: + result = None + while result is None or result.rowcount > 0: # type: ignore[unreachable] + with session_scope(session=session_maker()) as session: + result = session.connection().execute( + text( + f"UPDATE {table} set start_ts=" # nosec + "(case when start is NULL then 0 else EXTRACT(EPOCH FROM start) end), " + "created_ts=EXTRACT(EPOCH FROM created), " + "last_reset_ts=EXTRACT(EPOCH FROM last_reset) " + "where id IN ( " + f"SELECT id FROM {table} where start_ts is NULL LIMIT 250000 " + " );" + ) + ) + + def _initialize_database(session: Session) -> bool: """Initialize a new database. diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index d2d1968815a..40939bff1ba 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -124,6 +124,20 @@ def process_datetime_to_timestamp(ts: datetime) -> float: return ts.timestamp() +def datetime_to_timestamp_or_none(dt: datetime | None) -> float | None: + """Convert a datetime to a timestamp.""" + if dt is None: + return None + return dt_util.utc_to_timestamp(dt) + + +def timestamp_to_datetime_or_none(ts: float | None) -> datetime | None: + """Convert a timestamp to a datetime.""" + if not ts: + return None + return dt_util.utc_from_timestamp(ts) + + class LazyStatePreSchema31(State): """A lazy version of core State before schema 31.""" diff --git a/homeassistant/components/recorder/queries.py b/homeassistant/components/recorder/queries.py index c20393c69ae..d12b6409b7c 100644 --- a/homeassistant/components/recorder/queries.py +++ b/homeassistant/components/recorder/queries.py @@ -604,9 +604,10 @@ def find_short_term_statistics_to_purge( purge_before: datetime, ) -> StatementLambdaElement: """Find short term statistics to purge.""" + purge_before_ts = purge_before.timestamp() return lambda_stmt( lambda: select(StatisticsShortTerm.id) - .filter(StatisticsShortTerm.start < purge_before) + .filter(StatisticsShortTerm.start_ts < purge_before_ts) .limit(MAX_ROWS_TO_PURGE) ) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index f52be1faa8f..f0214b4143f 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -59,13 +59,20 @@ from .const import ( SupportedDialect, ) from .db_schema import ( + STATISTICS_TABLES, Statistics, StatisticsBase, StatisticsMeta, StatisticsRuns, StatisticsShortTerm, ) -from .models import StatisticData, StatisticMetaData, StatisticResult, process_timestamp +from .models import ( + StatisticData, + StatisticMetaData, + StatisticResult, + datetime_to_timestamp_or_none, + timestamp_to_datetime_or_none, +) from .util import ( execute, execute_stmt_lambda_element, @@ -79,22 +86,22 @@ if TYPE_CHECKING: QUERY_STATISTICS = ( Statistics.metadata_id, - Statistics.start, + Statistics.start_ts, Statistics.mean, Statistics.min, Statistics.max, - Statistics.last_reset, + Statistics.last_reset_ts, Statistics.state, Statistics.sum, ) QUERY_STATISTICS_SHORT_TERM = ( StatisticsShortTerm.metadata_id, - StatisticsShortTerm.start, + StatisticsShortTerm.start_ts, StatisticsShortTerm.mean, StatisticsShortTerm.min, StatisticsShortTerm.max, - StatisticsShortTerm.last_reset, + StatisticsShortTerm.last_reset_ts, StatisticsShortTerm.state, StatisticsShortTerm.sum, ) @@ -112,14 +119,14 @@ QUERY_STATISTICS_SUMMARY_MEAN = ( QUERY_STATISTICS_SUMMARY_SUM = ( StatisticsShortTerm.metadata_id, - StatisticsShortTerm.start, - StatisticsShortTerm.last_reset, + StatisticsShortTerm.start_ts, + StatisticsShortTerm.last_reset_ts, StatisticsShortTerm.state, StatisticsShortTerm.sum, func.row_number() .over( # type: ignore[no-untyped-call] partition_by=StatisticsShortTerm.metadata_id, - order_by=StatisticsShortTerm.start.desc(), + order_by=StatisticsShortTerm.start_ts.desc(), ) .label("rownum"), ) @@ -421,7 +428,18 @@ def _find_duplicates( .subquery() ) query = ( - session.query(table) + session.query( + table.id, + table.metadata_id, + table.created, + table.start, + table.mean, + table.min, + table.max, + table.last_reset, + table.state, + table.sum, + ) .outerjoin( subquery, (subquery.c.metadata_id == table.metadata_id) @@ -444,13 +462,24 @@ def _find_duplicates( def columns_to_dict(duplicate: Row) -> dict: """Convert a SQLAlchemy row to dict.""" dict_ = {} - for key in duplicate.__mapper__.c.keys(): + for key in ( + "id", + "metadata_id", + "start", + "created", + "mean", + "min", + "max", + "last_reset", + "state", + "sum", + ): dict_[key] = getattr(duplicate, key) return dict_ def compare_statistic_rows(row1: dict, row2: dict) -> bool: """Compare two statistics rows, ignoring id and created.""" - ignore_keys = ["id", "created"] + ignore_keys = {"id", "created"} keys1 = set(row1).difference(ignore_keys) keys2 = set(row2).difference(ignore_keys) return keys1 == keys2 and all(row1[k] == row2[k] for k in keys1) @@ -609,13 +638,13 @@ def delete_statistics_meta_duplicates(session: Session) -> None: def _compile_hourly_statistics_summary_mean_stmt( - start_time: datetime, end_time: datetime + start_time_ts: float, end_time_ts: float ) -> StatementLambdaElement: """Generate the summary mean statement for hourly statistics.""" return lambda_stmt( lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN) - .filter(StatisticsShortTerm.start >= start_time) - .filter(StatisticsShortTerm.start < end_time) + .filter(StatisticsShortTerm.start_ts >= start_time_ts) + .filter(StatisticsShortTerm.start_ts < end_time_ts) .group_by(StatisticsShortTerm.metadata_id) .order_by(StatisticsShortTerm.metadata_id) ) @@ -629,11 +658,13 @@ def _compile_hourly_statistics(session: Session, start: datetime) -> None: - sum is taken from the last 5-minute entry during the hour """ start_time = start.replace(minute=0) + start_time_ts = start_time.timestamp() end_time = start_time + timedelta(hours=1) + end_time_ts = end_time.timestamp() # Compute last hour's average, min, max summary: dict[str, StatisticData] = {} - stmt = _compile_hourly_statistics_summary_mean_stmt(start_time, end_time) + stmt = _compile_hourly_statistics_summary_mean_stmt(start_time_ts, end_time_ts) stats = execute_stmt_lambda_element(session, stmt) if stats: @@ -649,8 +680,8 @@ def _compile_hourly_statistics(session: Session, start: datetime) -> None: # Get last hour's last sum subquery = ( session.query(*QUERY_STATISTICS_SUMMARY_SUM) - .filter(StatisticsShortTerm.start >= bindparam("start_time")) - .filter(StatisticsShortTerm.start < bindparam("end_time")) + .filter(StatisticsShortTerm.start_ts >= bindparam("start_time_ts")) + .filter(StatisticsShortTerm.start_ts < bindparam("end_time_ts")) .subquery() ) query = ( @@ -658,15 +689,15 @@ def _compile_hourly_statistics(session: Session, start: datetime) -> None: .filter(subquery.c.rownum == 1) .order_by(subquery.c.metadata_id) ) - stats = execute(query.params(start_time=start_time, end_time=end_time)) + stats = execute(query.params(start_time_ts=start_time_ts, end_time_ts=end_time_ts)) if stats: for stat in stats: - metadata_id, start, last_reset, state, _sum, _ = stat + metadata_id, start, last_reset_ts, state, _sum, _ = stat if metadata_id in summary: summary[metadata_id].update( { - "last_reset": process_timestamp(last_reset), + "last_reset": timestamp_to_datetime_or_none(last_reset_ts), "state": state, "sum": _sum, } @@ -674,7 +705,7 @@ def _compile_hourly_statistics(session: Session, start: datetime) -> None: else: summary[metadata_id] = { "start": start_time, - "last_reset": process_timestamp(last_reset), + "last_reset": timestamp_to_datetime_or_none(last_reset_ts), "state": state, "sum": _sum, } @@ -757,9 +788,10 @@ def _adjust_sum_statistics( adj: float, ) -> None: """Adjust statistics in the database.""" + start_time_ts = start_time.timestamp() try: session.query(table).filter_by(metadata_id=metadata_id).filter( - table.start >= start_time + table.start_ts >= start_time_ts ).update( { table.sum: table.sum + adj, @@ -803,7 +835,9 @@ def _update_statistics( table.mean: statistic.get("mean"), table.min: statistic.get("min"), table.max: statistic.get("max"), - table.last_reset: statistic.get("last_reset"), + table.last_reset_ts: datetime_to_timestamp_or_none( + statistic.get("last_reset") + ), table.state: statistic.get("state"), table.sum: statistic.get("sum"), }, @@ -1150,10 +1184,11 @@ def _statistics_during_period_stmt( This prepares a lambda_stmt query, so we don't insert the parameters yet. """ + start_time_ts = start_time.timestamp() - columns = select(table.metadata_id, table.start) + columns = select(table.metadata_id, table.start_ts) if "last_reset" in types: - columns = columns.add_columns(table.last_reset) + columns = columns.add_columns(table.last_reset_ts) if "max" in types: columns = columns.add_columns(table.max) if "mean" in types: @@ -1165,15 +1200,16 @@ def _statistics_during_period_stmt( if "sum" in types: columns = columns.add_columns(table.sum) - stmt = lambda_stmt(lambda: columns.filter(table.start >= start_time)) + stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts)) if end_time is not None: - stmt += lambda q: q.filter(table.start < end_time) + end_time_ts = end_time.timestamp() + stmt += lambda q: q.filter(table.start_ts < end_time_ts) if metadata_ids: stmt += lambda q: q.filter( # https://github.com/python/mypy/issues/2608 table.metadata_id.in_(metadata_ids) # type:ignore[arg-type] ) - stmt += lambda q: q.order_by(table.metadata_id, table.start) + stmt += lambda q: q.order_by(table.metadata_id, table.start_ts) return stmt @@ -1204,9 +1240,11 @@ def _get_max_mean_min_statistic_in_sub_period( columns = columns.add_columns(func.min(table.min)) stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) if start_time is not None: - stmt += lambda q: q.filter(table.start >= start_time) + start_time_ts = start_time.timestamp() + stmt += lambda q: q.filter(table.start_ts >= start_time_ts) if end_time is not None: - stmt += lambda q: q.filter(table.start < end_time) + end_time_ts = end_time.timestamp() + stmt += lambda q: q.filter(table.start_ts < end_time_ts) stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt)) if not stats: return @@ -1296,13 +1334,14 @@ def _first_statistic( ) -> datetime | None: """Return the data of the oldest statistic row for a given metadata id.""" stmt = lambda_stmt( - lambda: select(table.start) + lambda: select(table.start_ts) .filter(table.metadata_id == metadata_id) - .order_by(table.start.asc()) + .order_by(table.start_ts.asc()) .limit(1) ) - stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) - return process_timestamp(stats[0].start) if stats else None + if stats := cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)): + return dt_util.utc_from_timestamp(stats[0].start_ts) + return None def _get_oldest_sum_statistic( @@ -1327,7 +1366,7 @@ def _get_oldest_sum_statistic( lambda: select(table.sum) .filter(table.metadata_id == metadata_id) .filter(table.sum.is_not(None)) - .order_by(table.start.asc()) + .order_by(table.start_ts.asc()) .limit(1) ) if start_time is not None: @@ -1338,7 +1377,8 @@ def _get_oldest_sum_statistic( else: period = start_time.replace(minute=0, second=0, microsecond=0) prev_period = period - table.duration - stmt += lambda q: q.filter(table.start >= prev_period) + prev_period_ts = prev_period.timestamp() + stmt += lambda q: q.filter(table.start_ts >= prev_period_ts) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return stats[0].sum if stats else None @@ -1412,13 +1452,15 @@ def _get_newest_sum_statistic( ) .filter(table.metadata_id == metadata_id) .filter(table.sum.is_not(None)) - .order_by(table.start.desc()) + .order_by(table.start_ts.desc()) .limit(1) ) if start_time is not None: - stmt += lambda q: q.filter(table.start >= start_time) + start_time_ts = start_time.timestamp() + stmt += lambda q: q.filter(table.start_ts >= start_time_ts) if end_time is not None: - stmt += lambda q: q.filter(table.start < end_time) + end_time_ts = end_time.timestamp() + stmt += lambda q: q.filter(table.start_ts < end_time_ts) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return stats[0].sum if stats else None @@ -1696,7 +1738,7 @@ def _get_last_statistics_stmt( return lambda_stmt( lambda: select(*QUERY_STATISTICS) .filter_by(metadata_id=metadata_id) - .order_by(Statistics.metadata_id, Statistics.start.desc()) + .order_by(Statistics.metadata_id, Statistics.start_ts.desc()) .limit(number_of_stats) ) @@ -1712,7 +1754,7 @@ def _get_last_statistics_short_term_stmt( return lambda_stmt( lambda: select(*QUERY_STATISTICS_SHORT_TERM) .filter_by(metadata_id=metadata_id) - .order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()) + .order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts.desc()) .limit(number_of_stats) ) @@ -1790,7 +1832,7 @@ def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery: StatisticsShortTerm.metadata_id, # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable - func.max(StatisticsShortTerm.start).label("start_max"), + func.max(StatisticsShortTerm.start_ts).label("start_max"), ) .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) .group_by(StatisticsShortTerm.metadata_id) @@ -1809,7 +1851,7 @@ def _latest_short_term_statistics_stmt( StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable == most_recent_statistic_row.c.metadata_id ) - & (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max), + & (StatisticsShortTerm.start_ts == most_recent_statistic_row.c.start_max), ) return stmt @@ -1860,9 +1902,9 @@ def _statistics_at_time( types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> Sequence[Row] | None: """Return last known statistics, earlier than start_time, for the metadata_ids.""" - columns = select(table.metadata_id, table.start) + columns = select(table.metadata_id, table.start_ts) if "last_reset" in types: - columns = columns.add_columns(table.last_reset) + columns = columns.add_columns(table.last_reset_ts) if "max" in types: columns = columns.add_columns(table.max) if "mean" in types: @@ -1874,13 +1916,14 @@ def _statistics_at_time( if "sum" in types: columns = columns.add_columns(table.sum) + start_time_ts = start_time.timestamp() stmt = lambda_stmt(lambda: columns) most_recent_statistic_ids = ( # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable lambda_stmt(lambda: select(func.max(table.id).label("max_id"))) - .filter(table.start < start_time) + .filter(table.start_ts < start_time_ts) .filter(table.metadata_id.in_(metadata_ids)) .group_by(table.metadata_id) .subquery() @@ -1925,7 +1968,7 @@ def _sorted_statistics_to_dict( stats, lambda stat: stat.metadata_id, # type: ignore[no-any-return] ): - first_start_time = process_timestamp(next(group).start) + first_start_time = dt_util.utc_from_timestamp(next(group).start_ts) if start_time and first_start_time > start_time: need_stat_at_start_time.add(meta_id) @@ -1940,6 +1983,8 @@ def _sorted_statistics_to_dict( stats_at_start_time[stat.metadata_id] = (stat,) # Append all statistic entries, and optionally do unit conversion + table_duration = table.duration + timestamp_to_datetime = dt_util.utc_from_timestamp for meta_id, group in groupby( stats, lambda stat: stat.metadata_id, # type: ignore[no-any-return] @@ -1954,11 +1999,10 @@ def _sorted_statistics_to_dict( convert = no_conversion ent_results = result[meta_id] for db_state in chain(stats_at_start_time.get(meta_id, ()), group): - start = process_timestamp(db_state.start) - end = start + table.duration - row = { + start = timestamp_to_datetime(db_state.start_ts) + row: dict[str, Any] = { "start": start, - "end": end, + "end": start + table_duration, } if "mean" in types: row["mean"] = convert(db_state.mean) @@ -1967,7 +2011,9 @@ def _sorted_statistics_to_dict( if "max" in types: row["max"] = convert(db_state.max) if "last_reset" in types: - row["last_reset"] = process_timestamp(db_state.last_reset) + row["last_reset"] = timestamp_to_datetime_or_none( + db_state.last_reset_ts + ) if "state" in types: row["state"] = convert(db_state.state) if "sum" in types: @@ -1996,9 +2042,10 @@ def _statistics_exists( start: datetime, ) -> int | None: """Return id if a statistics entry already exists.""" + start_ts = start.timestamp() result = ( session.query(table.id) - .filter((table.metadata_id == metadata_id) & (table.start == start)) + .filter((table.metadata_id == metadata_id) & (table.start_ts == start_ts)) .first() ) return result.id if result else None @@ -2515,3 +2562,60 @@ def correct_db_schema( f"start {datetime_type}", ], ) + + +def cleanup_statistics_timestamp_migration(instance: Recorder) -> bool: + """Clean up the statistics migration from timestamp to datetime. + + Returns False if there are more rows to update. + Returns True if all rows have been updated. + """ + engine = instance.engine + assert engine is not None + if engine.dialect.name == SupportedDialect.SQLITE: + for table in STATISTICS_TABLES: + with session_scope(session=instance.get_session()) as session: + session.connection().execute( + text( + f"update {table} set start = NULL, created = NULL, last_reset = NULL;" + ) + ) + elif engine.dialect.name == SupportedDialect.MYSQL: + for table in STATISTICS_TABLES: + with session_scope(session=instance.get_session()) as session: + if ( + session.connection() + .execute( + text( + f"UPDATE {table} set start=NULL, created=NULL, last_reset=NULL where start is not NULL LIMIT 250000;" + ) + ) + .rowcount + ): + # We have more rows to update so return False + # to indicate we need to run again + return False + elif engine.dialect.name == SupportedDialect.POSTGRESQL: + for table in STATISTICS_TABLES: + with session_scope(session=instance.get_session()) as session: + if ( + session.connection() + .execute( + text( + f"UPDATE {table} set start=NULL, created=NULL, last_reset=NULL " # nosec + f"where id in (select id from {table} where start is not NULL LIMIT 250000)" + ) + ) + .rowcount + ): + # We have more rows to update so return False + # to indicate we need to run again + return False + + from .migration import _drop_index # pylint: disable=import-outside-toplevel + + for table in STATISTICS_TABLES: + _drop_index(instance.get_session, table, f"ix_{table}_start") + # We have no more rows to update so return True + # to indicate we are done + return True diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 63dc8e9d2e3..4d12f6b343e 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -317,3 +317,14 @@ class PostSchemaMigrationTask(RecorderTask): instance._post_schema_migration( # pylint: disable=[protected-access] self.old_version, self.new_version ) + + +@dataclass +class StatisticsTimestampMigrationCleanupTask(RecorderTask): + """An object to insert into the recorder queue to run a statistics migration cleanup task.""" + + def run(self, instance: Recorder) -> None: + """Run statistics timestamp cleanup task.""" + if not statistics.cleanup_statistics_timestamp_migration(instance): + # Schedule a new statistics migration task if this one didn't finish + instance.queue_task(StatisticsTimestampMigrationCleanupTask()) diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 7a2fdfd65d9..7918e51311b 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -370,6 +370,10 @@ async def test_schema_migrate( wraps=_instrument_apply_update, ), patch( "homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics", + ), patch( + "homeassistant.components.recorder.Recorder._process_state_changed_event_into_session", + ), patch( + "homeassistant.components.recorder.Recorder._process_non_state_changed_event_into_session", ): recorder_helper.async_initialize_recorder(hass) hass.async_create_task( diff --git a/tests/components/recorder/test_purge.py b/tests/components/recorder/test_purge.py index ec53a75e575..b58189c04c7 100644 --- a/tests/components/recorder/test_purge.py +++ b/tests/components/recorder/test_purge.py @@ -1427,7 +1427,7 @@ async def _add_test_statistics(hass: HomeAssistant): session.add( StatisticsShortTerm( - start=timestamp, + start_ts=timestamp.timestamp(), state=state, ) )