diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index fb1a55cebfb..b67790f9a42 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -17,7 +17,6 @@ from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import Subquery from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE from homeassistant.core import HomeAssistant, State, split_entity_id @@ -592,48 +591,6 @@ def get_last_state_changes( ) -def _generate_most_recent_states_for_entities_by_date( - schema_version: int, - run_start: datetime, - utc_point_in_time: datetime, - entity_ids: list[str], -) -> Subquery: - """Generate the sub query for the most recent states for specific entities by date.""" - if schema_version >= 31: - run_start_ts = process_timestamp(run_start).timestamp() - utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated_ts).label("max_last_updated"), - ) - .filter( - (States.last_updated_ts >= run_start_ts) - & (States.last_updated_ts < utc_point_in_time_ts) - ) - .filter(States.entity_id.in_(entity_ids)) - .group_by(States.entity_id) - .subquery() - ) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated).label("max_last_updated"), - ) - .filter( - (States.last_updated >= run_start) - & (States.last_updated < utc_point_in_time) - ) - .filter(States.entity_id.in_(entity_ids)) - .group_by(States.entity_id) - .subquery() - ) - - def _get_states_for_entities_stmt( schema_version: int, run_start: datetime, @@ -645,16 +602,29 @@ def _get_states_for_entities_stmt( stmt, join_attributes = lambda_stmt_and_join_attributes( schema_version, no_attributes, include_last_changed=True ) - most_recent_states_for_entities_by_date = ( - _generate_most_recent_states_for_entities_by_date( - schema_version, run_start, utc_point_in_time, entity_ids - ) - ) # We got an include-list of entities, accelerate the query by filtering already # in the inner query. if schema_version >= 31: + run_start_ts = process_timestamp(run_start).timestamp() + utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) stmt += lambda q: q.join( - most_recent_states_for_entities_by_date, + ( + most_recent_states_for_entities_by_date := ( + select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated_ts).label("max_last_updated"), + ) + .filter( + (States.last_updated_ts >= run_start_ts) + & (States.last_updated_ts < utc_point_in_time_ts) + ) + .filter(States.entity_id.in_(entity_ids)) + .group_by(States.entity_id) + .subquery() + ) + ), and_( States.entity_id == most_recent_states_for_entities_by_date.c.max_entity_id, @@ -664,7 +634,21 @@ def _get_states_for_entities_stmt( ) else: stmt += lambda q: q.join( - most_recent_states_for_entities_by_date, + ( + most_recent_states_for_entities_by_date := select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated).label("max_last_updated"), + ) + .filter( + (States.last_updated >= run_start) + & (States.last_updated < utc_point_in_time) + ) + .filter(States.entity_id.in_(entity_ids)) + .group_by(States.entity_id) + .subquery() + ), and_( States.entity_id == most_recent_states_for_entities_by_date.c.max_entity_id, @@ -679,45 +663,6 @@ def _get_states_for_entities_stmt( return stmt -def _generate_most_recent_states_by_date( - schema_version: int, - run_start: datetime, - utc_point_in_time: datetime, -) -> Subquery: - """Generate the sub query for the most recent states by date.""" - if schema_version >= 31: - run_start_ts = process_timestamp(run_start).timestamp() - utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated_ts).label("max_last_updated"), - ) - .filter( - (States.last_updated_ts >= run_start_ts) - & (States.last_updated_ts < utc_point_in_time_ts) - ) - .group_by(States.entity_id) - .subquery() - ) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated).label("max_last_updated"), - ) - .filter( - (States.last_updated >= run_start) - & (States.last_updated < utc_point_in_time) - ) - .group_by(States.entity_id) - .subquery() - ) - - def _get_states_for_all_stmt( schema_version: int, run_start: datetime, @@ -733,12 +678,26 @@ def _get_states_for_all_stmt( # query, then filter out unwanted domains as well as applying the custom filter. # This filtering can't be done in the inner query because the domain column is # not indexed and we can't control what's in the custom filter. - most_recent_states_by_date = _generate_most_recent_states_by_date( - schema_version, run_start, utc_point_in_time - ) if schema_version >= 31: + run_start_ts = process_timestamp(run_start).timestamp() + utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) stmt += lambda q: q.join( - most_recent_states_by_date, + ( + most_recent_states_by_date := ( + select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated_ts).label("max_last_updated"), + ) + .filter( + (States.last_updated_ts >= run_start_ts) + & (States.last_updated_ts < utc_point_in_time_ts) + ) + .group_by(States.entity_id) + .subquery() + ) + ), and_( States.entity_id == most_recent_states_by_date.c.max_entity_id, States.last_updated_ts == most_recent_states_by_date.c.max_last_updated, @@ -746,7 +705,22 @@ def _get_states_for_all_stmt( ) else: stmt += lambda q: q.join( - most_recent_states_by_date, + ( + most_recent_states_by_date := ( + select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated).label("max_last_updated"), + ) + .filter( + (States.last_updated >= run_start) + & (States.last_updated < utc_point_in_time) + ) + .group_by(States.entity_id) + .subquery() + ) + ), and_( States.entity_id == most_recent_states_by_date.c.max_entity_id, States.last_updated == most_recent_states_by_date.c.max_last_updated, diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index c90447f1c99..bd11744ab09 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -16,14 +16,13 @@ import re from statistics import mean from typing import TYPE_CHECKING, Any, Literal, cast -from sqlalchemy import and_, bindparam, func, lambda_stmt, select, text +from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text from sqlalchemy.engine import Engine from sqlalchemy.engine.row import Row from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import Subquery import voluptuous as vol from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT @@ -650,27 +649,19 @@ def _compile_hourly_statistics_summary_mean_stmt( ) -def _compile_hourly_statistics_last_sum_stmt_subquery( - start_time_ts: float, end_time_ts: float -) -> Subquery: - """Generate the summary mean statement for hourly statistics.""" - return ( - select(*QUERY_STATISTICS_SUMMARY_SUM) - .filter(StatisticsShortTerm.start_ts >= start_time_ts) - .filter(StatisticsShortTerm.start_ts < end_time_ts) - .subquery() - ) - - def _compile_hourly_statistics_last_sum_stmt( start_time_ts: float, end_time_ts: float ) -> StatementLambdaElement: """Generate the summary mean statement for hourly statistics.""" - subquery = _compile_hourly_statistics_last_sum_stmt_subquery( - start_time_ts, end_time_ts - ) return lambda_stmt( - lambda: select(subquery) + lambda: select( + subquery := ( + select(*QUERY_STATISTICS_SUMMARY_SUM) + .filter(StatisticsShortTerm.start_ts >= start_time_ts) + .filter(StatisticsShortTerm.start_ts < end_time_ts) + .subquery() + ) + ) .filter(subquery.c.rownum == 1) .order_by(subquery.c.metadata_id) ) @@ -1267,7 +1258,8 @@ def _reduce_statistics_per_month( ) -def _statistics_during_period_stmt( +def _generate_statistics_during_period_stmt( + columns: Select, start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, @@ -1279,21 +1271,6 @@ def _statistics_during_period_stmt( This prepares a lambda_stmt query, so we don't insert the parameters yet. """ start_time_ts = start_time.timestamp() - - columns = select(table.metadata_id, table.start_ts) - if "last_reset" in types: - columns = columns.add_columns(table.last_reset_ts) - if "max" in types: - columns = columns.add_columns(table.max) - if "mean" in types: - columns = columns.add_columns(table.mean) - if "min" in types: - columns = columns.add_columns(table.min) - if "state" in types: - columns = columns.add_columns(table.state) - if "sum" in types: - columns = columns.add_columns(table.sum) - stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts)) if end_time is not None: end_time_ts = end_time.timestamp() @@ -1307,6 +1284,23 @@ def _statistics_during_period_stmt( return stmt +def _generate_max_mean_min_statistic_in_sub_period_stmt( + columns: Select, + start_time: datetime | None, + end_time: datetime | None, + table: type[StatisticsBase], + metadata_id: int, +) -> StatementLambdaElement: + stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) + if start_time is not None: + start_time_ts = start_time.timestamp() + stmt += lambda q: q.filter(table.start_ts >= start_time_ts) + if end_time is not None: + end_time_ts = end_time.timestamp() + stmt += lambda q: q.filter(table.start_ts < end_time_ts) + return stmt + + def _get_max_mean_min_statistic_in_sub_period( session: Session, result: dict[str, float], @@ -1332,13 +1326,9 @@ def _get_max_mean_min_statistic_in_sub_period( # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable columns = columns.add_columns(func.min(table.min)) - stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) - if start_time is not None: - start_time_ts = start_time.timestamp() - stmt += lambda q: q.filter(table.start_ts >= start_time_ts) - if end_time is not None: - end_time_ts = end_time.timestamp() - stmt += lambda q: q.filter(table.start_ts < end_time_ts) + stmt = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns, start_time, end_time, table, metadata_id + ) stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt)) if not stats: return @@ -1753,8 +1743,21 @@ def _statistics_during_period_with_session( table: type[Statistics | StatisticsShortTerm] = ( Statistics if period != "5minute" else StatisticsShortTerm ) - stmt = _statistics_during_period_stmt( - start_time, end_time, metadata_ids, table, types + columns = select(table.metadata_id, table.start_ts) # type: ignore[call-overload] + if "last_reset" in types: + columns = columns.add_columns(table.last_reset_ts) + if "max" in types: + columns = columns.add_columns(table.max) + if "mean" in types: + columns = columns.add_columns(table.mean) + if "min" in types: + columns = columns.add_columns(table.min) + if "state" in types: + columns = columns.add_columns(table.state) + if "sum" in types: + columns = columns.add_columns(table.sum) + stmt = _generate_statistics_during_period_stmt( + columns, start_time, end_time, metadata_ids, table, types ) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) @@ -1919,28 +1922,24 @@ def get_last_short_term_statistics( ) -def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery: - """Generate the subquery to find the most recent statistic row.""" - return ( - select( - StatisticsShortTerm.metadata_id, - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(StatisticsShortTerm.start_ts).label("start_max"), - ) - .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) - .group_by(StatisticsShortTerm.metadata_id) - ).subquery() - - def _latest_short_term_statistics_stmt( metadata_ids: list[int], ) -> StatementLambdaElement: """Create the statement for finding the latest short term stat rows.""" stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) - most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids) stmt += lambda s: s.join( - most_recent_statistic_row, + ( + most_recent_statistic_row := ( + select( + StatisticsShortTerm.metadata_id, + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(StatisticsShortTerm.start_ts).label("start_max"), + ) + .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) + .group_by(StatisticsShortTerm.metadata_id) + ).subquery() + ), ( StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable == most_recent_statistic_row.c.metadata_id @@ -1988,21 +1987,34 @@ def get_latest_short_term_statistics( ) -def _get_most_recent_statistics_subquery( - metadata_ids: set[int], table: type[StatisticsBase], start_time_ts: float -) -> Subquery: - """Generate the subquery to find the most recent statistic row.""" - return ( - select( - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(table.start_ts).label("max_start_ts"), - table.metadata_id.label("max_metadata_id"), +def _generate_statistics_at_time_stmt( + columns: Select, + table: type[StatisticsBase], + metadata_ids: set[int], + start_time_ts: float, +) -> StatementLambdaElement: + """Create the statement for finding the statistics for a given time.""" + return lambda_stmt( + lambda: columns.join( + ( + most_recent_statistic_ids := ( + select( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(table.start_ts).label("max_start_ts"), + table.metadata_id.label("max_metadata_id"), + ) + .filter(table.start_ts < start_time_ts) + .filter(table.metadata_id.in_(metadata_ids)) + .group_by(table.metadata_id) + .subquery() + ) + ), + and_( + table.start_ts == most_recent_statistic_ids.c.max_start_ts, + table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, + ), ) - .filter(table.start_ts < start_time_ts) - .filter(table.metadata_id.in_(metadata_ids)) - .group_by(table.metadata_id) - .subquery() ) @@ -2027,19 +2039,10 @@ def _statistics_at_time( columns = columns.add_columns(table.state) if "sum" in types: columns = columns.add_columns(table.sum) - start_time_ts = start_time.timestamp() - most_recent_statistic_ids = _get_most_recent_statistics_subquery( - metadata_ids, table, start_time_ts + stmt = _generate_statistics_at_time_stmt( + columns, table, metadata_ids, start_time_ts ) - stmt = lambda_stmt(lambda: columns).join( - most_recent_statistic_ids, - and_( - table.start_ts == most_recent_statistic_ids.c.max_start_ts, - table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, - ), - ) - return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index dd51946c86f..e6ae291264f 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -8,7 +8,7 @@ import sys from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, select from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session @@ -22,6 +22,10 @@ from homeassistant.components.recorder.models import ( ) from homeassistant.components.recorder.statistics import ( STATISTIC_UNIT_TO_UNIT_CONVERTER, + _generate_get_metadata_stmt, + _generate_max_mean_min_statistic_in_sub_period_stmt, + _generate_statistics_at_time_stmt, + _generate_statistics_during_period_stmt, _statistics_during_period_with_session, _update_or_add_metadata, async_add_external_statistics, @@ -1799,3 +1803,100 @@ def record_states(hass): states[sns4].append(set_state(sns4, "20", attributes=sns4_attr)) return zero, four, states + + +def test_cache_key_for_generate_statistics_during_period_stmt(): + """Test cache key for _generate_statistics_during_period_stmt.""" + columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) + stmt = _generate_statistics_during_period_stmt( + columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {} + ) + cache_key_1 = stmt._generate_cache_key() + stmt2 = _generate_statistics_during_period_stmt( + columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {} + ) + cache_key_2 = stmt2._generate_cache_key() + assert cache_key_1 == cache_key_2 + columns2 = select( + StatisticsShortTerm.metadata_id, + StatisticsShortTerm.start_ts, + StatisticsShortTerm.sum, + StatisticsShortTerm.mean, + ) + stmt3 = _generate_statistics_during_period_stmt( + columns2, + dt_util.utcnow(), + dt_util.utcnow(), + [0], + StatisticsShortTerm, + {"max", "mean"}, + ) + cache_key_3 = stmt3._generate_cache_key() + assert cache_key_1 != cache_key_3 + + +def test_cache_key_for_generate_get_metadata_stmt(): + """Test cache key for _generate_get_metadata_stmt.""" + stmt_mean = _generate_get_metadata_stmt([0], "mean") + stmt_mean2 = _generate_get_metadata_stmt([1], "mean") + stmt_sum = _generate_get_metadata_stmt([0], "sum") + stmt_none = _generate_get_metadata_stmt() + assert stmt_mean._generate_cache_key() == stmt_mean2._generate_cache_key() + assert stmt_mean._generate_cache_key() != stmt_sum._generate_cache_key() + assert stmt_mean._generate_cache_key() != stmt_none._generate_cache_key() + + +def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt(): + """Test cache key for _generate_max_mean_min_statistic_in_sub_period_stmt.""" + columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) + stmt = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns, + dt_util.utcnow(), + dt_util.utcnow(), + StatisticsShortTerm, + [0], + ) + cache_key_1 = stmt._generate_cache_key() + stmt2 = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns, + dt_util.utcnow(), + dt_util.utcnow(), + StatisticsShortTerm, + [0], + ) + cache_key_2 = stmt2._generate_cache_key() + assert cache_key_1 == cache_key_2 + columns2 = select( + StatisticsShortTerm.metadata_id, + StatisticsShortTerm.start_ts, + StatisticsShortTerm.sum, + StatisticsShortTerm.mean, + ) + stmt3 = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns2, + dt_util.utcnow(), + dt_util.utcnow(), + StatisticsShortTerm, + [0], + ) + cache_key_3 = stmt3._generate_cache_key() + assert cache_key_1 != cache_key_3 + + +def test_cache_key_for_generate_statistics_at_time_stmt(): + """Test cache key for _generate_statistics_at_time_stmt.""" + columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) + stmt = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0) + cache_key_1 = stmt._generate_cache_key() + stmt2 = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0) + cache_key_2 = stmt2._generate_cache_key() + assert cache_key_1 == cache_key_2 + columns2 = select( + StatisticsShortTerm.metadata_id, + StatisticsShortTerm.start_ts, + StatisticsShortTerm.sum, + StatisticsShortTerm.mean, + ) + stmt3 = _generate_statistics_at_time_stmt(columns2, StatisticsShortTerm, {0}, 0.0) + cache_key_3 = stmt3._generate_cache_key() + assert cache_key_1 != cache_key_3