From 84292d4797367e1a153c90874386eefed67f035a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 1 Apr 2023 15:40:14 -1000 Subject: [PATCH] Cleanup some duplicate code in recorder statistics (#90549) * Cleanup some duplicate code in recorder statistics * more cleanup * reduce * reduce --- .../components/recorder/statistics.py | 106 +++++++++--------- tests/components/recorder/test_statistics.py | 26 ++--- 2 files changed, 60 insertions(+), 72 deletions(-) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 0122ba4464b..70e82fad5d7 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -1034,18 +1034,19 @@ def _reduce_statistics_per_month( def _generate_statistics_during_period_stmt( - columns: Select, start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, table: type[StatisticsBase], + types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> StatementLambdaElement: """Prepare a database query for statistics during a given period. This prepares a lambda_stmt query, so we don't insert the parameters yet. """ start_time_ts = start_time.timestamp() - stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts)) + stmt = _generate_select_columns_for_types_stmt(table, types) + stmt += lambda q: q.filter(table.start_ts >= start_time_ts) if end_time is not None: end_time_ts = end_time.timestamp() stmt += lambda q: q.filter(table.start_ts < end_time_ts) @@ -1491,6 +1492,33 @@ def statistic_during_period( return {key: convert(value) if convert else value for key, value in result.items()} +_type_column_mapping = { + "last_reset": "last_reset_ts", + "max": "max", + "mean": "mean", + "min": "min", + "state": "state", + "sum": "sum", +} + + +def _generate_select_columns_for_types_stmt( + table: type[StatisticsBase], + types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], +) -> StatementLambdaElement: + columns = select(table.metadata_id, table.start_ts) + track_on: list[str | None] = [ + table.__tablename__, # type: ignore[attr-defined] + ] + for key, column in _type_column_mapping.items(): + if key in types: + columns = columns.add_columns(getattr(table, column)) + track_on.append(column) + else: + track_on.append(None) + return lambda_stmt(lambda: columns, track_on=track_on) + + def _statistics_during_period_with_session( hass: HomeAssistant, session: Session, @@ -1525,21 +1553,8 @@ def _statistics_during_period_with_session( table: type[Statistics | StatisticsShortTerm] = ( Statistics if period != "5minute" else StatisticsShortTerm ) - columns = select(table.metadata_id, table.start_ts) # type: ignore[call-overload] - if "last_reset" in types: - columns = columns.add_columns(table.last_reset_ts) - if "max" in types: - columns = columns.add_columns(table.max) - if "mean" in types: - columns = columns.add_columns(table.mean) - if "min" in types: - columns = columns.add_columns(table.min) - if "state" in types: - columns = columns.add_columns(table.state) - if "sum" in types: - columns = columns.add_columns(table.sum) stmt = _generate_statistics_during_period_stmt( - columns, start_time, end_time, metadata_ids, table + start_time, end_time, metadata_ids, table, types ) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) @@ -1771,34 +1786,34 @@ def get_latest_short_term_statistics( def _generate_statistics_at_time_stmt( - columns: Select, table: type[StatisticsBase], metadata_ids: set[int], start_time_ts: float, + types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> StatementLambdaElement: """Create the statement for finding the statistics for a given time.""" - return lambda_stmt( - lambda: columns.join( - ( - most_recent_statistic_ids := ( - select( - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(table.start_ts).label("max_start_ts"), - table.metadata_id.label("max_metadata_id"), - ) - .filter(table.start_ts < start_time_ts) - .filter(table.metadata_id.in_(metadata_ids)) - .group_by(table.metadata_id) - .subquery() + stmt = _generate_select_columns_for_types_stmt(table, types) + stmt += lambda q: q.join( + ( + most_recent_statistic_ids := ( + select( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(table.start_ts).label("max_start_ts"), + table.metadata_id.label("max_metadata_id"), ) - ), - and_( - table.start_ts == most_recent_statistic_ids.c.max_start_ts, - table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, - ), - ) + .filter(table.start_ts < start_time_ts) + .filter(table.metadata_id.in_(metadata_ids)) + .group_by(table.metadata_id) + .subquery() + ) + ), + and_( + table.start_ts == most_recent_statistic_ids.c.max_start_ts, + table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, + ), ) + return stmt def _statistics_at_time( @@ -1809,23 +1824,8 @@ def _statistics_at_time( types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> Sequence[Row] | None: """Return last known statistics, earlier than start_time, for the metadata_ids.""" - columns = select(table.metadata_id, table.start_ts) - if "last_reset" in types: - columns = columns.add_columns(table.last_reset_ts) - if "max" in types: - columns = columns.add_columns(table.max) - if "mean" in types: - columns = columns.add_columns(table.mean) - if "min" in types: - columns = columns.add_columns(table.min) - if "state" in types: - columns = columns.add_columns(table.state) - if "sum" in types: - columns = columns.add_columns(table.sum) start_time_ts = start_time.timestamp() - stmt = _generate_statistics_at_time_stmt( - columns, table, metadata_ids, start_time_ts - ) + stmt = _generate_statistics_at_time_stmt(table, metadata_ids, start_time_ts, types) return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index ff429794315..25890fe475b 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -1244,28 +1244,21 @@ def test_monthly_statistics( def test_cache_key_for_generate_statistics_during_period_stmt() -> None: """Test cache key for _generate_statistics_during_period_stmt.""" - columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) stmt = _generate_statistics_during_period_stmt( - columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm + dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, set() ) cache_key_1 = stmt._generate_cache_key() stmt2 = _generate_statistics_during_period_stmt( - columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm + dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, set() ) cache_key_2 = stmt2._generate_cache_key() assert cache_key_1 == cache_key_2 - columns2 = select( - StatisticsShortTerm.metadata_id, - StatisticsShortTerm.start_ts, - StatisticsShortTerm.sum, - StatisticsShortTerm.mean, - ) stmt3 = _generate_statistics_during_period_stmt( - columns2, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, + {"sum", "mean"}, ) cache_key_3 = stmt3._generate_cache_key() assert cache_key_1 != cache_key_3 @@ -1321,18 +1314,13 @@ def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt() -> N def test_cache_key_for_generate_statistics_at_time_stmt() -> None: """Test cache key for _generate_statistics_at_time_stmt.""" - columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) - stmt = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0) + stmt = _generate_statistics_at_time_stmt(StatisticsShortTerm, {0}, 0.0, set()) cache_key_1 = stmt._generate_cache_key() - stmt2 = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0) + stmt2 = _generate_statistics_at_time_stmt(StatisticsShortTerm, {0}, 0.0, set()) cache_key_2 = stmt2._generate_cache_key() assert cache_key_1 == cache_key_2 - columns2 = select( - StatisticsShortTerm.metadata_id, - StatisticsShortTerm.start_ts, - StatisticsShortTerm.sum, - StatisticsShortTerm.mean, + stmt3 = _generate_statistics_at_time_stmt( + StatisticsShortTerm, {0}, 0.0, {"sum", "mean"} ) - stmt3 = _generate_statistics_at_time_stmt(columns2, StatisticsShortTerm, {0}, 0.0) cache_key_3 = stmt3._generate_cache_key() assert cache_key_1 != cache_key_3