Convert statistics to use lambda_stmt (#71903)

* Convert stats to use lambda_stmt

- Since baked queries are now [deprecated in 1.4](https://docs.sqlalchemy.org/en/14/orm/extensions/baked.html#module-sqlalchemy.ext.baked) the
  next step is to convert these to `lambda_stmt`

https://docs.sqlalchemy.org/en/14/core/connections.html#quick-guidelines-for-lambdas

* Update homeassistant/components/recorder/statistics.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
J. Nick Koston 2022-05-18 09:22:21 -05:00 committed by GitHub
parent 26ee289be3
commit f3c582815c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 143 additions and 139 deletions

View File

@ -14,11 +14,12 @@ import re
from statistics import mean
from typing import TYPE_CHECKING, Any, Literal, overload
from sqlalchemy import bindparam, func
from sqlalchemy import bindparam, func, lambda_stmt, select
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.ext import baked
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.lambdas import StatementLambdaElement
import voluptuous as vol
from homeassistant.const import (
@ -50,7 +51,12 @@ from .models import (
process_timestamp,
process_timestamp_to_utc_isoformat,
)
from .util import execute, retryable_database_job, session_scope
from .util import (
execute,
execute_stmt_lambda_element,
retryable_database_job,
session_scope,
)
if TYPE_CHECKING:
from . import Recorder
@ -120,8 +126,6 @@ QUERY_STATISTIC_META_ID = [
StatisticsMeta.statistic_id,
]
STATISTICS_BAKERY = "recorder_statistics_bakery"
# Convert pressure, temperature and volume statistics from the normalized unit used for
# statistics to the unit configured by the user
@ -203,7 +207,6 @@ class ValidationIssue:
def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks."""
hass.data[STATISTICS_BAKERY] = baked.bakery()
def _entity_id_changed(event: Event) -> None:
"""Handle entity_id changed."""
@ -420,6 +423,36 @@ def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
)
def _compile_hourly_statistics_summary_mean_stmt(
start_time: datetime, end_time: datetime
) -> StatementLambdaElement:
"""Generate the summary mean statement for hourly statistics."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN))
stmt += (
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
.filter(StatisticsShortTerm.start < end_time)
.group_by(StatisticsShortTerm.metadata_id)
.order_by(StatisticsShortTerm.metadata_id)
)
return stmt
def _compile_hourly_statistics_summary_sum_legacy_stmt(
start_time: datetime, end_time: datetime
) -> StatementLambdaElement:
"""Generate the legacy sum statement for hourly statistics.
This is used for databases not supporting row number.
"""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY))
stmt += (
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
.filter(StatisticsShortTerm.start < end_time)
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
)
return stmt
def compile_hourly_statistics(
instance: Recorder, session: Session, start: datetime
) -> None:
@ -434,20 +467,8 @@ def compile_hourly_statistics(
# Compute last hour's average, min, max
summary: dict[str, StatisticData] = {}
baked_query = instance.hass.data[STATISTICS_BAKERY](
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_MEAN)
)
baked_query += lambda q: q.filter(
StatisticsShortTerm.start >= bindparam("start_time")
)
baked_query += lambda q: q.filter(StatisticsShortTerm.start < bindparam("end_time"))
baked_query += lambda q: q.group_by(StatisticsShortTerm.metadata_id)
baked_query += lambda q: q.order_by(StatisticsShortTerm.metadata_id)
stats = execute(
baked_query(session).params(start_time=start_time, end_time=end_time)
)
stmt = _compile_hourly_statistics_summary_mean_stmt(start_time, end_time)
stats = execute_stmt_lambda_element(session, stmt)
if stats:
for stat in stats:
@ -493,23 +514,8 @@ def compile_hourly_statistics(
"sum": _sum,
}
else:
baked_query = instance.hass.data[STATISTICS_BAKERY](
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)
)
baked_query += lambda q: q.filter(
StatisticsShortTerm.start >= bindparam("start_time")
)
baked_query += lambda q: q.filter(
StatisticsShortTerm.start < bindparam("end_time")
)
baked_query += lambda q: q.order_by(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()
)
stats = execute(
baked_query(session).params(start_time=start_time, end_time=end_time)
)
stmt = _compile_hourly_statistics_summary_sum_legacy_stmt(start_time, end_time)
stats = execute_stmt_lambda_element(session, stmt)
if stats:
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore[no-any-return]
@ -669,6 +675,24 @@ def _update_statistics(
)
def _generate_get_metadata_stmt(
statistic_ids: list[str] | tuple[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> StatementLambdaElement:
"""Generate a statement to fetch metadata."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META))
if statistic_ids is not None:
stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids))
if statistic_source is not None:
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source)
if statistic_type == "mean":
stmt += lambda q: q.where(StatisticsMeta.has_mean == true())
elif statistic_type == "sum":
stmt += lambda q: q.where(StatisticsMeta.has_sum == true())
return stmt
def get_metadata_with_session(
hass: HomeAssistant,
session: Session,
@ -686,26 +710,8 @@ def get_metadata_with_session(
"""
# Fetch metatadata from the database
baked_query = hass.data[STATISTICS_BAKERY](
lambda session: session.query(*QUERY_STATISTIC_META)
)
if statistic_ids is not None:
baked_query += lambda q: q.filter(
StatisticsMeta.statistic_id.in_(bindparam("statistic_ids"))
)
if statistic_source is not None:
baked_query += lambda q: q.filter(
StatisticsMeta.source == bindparam("statistic_source")
)
if statistic_type == "mean":
baked_query += lambda q: q.filter(StatisticsMeta.has_mean == true())
elif statistic_type == "sum":
baked_query += lambda q: q.filter(StatisticsMeta.has_sum == true())
result = execute(
baked_query(session).params(
statistic_ids=statistic_ids, statistic_source=statistic_source
)
)
stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source)
result = execute_stmt_lambda_element(session, stmt)
if not result:
return {}
@ -852,31 +858,6 @@ def list_statistic_ids(
]
def _statistics_during_period_query(
hass: HomeAssistant,
end_time: datetime | None,
statistic_ids: list[str] | None,
baked_query: baked.BakedQuery,
table: type[Statistics | StatisticsShortTerm],
) -> Callable:
"""Prepare a database query for statistics during a given period.
This prepares a baked query, so we don't insert the parameters yet.
"""
baked_query += lambda q: q.filter(table.start >= bindparam("start_time"))
if end_time is not None:
baked_query += lambda q: q.filter(table.start < bindparam("end_time"))
if statistic_ids is not None:
baked_query += lambda q: q.filter(
table.metadata_id.in_(bindparam("metadata_ids"))
)
baked_query += lambda q: q.order_by(table.metadata_id, table.start)
return baked_query # type: ignore[no-any-return]
def _reduce_statistics(
stats: dict[str, list[dict[str, Any]]],
same_period: Callable[[datetime, datetime], bool],
@ -975,6 +956,34 @@ def _reduce_statistics_per_month(
return _reduce_statistics(stats, same_month, month_start_end, timedelta(days=31))
def _statistics_during_period_stmt(
start_time: datetime,
end_time: datetime | None,
statistic_ids: list[str] | None,
metadata_ids: list[int] | None,
table: type[Statistics | StatisticsShortTerm],
) -> StatementLambdaElement:
"""Prepare a database query for statistics during a given period.
This prepares a lambda_stmt query, so we don't insert the parameters yet.
"""
if table == StatisticsShortTerm:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
else:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
stmt += lambda q: q.filter(table.start >= start_time)
if end_time is not None:
stmt += lambda q: q.filter(table.start < end_time)
if statistic_ids is not None:
stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids))
stmt += lambda q: q.order_by(table.metadata_id, table.start)
return stmt
def statistics_during_period(
hass: HomeAssistant,
start_time: datetime,
@ -999,25 +1008,16 @@ def statistics_during_period(
if statistic_ids is not None:
metadata_ids = [metadata_id for metadata_id, _ in metadata.values()]
bakery = hass.data[STATISTICS_BAKERY]
if period == "5minute":
baked_query = bakery(
lambda session: session.query(*QUERY_STATISTICS_SHORT_TERM)
)
table = StatisticsShortTerm
else:
baked_query = bakery(lambda session: session.query(*QUERY_STATISTICS))
table = Statistics
baked_query = _statistics_during_period_query(
hass, end_time, statistic_ids, baked_query, table
stmt = _statistics_during_period_stmt(
start_time, end_time, statistic_ids, metadata_ids, table
)
stats = execute_stmt_lambda_element(session, stmt)
stats = execute(
baked_query(session).params(
start_time=start_time, end_time=end_time, metadata_ids=metadata_ids
)
)
if not stats:
return {}
# Return statistics combined with metadata
@ -1044,6 +1044,24 @@ def statistics_during_period(
return _reduce_statistics_per_month(result)
def _get_last_statistics_stmt(
metadata_id: int,
number_of_stats: int,
table: type[Statistics | StatisticsShortTerm],
) -> StatementLambdaElement:
"""Generate a statement for number_of_stats statistics for a given statistic_id."""
if table == StatisticsShortTerm:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
else:
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
stmt += (
lambda q: q.filter_by(metadata_id=metadata_id)
.order_by(table.metadata_id, table.start.desc())
.limit(number_of_stats)
)
return stmt
def _get_last_statistics(
hass: HomeAssistant,
number_of_stats: int,
@ -1058,27 +1076,10 @@ def _get_last_statistics(
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids)
if not metadata:
return {}
bakery = hass.data[STATISTICS_BAKERY]
if table == StatisticsShortTerm:
baked_query = bakery(
lambda session: session.query(*QUERY_STATISTICS_SHORT_TERM)
)
else:
baked_query = bakery(lambda session: session.query(*QUERY_STATISTICS))
baked_query += lambda q: q.filter_by(metadata_id=bindparam("metadata_id"))
metadata_id = metadata[statistic_id][0]
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table)
stats = execute_stmt_lambda_element(session, stmt)
baked_query += lambda q: q.order_by(table.metadata_id, table.start.desc())
baked_query += lambda q: q.limit(bindparam("number_of_stats"))
stats = execute(
baked_query(session).params(
number_of_stats=number_of_stats, metadata_id=metadata_id
)
)
if not stats:
return {}
@ -1113,14 +1114,36 @@ def get_last_short_term_statistics(
)
def _latest_short_term_statistics_stmt(
metadata_ids: list[int],
) -> StatementLambdaElement:
"""Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
most_recent_statistic_row = (
select(
StatisticsShortTerm.metadata_id,
func.max(StatisticsShortTerm.start).label("start_max"),
)
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
.group_by(StatisticsShortTerm.metadata_id)
).subquery()
stmt += lambda s: s.join(
most_recent_statistic_row,
(
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
== most_recent_statistic_row.c.metadata_id
)
& (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max),
)
return stmt
def get_latest_short_term_statistics(
hass: HomeAssistant,
statistic_ids: list[str],
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[dict]]:
"""Return the latest short term statistics for a list of statistic_ids."""
# This function doesn't use a baked query, we instead rely on the
# "Transparent SQL Compilation Caching" feature introduced in SQLAlchemy 1.4
with session_scope(hass=hass) as session:
# Fetch metadata for the given statistic_ids
if not metadata:
@ -1134,24 +1157,8 @@ def get_latest_short_term_statistics(
for statistic_id in statistic_ids
if statistic_id in metadata
]
most_recent_statistic_row = (
session.query(
StatisticsShortTerm.metadata_id,
func.max(StatisticsShortTerm.start).label("start_max"),
)
.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids))
.group_by(StatisticsShortTerm.metadata_id)
).subquery()
stats = execute(
session.query(*QUERY_STATISTICS_SHORT_TERM).join(
most_recent_statistic_row,
(
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
== most_recent_statistic_row.c.metadata_id
)
& (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max),
)
)
stmt = _latest_short_term_statistics_stmt(metadata_ids)
stats = execute_stmt_lambda_element(session, stmt)
if not stats:
return {}
@ -1203,7 +1210,7 @@ def _statistics_at_time(
def _sorted_statistics_to_dict(
hass: HomeAssistant,
session: Session,
stats: list,
stats: Iterable[Row],
statistic_ids: list[str] | None,
_metadata: dict[str, tuple[int, StatisticMetaData]],
convert_units: bool,
@ -1215,7 +1222,7 @@ def _sorted_statistics_to_dict(
result: dict = defaultdict(list)
units = hass.config.units
metadata = dict(_metadata.values())
need_stat_at_start_time = set()
need_stat_at_start_time: set[int] = set()
stats_at_start_time = {}
def no_conversion(val: Any, _: Any) -> float | None:

View File

@ -20,7 +20,6 @@ from sqlalchemy import text
from sqlalchemy.engine.cursor import CursorFetchStrategy
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.ext.baked import Result
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.lambdas import StatementLambdaElement
@ -123,9 +122,7 @@ def commit(session: Session, work: Any) -> bool:
def execute(
qry: Query | Result,
to_native: bool = False,
validate_entity_ids: bool = True,
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
) -> list[Row]:
"""Query the database and convert the objects to HA native form.