From f3c582815c21fd76ae1e85baa4d5d870d5cf2191 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 18 May 2022 09:22:21 -0500 Subject: [PATCH] Convert statistics to use lambda_stmt (#71903) * Convert stats to use lambda_stmt - Since baked queries are now [deprecated in 1.4](https://docs.sqlalchemy.org/en/14/orm/extensions/baked.html#module-sqlalchemy.ext.baked) the next step is to convert these to `lambda_stmt` https://docs.sqlalchemy.org/en/14/core/connections.html#quick-guidelines-for-lambdas * Update homeassistant/components/recorder/statistics.py Co-authored-by: Erik Montnemery --- .../components/recorder/statistics.py | 277 +++++++++--------- homeassistant/components/recorder/util.py | 5 +- 2 files changed, 143 insertions(+), 139 deletions(-) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 732c042d9c2..77f56bc59fa 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -14,11 +14,12 @@ import re from statistics import mean from typing import TYPE_CHECKING, Any, Literal, overload -from sqlalchemy import bindparam, func +from sqlalchemy import bindparam, func, lambda_stmt, select +from sqlalchemy.engine.row import Row from sqlalchemy.exc import SQLAlchemyError, StatementError -from sqlalchemy.ext import baked from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal_column, true +from sqlalchemy.sql.lambdas import StatementLambdaElement import voluptuous as vol from homeassistant.const import ( @@ -50,7 +51,12 @@ from .models import ( process_timestamp, process_timestamp_to_utc_isoformat, ) -from .util import execute, retryable_database_job, session_scope +from .util import ( + execute, + execute_stmt_lambda_element, + retryable_database_job, + session_scope, +) if TYPE_CHECKING: from . import Recorder @@ -120,8 +126,6 @@ QUERY_STATISTIC_META_ID = [ StatisticsMeta.statistic_id, ] -STATISTICS_BAKERY = "recorder_statistics_bakery" - # Convert pressure, temperature and volume statistics from the normalized unit used for # statistics to the unit configured by the user @@ -203,7 +207,6 @@ class ValidationIssue: def async_setup(hass: HomeAssistant) -> None: """Set up the history hooks.""" - hass.data[STATISTICS_BAKERY] = baked.bakery() def _entity_id_changed(event: Event) -> None: """Handle entity_id changed.""" @@ -420,6 +423,36 @@ def delete_duplicates(hass: HomeAssistant, session: Session) -> None: ) +def _compile_hourly_statistics_summary_mean_stmt( + start_time: datetime, end_time: datetime +) -> StatementLambdaElement: + """Generate the summary mean statement for hourly statistics.""" + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN)) + stmt += ( + lambda q: q.filter(StatisticsShortTerm.start >= start_time) + .filter(StatisticsShortTerm.start < end_time) + .group_by(StatisticsShortTerm.metadata_id) + .order_by(StatisticsShortTerm.metadata_id) + ) + return stmt + + +def _compile_hourly_statistics_summary_sum_legacy_stmt( + start_time: datetime, end_time: datetime +) -> StatementLambdaElement: + """Generate the legacy sum statement for hourly statistics. + + This is used for databases not supporting row number. + """ + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)) + stmt += ( + lambda q: q.filter(StatisticsShortTerm.start >= start_time) + .filter(StatisticsShortTerm.start < end_time) + .order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()) + ) + return stmt + + def compile_hourly_statistics( instance: Recorder, session: Session, start: datetime ) -> None: @@ -434,20 +467,8 @@ def compile_hourly_statistics( # Compute last hour's average, min, max summary: dict[str, StatisticData] = {} - baked_query = instance.hass.data[STATISTICS_BAKERY]( - lambda session: session.query(*QUERY_STATISTICS_SUMMARY_MEAN) - ) - - baked_query += lambda q: q.filter( - StatisticsShortTerm.start >= bindparam("start_time") - ) - baked_query += lambda q: q.filter(StatisticsShortTerm.start < bindparam("end_time")) - baked_query += lambda q: q.group_by(StatisticsShortTerm.metadata_id) - baked_query += lambda q: q.order_by(StatisticsShortTerm.metadata_id) - - stats = execute( - baked_query(session).params(start_time=start_time, end_time=end_time) - ) + stmt = _compile_hourly_statistics_summary_mean_stmt(start_time, end_time) + stats = execute_stmt_lambda_element(session, stmt) if stats: for stat in stats: @@ -493,23 +514,8 @@ def compile_hourly_statistics( "sum": _sum, } else: - baked_query = instance.hass.data[STATISTICS_BAKERY]( - lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY) - ) - - baked_query += lambda q: q.filter( - StatisticsShortTerm.start >= bindparam("start_time") - ) - baked_query += lambda q: q.filter( - StatisticsShortTerm.start < bindparam("end_time") - ) - baked_query += lambda q: q.order_by( - StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc() - ) - - stats = execute( - baked_query(session).params(start_time=start_time, end_time=end_time) - ) + stmt = _compile_hourly_statistics_summary_sum_legacy_stmt(start_time, end_time) + stats = execute_stmt_lambda_element(session, stmt) if stats: for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore[no-any-return] @@ -669,6 +675,24 @@ def _update_statistics( ) +def _generate_get_metadata_stmt( + statistic_ids: list[str] | tuple[str] | None = None, + statistic_type: Literal["mean"] | Literal["sum"] | None = None, + statistic_source: str | None = None, +) -> StatementLambdaElement: + """Generate a statement to fetch metadata.""" + stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META)) + if statistic_ids is not None: + stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids)) + if statistic_source is not None: + stmt += lambda q: q.where(StatisticsMeta.source == statistic_source) + if statistic_type == "mean": + stmt += lambda q: q.where(StatisticsMeta.has_mean == true()) + elif statistic_type == "sum": + stmt += lambda q: q.where(StatisticsMeta.has_sum == true()) + return stmt + + def get_metadata_with_session( hass: HomeAssistant, session: Session, @@ -686,26 +710,8 @@ def get_metadata_with_session( """ # Fetch metatadata from the database - baked_query = hass.data[STATISTICS_BAKERY]( - lambda session: session.query(*QUERY_STATISTIC_META) - ) - if statistic_ids is not None: - baked_query += lambda q: q.filter( - StatisticsMeta.statistic_id.in_(bindparam("statistic_ids")) - ) - if statistic_source is not None: - baked_query += lambda q: q.filter( - StatisticsMeta.source == bindparam("statistic_source") - ) - if statistic_type == "mean": - baked_query += lambda q: q.filter(StatisticsMeta.has_mean == true()) - elif statistic_type == "sum": - baked_query += lambda q: q.filter(StatisticsMeta.has_sum == true()) - result = execute( - baked_query(session).params( - statistic_ids=statistic_ids, statistic_source=statistic_source - ) - ) + stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source) + result = execute_stmt_lambda_element(session, stmt) if not result: return {} @@ -852,31 +858,6 @@ def list_statistic_ids( ] -def _statistics_during_period_query( - hass: HomeAssistant, - end_time: datetime | None, - statistic_ids: list[str] | None, - baked_query: baked.BakedQuery, - table: type[Statistics | StatisticsShortTerm], -) -> Callable: - """Prepare a database query for statistics during a given period. - - This prepares a baked query, so we don't insert the parameters yet. - """ - baked_query += lambda q: q.filter(table.start >= bindparam("start_time")) - - if end_time is not None: - baked_query += lambda q: q.filter(table.start < bindparam("end_time")) - - if statistic_ids is not None: - baked_query += lambda q: q.filter( - table.metadata_id.in_(bindparam("metadata_ids")) - ) - - baked_query += lambda q: q.order_by(table.metadata_id, table.start) - return baked_query # type: ignore[no-any-return] - - def _reduce_statistics( stats: dict[str, list[dict[str, Any]]], same_period: Callable[[datetime, datetime], bool], @@ -975,6 +956,34 @@ def _reduce_statistics_per_month( return _reduce_statistics(stats, same_month, month_start_end, timedelta(days=31)) +def _statistics_during_period_stmt( + start_time: datetime, + end_time: datetime | None, + statistic_ids: list[str] | None, + metadata_ids: list[int] | None, + table: type[Statistics | StatisticsShortTerm], +) -> StatementLambdaElement: + """Prepare a database query for statistics during a given period. + + This prepares a lambda_stmt query, so we don't insert the parameters yet. + """ + if table == StatisticsShortTerm: + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) + else: + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS)) + + stmt += lambda q: q.filter(table.start >= start_time) + + if end_time is not None: + stmt += lambda q: q.filter(table.start < end_time) + + if statistic_ids is not None: + stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids)) + + stmt += lambda q: q.order_by(table.metadata_id, table.start) + return stmt + + def statistics_during_period( hass: HomeAssistant, start_time: datetime, @@ -999,25 +1008,16 @@ def statistics_during_period( if statistic_ids is not None: metadata_ids = [metadata_id for metadata_id, _ in metadata.values()] - bakery = hass.data[STATISTICS_BAKERY] if period == "5minute": - baked_query = bakery( - lambda session: session.query(*QUERY_STATISTICS_SHORT_TERM) - ) table = StatisticsShortTerm else: - baked_query = bakery(lambda session: session.query(*QUERY_STATISTICS)) table = Statistics - baked_query = _statistics_during_period_query( - hass, end_time, statistic_ids, baked_query, table + stmt = _statistics_during_period_stmt( + start_time, end_time, statistic_ids, metadata_ids, table ) + stats = execute_stmt_lambda_element(session, stmt) - stats = execute( - baked_query(session).params( - start_time=start_time, end_time=end_time, metadata_ids=metadata_ids - ) - ) if not stats: return {} # Return statistics combined with metadata @@ -1044,6 +1044,24 @@ def statistics_during_period( return _reduce_statistics_per_month(result) +def _get_last_statistics_stmt( + metadata_id: int, + number_of_stats: int, + table: type[Statistics | StatisticsShortTerm], +) -> StatementLambdaElement: + """Generate a statement for number_of_stats statistics for a given statistic_id.""" + if table == StatisticsShortTerm: + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) + else: + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS)) + stmt += ( + lambda q: q.filter_by(metadata_id=metadata_id) + .order_by(table.metadata_id, table.start.desc()) + .limit(number_of_stats) + ) + return stmt + + def _get_last_statistics( hass: HomeAssistant, number_of_stats: int, @@ -1058,27 +1076,10 @@ def _get_last_statistics( metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids) if not metadata: return {} - - bakery = hass.data[STATISTICS_BAKERY] - if table == StatisticsShortTerm: - baked_query = bakery( - lambda session: session.query(*QUERY_STATISTICS_SHORT_TERM) - ) - else: - baked_query = bakery(lambda session: session.query(*QUERY_STATISTICS)) - - baked_query += lambda q: q.filter_by(metadata_id=bindparam("metadata_id")) metadata_id = metadata[statistic_id][0] + stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table) + stats = execute_stmt_lambda_element(session, stmt) - baked_query += lambda q: q.order_by(table.metadata_id, table.start.desc()) - - baked_query += lambda q: q.limit(bindparam("number_of_stats")) - - stats = execute( - baked_query(session).params( - number_of_stats=number_of_stats, metadata_id=metadata_id - ) - ) if not stats: return {} @@ -1113,14 +1114,36 @@ def get_last_short_term_statistics( ) +def _latest_short_term_statistics_stmt( + metadata_ids: list[int], +) -> StatementLambdaElement: + """Create the statement for finding the latest short term stat rows.""" + stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) + most_recent_statistic_row = ( + select( + StatisticsShortTerm.metadata_id, + func.max(StatisticsShortTerm.start).label("start_max"), + ) + .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) + .group_by(StatisticsShortTerm.metadata_id) + ).subquery() + stmt += lambda s: s.join( + most_recent_statistic_row, + ( + StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable + == most_recent_statistic_row.c.metadata_id + ) + & (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max), + ) + return stmt + + def get_latest_short_term_statistics( hass: HomeAssistant, statistic_ids: list[str], metadata: dict[str, tuple[int, StatisticMetaData]] | None = None, ) -> dict[str, list[dict]]: """Return the latest short term statistics for a list of statistic_ids.""" - # This function doesn't use a baked query, we instead rely on the - # "Transparent SQL Compilation Caching" feature introduced in SQLAlchemy 1.4 with session_scope(hass=hass) as session: # Fetch metadata for the given statistic_ids if not metadata: @@ -1134,24 +1157,8 @@ def get_latest_short_term_statistics( for statistic_id in statistic_ids if statistic_id in metadata ] - most_recent_statistic_row = ( - session.query( - StatisticsShortTerm.metadata_id, - func.max(StatisticsShortTerm.start).label("start_max"), - ) - .filter(StatisticsShortTerm.metadata_id.in_(metadata_ids)) - .group_by(StatisticsShortTerm.metadata_id) - ).subquery() - stats = execute( - session.query(*QUERY_STATISTICS_SHORT_TERM).join( - most_recent_statistic_row, - ( - StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable - == most_recent_statistic_row.c.metadata_id - ) - & (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max), - ) - ) + stmt = _latest_short_term_statistics_stmt(metadata_ids) + stats = execute_stmt_lambda_element(session, stmt) if not stats: return {} @@ -1203,7 +1210,7 @@ def _statistics_at_time( def _sorted_statistics_to_dict( hass: HomeAssistant, session: Session, - stats: list, + stats: Iterable[Row], statistic_ids: list[str] | None, _metadata: dict[str, tuple[int, StatisticMetaData]], convert_units: bool, @@ -1215,7 +1222,7 @@ def _sorted_statistics_to_dict( result: dict = defaultdict(list) units = hass.config.units metadata = dict(_metadata.values()) - need_stat_at_start_time = set() + need_stat_at_start_time: set[int] = set() stats_at_start_time = {} def no_conversion(val: Any, _: Any) -> float | None: diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index e464ac4126b..f086119f7f9 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -20,7 +20,6 @@ from sqlalchemy import text from sqlalchemy.engine.cursor import CursorFetchStrategy from sqlalchemy.engine.row import Row from sqlalchemy.exc import OperationalError, SQLAlchemyError -from sqlalchemy.ext.baked import Result from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.lambdas import StatementLambdaElement @@ -123,9 +122,7 @@ def commit(session: Session, work: Any) -> bool: def execute( - qry: Query | Result, - to_native: bool = False, - validate_entity_ids: bool = True, + qry: Query, to_native: bool = False, validate_entity_ids: bool = True ) -> list[Row]: """Query the database and convert the objects to HA native form.