mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Convert statistics to use lambda_stmt (#71903)
* Convert stats to use lambda_stmt - Since baked queries are now [deprecated in 1.4](https://docs.sqlalchemy.org/en/14/orm/extensions/baked.html#module-sqlalchemy.ext.baked) the next step is to convert these to `lambda_stmt` https://docs.sqlalchemy.org/en/14/core/connections.html#quick-guidelines-for-lambdas * Update homeassistant/components/recorder/statistics.py Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
26ee289be3
commit
f3c582815c
@ -14,11 +14,12 @@ import re
|
||||
from statistics import mean
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from sqlalchemy import bindparam, func
|
||||
from sqlalchemy import bindparam, func, lambda_stmt, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.exc import SQLAlchemyError, StatementError
|
||||
from sqlalchemy.ext import baked
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.expression import literal_column, true
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
@ -50,7 +51,12 @@ from .models import (
|
||||
process_timestamp,
|
||||
process_timestamp_to_utc_isoformat,
|
||||
)
|
||||
from .util import execute, retryable_database_job, session_scope
|
||||
from .util import (
|
||||
execute,
|
||||
execute_stmt_lambda_element,
|
||||
retryable_database_job,
|
||||
session_scope,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Recorder
|
||||
@ -120,8 +126,6 @@ QUERY_STATISTIC_META_ID = [
|
||||
StatisticsMeta.statistic_id,
|
||||
]
|
||||
|
||||
STATISTICS_BAKERY = "recorder_statistics_bakery"
|
||||
|
||||
|
||||
# Convert pressure, temperature and volume statistics from the normalized unit used for
|
||||
# statistics to the unit configured by the user
|
||||
@ -203,7 +207,6 @@ class ValidationIssue:
|
||||
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the history hooks."""
|
||||
hass.data[STATISTICS_BAKERY] = baked.bakery()
|
||||
|
||||
def _entity_id_changed(event: Event) -> None:
|
||||
"""Handle entity_id changed."""
|
||||
@ -420,6 +423,36 @@ def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_hourly_statistics_summary_mean_stmt(
|
||||
start_time: datetime, end_time: datetime
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate the summary mean statement for hourly statistics."""
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN))
|
||||
stmt += (
|
||||
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
|
||||
.filter(StatisticsShortTerm.start < end_time)
|
||||
.group_by(StatisticsShortTerm.metadata_id)
|
||||
.order_by(StatisticsShortTerm.metadata_id)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _compile_hourly_statistics_summary_sum_legacy_stmt(
|
||||
start_time: datetime, end_time: datetime
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate the legacy sum statement for hourly statistics.
|
||||
|
||||
This is used for databases not supporting row number.
|
||||
"""
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY))
|
||||
stmt += (
|
||||
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
|
||||
.filter(StatisticsShortTerm.start < end_time)
|
||||
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def compile_hourly_statistics(
|
||||
instance: Recorder, session: Session, start: datetime
|
||||
) -> None:
|
||||
@ -434,20 +467,8 @@ def compile_hourly_statistics(
|
||||
|
||||
# Compute last hour's average, min, max
|
||||
summary: dict[str, StatisticData] = {}
|
||||
baked_query = instance.hass.data[STATISTICS_BAKERY](
|
||||
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_MEAN)
|
||||
)
|
||||
|
||||
baked_query += lambda q: q.filter(
|
||||
StatisticsShortTerm.start >= bindparam("start_time")
|
||||
)
|
||||
baked_query += lambda q: q.filter(StatisticsShortTerm.start < bindparam("end_time"))
|
||||
baked_query += lambda q: q.group_by(StatisticsShortTerm.metadata_id)
|
||||
baked_query += lambda q: q.order_by(StatisticsShortTerm.metadata_id)
|
||||
|
||||
stats = execute(
|
||||
baked_query(session).params(start_time=start_time, end_time=end_time)
|
||||
)
|
||||
stmt = _compile_hourly_statistics_summary_mean_stmt(start_time, end_time)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
if stats:
|
||||
for stat in stats:
|
||||
@ -493,23 +514,8 @@ def compile_hourly_statistics(
|
||||
"sum": _sum,
|
||||
}
|
||||
else:
|
||||
baked_query = instance.hass.data[STATISTICS_BAKERY](
|
||||
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)
|
||||
)
|
||||
|
||||
baked_query += lambda q: q.filter(
|
||||
StatisticsShortTerm.start >= bindparam("start_time")
|
||||
)
|
||||
baked_query += lambda q: q.filter(
|
||||
StatisticsShortTerm.start < bindparam("end_time")
|
||||
)
|
||||
baked_query += lambda q: q.order_by(
|
||||
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()
|
||||
)
|
||||
|
||||
stats = execute(
|
||||
baked_query(session).params(start_time=start_time, end_time=end_time)
|
||||
)
|
||||
stmt = _compile_hourly_statistics_summary_sum_legacy_stmt(start_time, end_time)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
if stats:
|
||||
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore[no-any-return]
|
||||
@ -669,6 +675,24 @@ def _update_statistics(
|
||||
)
|
||||
|
||||
|
||||
def _generate_get_metadata_stmt(
|
||||
statistic_ids: list[str] | tuple[str] | None = None,
|
||||
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
|
||||
statistic_source: str | None = None,
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate a statement to fetch metadata."""
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META))
|
||||
if statistic_ids is not None:
|
||||
stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids))
|
||||
if statistic_source is not None:
|
||||
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source)
|
||||
if statistic_type == "mean":
|
||||
stmt += lambda q: q.where(StatisticsMeta.has_mean == true())
|
||||
elif statistic_type == "sum":
|
||||
stmt += lambda q: q.where(StatisticsMeta.has_sum == true())
|
||||
return stmt
|
||||
|
||||
|
||||
def get_metadata_with_session(
|
||||
hass: HomeAssistant,
|
||||
session: Session,
|
||||
@ -686,26 +710,8 @@ def get_metadata_with_session(
|
||||
"""
|
||||
|
||||
# Fetch metatadata from the database
|
||||
baked_query = hass.data[STATISTICS_BAKERY](
|
||||
lambda session: session.query(*QUERY_STATISTIC_META)
|
||||
)
|
||||
if statistic_ids is not None:
|
||||
baked_query += lambda q: q.filter(
|
||||
StatisticsMeta.statistic_id.in_(bindparam("statistic_ids"))
|
||||
)
|
||||
if statistic_source is not None:
|
||||
baked_query += lambda q: q.filter(
|
||||
StatisticsMeta.source == bindparam("statistic_source")
|
||||
)
|
||||
if statistic_type == "mean":
|
||||
baked_query += lambda q: q.filter(StatisticsMeta.has_mean == true())
|
||||
elif statistic_type == "sum":
|
||||
baked_query += lambda q: q.filter(StatisticsMeta.has_sum == true())
|
||||
result = execute(
|
||||
baked_query(session).params(
|
||||
statistic_ids=statistic_ids, statistic_source=statistic_source
|
||||
)
|
||||
)
|
||||
stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source)
|
||||
result = execute_stmt_lambda_element(session, stmt)
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
@ -852,31 +858,6 @@ def list_statistic_ids(
|
||||
]
|
||||
|
||||
|
||||
def _statistics_during_period_query(
|
||||
hass: HomeAssistant,
|
||||
end_time: datetime | None,
|
||||
statistic_ids: list[str] | None,
|
||||
baked_query: baked.BakedQuery,
|
||||
table: type[Statistics | StatisticsShortTerm],
|
||||
) -> Callable:
|
||||
"""Prepare a database query for statistics during a given period.
|
||||
|
||||
This prepares a baked query, so we don't insert the parameters yet.
|
||||
"""
|
||||
baked_query += lambda q: q.filter(table.start >= bindparam("start_time"))
|
||||
|
||||
if end_time is not None:
|
||||
baked_query += lambda q: q.filter(table.start < bindparam("end_time"))
|
||||
|
||||
if statistic_ids is not None:
|
||||
baked_query += lambda q: q.filter(
|
||||
table.metadata_id.in_(bindparam("metadata_ids"))
|
||||
)
|
||||
|
||||
baked_query += lambda q: q.order_by(table.metadata_id, table.start)
|
||||
return baked_query # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def _reduce_statistics(
|
||||
stats: dict[str, list[dict[str, Any]]],
|
||||
same_period: Callable[[datetime, datetime], bool],
|
||||
@ -975,6 +956,34 @@ def _reduce_statistics_per_month(
|
||||
return _reduce_statistics(stats, same_month, month_start_end, timedelta(days=31))
|
||||
|
||||
|
||||
def _statistics_during_period_stmt(
|
||||
start_time: datetime,
|
||||
end_time: datetime | None,
|
||||
statistic_ids: list[str] | None,
|
||||
metadata_ids: list[int] | None,
|
||||
table: type[Statistics | StatisticsShortTerm],
|
||||
) -> StatementLambdaElement:
|
||||
"""Prepare a database query for statistics during a given period.
|
||||
|
||||
This prepares a lambda_stmt query, so we don't insert the parameters yet.
|
||||
"""
|
||||
if table == StatisticsShortTerm:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||
else:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
|
||||
|
||||
stmt += lambda q: q.filter(table.start >= start_time)
|
||||
|
||||
if end_time is not None:
|
||||
stmt += lambda q: q.filter(table.start < end_time)
|
||||
|
||||
if statistic_ids is not None:
|
||||
stmt += lambda q: q.filter(table.metadata_id.in_(metadata_ids))
|
||||
|
||||
stmt += lambda q: q.order_by(table.metadata_id, table.start)
|
||||
return stmt
|
||||
|
||||
|
||||
def statistics_during_period(
|
||||
hass: HomeAssistant,
|
||||
start_time: datetime,
|
||||
@ -999,25 +1008,16 @@ def statistics_during_period(
|
||||
if statistic_ids is not None:
|
||||
metadata_ids = [metadata_id for metadata_id, _ in metadata.values()]
|
||||
|
||||
bakery = hass.data[STATISTICS_BAKERY]
|
||||
if period == "5minute":
|
||||
baked_query = bakery(
|
||||
lambda session: session.query(*QUERY_STATISTICS_SHORT_TERM)
|
||||
)
|
||||
table = StatisticsShortTerm
|
||||
else:
|
||||
baked_query = bakery(lambda session: session.query(*QUERY_STATISTICS))
|
||||
table = Statistics
|
||||
|
||||
baked_query = _statistics_during_period_query(
|
||||
hass, end_time, statistic_ids, baked_query, table
|
||||
stmt = _statistics_during_period_stmt(
|
||||
start_time, end_time, statistic_ids, metadata_ids, table
|
||||
)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
stats = execute(
|
||||
baked_query(session).params(
|
||||
start_time=start_time, end_time=end_time, metadata_ids=metadata_ids
|
||||
)
|
||||
)
|
||||
if not stats:
|
||||
return {}
|
||||
# Return statistics combined with metadata
|
||||
@ -1044,6 +1044,24 @@ def statistics_during_period(
|
||||
return _reduce_statistics_per_month(result)
|
||||
|
||||
|
||||
def _get_last_statistics_stmt(
|
||||
metadata_id: int,
|
||||
number_of_stats: int,
|
||||
table: type[Statistics | StatisticsShortTerm],
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate a statement for number_of_stats statistics for a given statistic_id."""
|
||||
if table == StatisticsShortTerm:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||
else:
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS))
|
||||
stmt += (
|
||||
lambda q: q.filter_by(metadata_id=metadata_id)
|
||||
.order_by(table.metadata_id, table.start.desc())
|
||||
.limit(number_of_stats)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _get_last_statistics(
|
||||
hass: HomeAssistant,
|
||||
number_of_stats: int,
|
||||
@ -1058,27 +1076,10 @@ def _get_last_statistics(
|
||||
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids)
|
||||
if not metadata:
|
||||
return {}
|
||||
|
||||
bakery = hass.data[STATISTICS_BAKERY]
|
||||
if table == StatisticsShortTerm:
|
||||
baked_query = bakery(
|
||||
lambda session: session.query(*QUERY_STATISTICS_SHORT_TERM)
|
||||
)
|
||||
else:
|
||||
baked_query = bakery(lambda session: session.query(*QUERY_STATISTICS))
|
||||
|
||||
baked_query += lambda q: q.filter_by(metadata_id=bindparam("metadata_id"))
|
||||
metadata_id = metadata[statistic_id][0]
|
||||
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats, table)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
baked_query += lambda q: q.order_by(table.metadata_id, table.start.desc())
|
||||
|
||||
baked_query += lambda q: q.limit(bindparam("number_of_stats"))
|
||||
|
||||
stats = execute(
|
||||
baked_query(session).params(
|
||||
number_of_stats=number_of_stats, metadata_id=metadata_id
|
||||
)
|
||||
)
|
||||
if not stats:
|
||||
return {}
|
||||
|
||||
@ -1113,14 +1114,36 @@ def get_last_short_term_statistics(
|
||||
)
|
||||
|
||||
|
||||
def _latest_short_term_statistics_stmt(
|
||||
metadata_ids: list[int],
|
||||
) -> StatementLambdaElement:
|
||||
"""Create the statement for finding the latest short term stat rows."""
|
||||
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
|
||||
most_recent_statistic_row = (
|
||||
select(
|
||||
StatisticsShortTerm.metadata_id,
|
||||
func.max(StatisticsShortTerm.start).label("start_max"),
|
||||
)
|
||||
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
||||
.group_by(StatisticsShortTerm.metadata_id)
|
||||
).subquery()
|
||||
stmt += lambda s: s.join(
|
||||
most_recent_statistic_row,
|
||||
(
|
||||
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
|
||||
== most_recent_statistic_row.c.metadata_id
|
||||
)
|
||||
& (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max),
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def get_latest_short_term_statistics(
|
||||
hass: HomeAssistant,
|
||||
statistic_ids: list[str],
|
||||
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
|
||||
) -> dict[str, list[dict]]:
|
||||
"""Return the latest short term statistics for a list of statistic_ids."""
|
||||
# This function doesn't use a baked query, we instead rely on the
|
||||
# "Transparent SQL Compilation Caching" feature introduced in SQLAlchemy 1.4
|
||||
with session_scope(hass=hass) as session:
|
||||
# Fetch metadata for the given statistic_ids
|
||||
if not metadata:
|
||||
@ -1134,24 +1157,8 @@ def get_latest_short_term_statistics(
|
||||
for statistic_id in statistic_ids
|
||||
if statistic_id in metadata
|
||||
]
|
||||
most_recent_statistic_row = (
|
||||
session.query(
|
||||
StatisticsShortTerm.metadata_id,
|
||||
func.max(StatisticsShortTerm.start).label("start_max"),
|
||||
)
|
||||
.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids))
|
||||
.group_by(StatisticsShortTerm.metadata_id)
|
||||
).subquery()
|
||||
stats = execute(
|
||||
session.query(*QUERY_STATISTICS_SHORT_TERM).join(
|
||||
most_recent_statistic_row,
|
||||
(
|
||||
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
|
||||
== most_recent_statistic_row.c.metadata_id
|
||||
)
|
||||
& (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max),
|
||||
)
|
||||
)
|
||||
stmt = _latest_short_term_statistics_stmt(metadata_ids)
|
||||
stats = execute_stmt_lambda_element(session, stmt)
|
||||
if not stats:
|
||||
return {}
|
||||
|
||||
@ -1203,7 +1210,7 @@ def _statistics_at_time(
|
||||
def _sorted_statistics_to_dict(
|
||||
hass: HomeAssistant,
|
||||
session: Session,
|
||||
stats: list,
|
||||
stats: Iterable[Row],
|
||||
statistic_ids: list[str] | None,
|
||||
_metadata: dict[str, tuple[int, StatisticMetaData]],
|
||||
convert_units: bool,
|
||||
@ -1215,7 +1222,7 @@ def _sorted_statistics_to_dict(
|
||||
result: dict = defaultdict(list)
|
||||
units = hass.config.units
|
||||
metadata = dict(_metadata.values())
|
||||
need_stat_at_start_time = set()
|
||||
need_stat_at_start_time: set[int] = set()
|
||||
stats_at_start_time = {}
|
||||
|
||||
def no_conversion(val: Any, _: Any) -> float | None:
|
||||
|
@ -20,7 +20,6 @@ from sqlalchemy import text
|
||||
from sqlalchemy.engine.cursor import CursorFetchStrategy
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.ext.baked import Result
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
@ -123,9 +122,7 @@ def commit(session: Session, work: Any) -> bool:
|
||||
|
||||
|
||||
def execute(
|
||||
qry: Query | Result,
|
||||
to_native: bool = False,
|
||||
validate_entity_ids: bool = True,
|
||||
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
|
||||
) -> list[Row]:
|
||||
"""Query the database and convert the objects to HA native form.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user