Convert history queries to use lambda_stmt (#71870)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2022-05-15 10:47:29 -05:00 committed by GitHub
parent 8ea5ec6f08
commit 98809675ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 313 additions and 262 deletions

View File

@ -20,7 +20,7 @@ from homeassistant.helpers.integration_platform import (
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from . import history, statistics, websocket_api from . import statistics, websocket_api
from .const import ( from .const import (
CONF_DB_INTEGRITY_CHECK, CONF_DB_INTEGRITY_CHECK,
DATA_INSTANCE, DATA_INSTANCE,
@ -166,7 +166,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
instance.async_register() instance.async_register()
instance.start() instance.start()
async_register_services(hass, instance) async_register_services(hass, instance)
history.async_setup(hass)
statistics.async_setup(hass) statistics.async_setup(hass)
websocket_api.async_setup(hass) websocket_api.async_setup(hass)
await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform) await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform)

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from sqlalchemy import not_, or_ from sqlalchemy import not_, or_
from sqlalchemy.ext.baked import BakedQuery
from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.elements import ClauseList
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
@ -60,16 +59,6 @@ class Filters:
or self.included_entity_globs or self.included_entity_globs
) )
def bake(self, baked_query: BakedQuery) -> BakedQuery:
"""Update a baked query.
Works the same as apply on a baked_query.
"""
if not self.has_config:
return
baked_query += lambda q: q.filter(self.entity_filter())
def entity_filter(self) -> ClauseList: def entity_filter(self) -> ClauseList:
"""Generate the entity filter query.""" """Generate the entity filter query."""
includes = [] includes = []

View File

@ -9,13 +9,12 @@ import logging
import time import time
from typing import Any, cast from typing import Any, cast
from sqlalchemy import Column, Text, and_, bindparam, func, or_ from sqlalchemy import Column, Text, and_, func, lambda_stmt, or_, select
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.ext import baked
from sqlalchemy.ext.baked import BakedQuery
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.websocket_api.const import ( from homeassistant.components.websocket_api.const import (
@ -36,7 +35,7 @@ from .models import (
process_timestamp_to_utc_isoformat, process_timestamp_to_utc_isoformat,
row_to_compressed_state, row_to_compressed_state,
) )
from .util import execute, session_scope from .util import execute_stmt_lambda_element, session_scope
# mypy: allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
@ -111,52 +110,48 @@ QUERY_STATES_NO_LAST_CHANGED = [
StateAttributes.shared_attrs, StateAttributes.shared_attrs,
] ]
HISTORY_BAKERY = "recorder_history_bakery"
def _schema_version(hass: HomeAssistant) -> int:
return recorder.get_instance(hass).schema_version
def bake_query_and_join_attributes( def lambda_stmt_and_join_attributes(
hass: HomeAssistant, no_attributes: bool, include_last_changed: bool = True schema_version: int, no_attributes: bool, include_last_changed: bool = True
) -> tuple[Any, bool]: ) -> tuple[StatementLambdaElement, bool]:
"""Return the initial backed query and if StateAttributes should be joined. """Return the lambda_stmt and if StateAttributes should be joined.
Because these are baked queries the values inside the lambdas need Because these are lambda_stmt the values inside the lambdas need
to be explicitly written out to avoid caching the wrong values. to be explicitly written out to avoid caching the wrong values.
""" """
bakery: baked.bakery = hass.data[HISTORY_BAKERY]
# If no_attributes was requested we do the query # If no_attributes was requested we do the query
# without the attributes fields and do not join the # without the attributes fields and do not join the
# state_attributes table # state_attributes table
if no_attributes: if no_attributes:
if include_last_changed: if include_last_changed:
return bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR)), False return lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR)), False
return ( return (
bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)), lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)),
False, False,
) )
# If we in the process of migrating schema we do # If we in the process of migrating schema we do
# not want to join the state_attributes table as we # not want to join the state_attributes table as we
# do not know if it will be there yet # do not know if it will be there yet
if recorder.get_instance(hass).schema_version < 25: if schema_version < 25:
if include_last_changed: if include_last_changed:
return ( return (
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25)), lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25)),
False, False,
) )
return ( return (
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)), lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)),
False, False,
) )
# Finally if no migration is in progress and no_attributes # Finally if no migration is in progress and no_attributes
# was not requested, we query both attributes columns and # was not requested, we query both attributes columns and
# join state_attributes # join state_attributes
if include_last_changed: if include_last_changed:
return bakery(lambda s: s.query(*QUERY_STATES)), True return lambda_stmt(lambda: select(*QUERY_STATES)), True
return bakery(lambda s: s.query(*QUERY_STATES_NO_LAST_CHANGED)), True return lambda_stmt(lambda: select(*QUERY_STATES_NO_LAST_CHANGED)), True
def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks."""
hass.data[HISTORY_BAKERY] = baked.bakery()
def get_significant_states( def get_significant_states(
@ -200,38 +195,30 @@ def _ignore_domains_filter(query: Query) -> Query:
) )
def _query_significant_states_with_session( def _significant_states_stmt(
hass: HomeAssistant, schema_version: int,
session: Session,
start_time: datetime, start_time: datetime,
end_time: datetime | None = None, end_time: datetime | None,
entity_ids: list[str] | None = None, entity_ids: list[str] | None,
filters: Filters | None = None, filters: Filters | None,
significant_changes_only: bool = True, significant_changes_only: bool,
no_attributes: bool = False, no_attributes: bool,
) -> list[Row]: ) -> StatementLambdaElement:
"""Query the database for significant state changes.""" """Query the database for significant state changes."""
if _LOGGER.isEnabledFor(logging.DEBUG): stmt, join_attributes = lambda_stmt_and_join_attributes(
timer_start = time.perf_counter() schema_version, no_attributes, include_last_changed=not significant_changes_only
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=True
) )
if (
if entity_ids is not None and len(entity_ids) == 1: entity_ids
if ( and len(entity_ids) == 1
significant_changes_only and significant_changes_only
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
): ):
baked_query, join_attributes = bake_query_and_join_attributes( stmt += lambda q: q.filter(
hass, no_attributes, include_last_changed=False (States.last_changed == States.last_updated) | States.last_changed.is_(None)
) )
baked_query += lambda q: q.filter(
(States.last_changed == States.last_updated)
| States.last_changed.is_(None)
)
elif significant_changes_only: elif significant_changes_only:
baked_query += lambda q: q.filter( stmt += lambda q: q.filter(
or_( or_(
*[ *[
States.entity_id.like(entity_domain) States.entity_id.like(entity_domain)
@ -244,36 +231,24 @@ def _query_significant_states_with_session(
) )
) )
if entity_ids is not None: if entity_ids:
baked_query += lambda q: q.filter( stmt += lambda q: q.filter(States.entity_id.in_(entity_ids))
States.entity_id.in_(bindparam("entity_ids", expanding=True))
)
else: else:
baked_query += _ignore_domains_filter stmt += _ignore_domains_filter
if filters: if filters and filters.has_config:
filters.bake(baked_query) entity_filter = filters.entity_filter()
stmt += lambda q: q.filter(entity_filter)
baked_query += lambda q: q.filter(States.last_updated > bindparam("start_time")) stmt += lambda q: q.filter(States.last_updated > start_time)
if end_time is not None: if end_time:
baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time")) stmt += lambda q: q.filter(States.last_updated < end_time)
if join_attributes: if join_attributes:
baked_query += lambda q: q.outerjoin( stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated) stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
return stmt
states = execute(
baked_query(session).params(
start_time=start_time, end_time=end_time, entity_ids=entity_ids
)
)
if _LOGGER.isEnabledFor(logging.DEBUG):
elapsed = time.perf_counter() - timer_start
_LOGGER.debug("get_significant_states took %fs", elapsed)
return states
def get_significant_states_with_session( def get_significant_states_with_session(
@ -301,9 +276,8 @@ def get_significant_states_with_session(
as well as all states from certain domains (for instance as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs). thermostat so that we get current temperature in our graphs).
""" """
states = _query_significant_states_with_session( stmt = _significant_states_stmt(
hass, _schema_version(hass),
session,
start_time, start_time,
end_time, end_time,
entity_ids, entity_ids,
@ -311,6 +285,9 @@ def get_significant_states_with_session(
significant_changes_only, significant_changes_only,
no_attributes, no_attributes,
) )
states = execute_stmt_lambda_element(
session, stmt, None if entity_ids else start_time, end_time
)
return _sorted_states_to_dict( return _sorted_states_to_dict(
hass, hass,
session, session,
@ -354,6 +331,38 @@ def get_full_significant_states_with_session(
) )
def _state_changed_during_period_stmt(
schema_version: int,
start_time: datetime,
end_time: datetime | None,
entity_id: str | None,
no_attributes: bool,
descending: bool,
limit: int | None,
) -> StatementLambdaElement:
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=False
)
stmt += lambda q: q.filter(
((States.last_changed == States.last_updated) | States.last_changed.is_(None))
& (States.last_updated > start_time)
)
if end_time:
stmt += lambda q: q.filter(States.last_updated < end_time)
stmt += lambda q: q.filter(States.entity_id == entity_id)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
if descending:
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc())
else:
stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
if limit:
stmt += lambda q: q.limit(limit)
return stmt
def state_changes_during_period( def state_changes_during_period(
hass: HomeAssistant, hass: HomeAssistant,
start_time: datetime, start_time: datetime,
@ -365,52 +374,21 @@ def state_changes_during_period(
include_start_time_state: bool = True, include_start_time_state: bool = True,
) -> MutableMapping[str, list[State]]: ) -> MutableMapping[str, list[State]]:
"""Return states changes during UTC period start_time - end_time.""" """Return states changes during UTC period start_time - end_time."""
entity_id = entity_id.lower() if entity_id is not None else None
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
baked_query, join_attributes = bake_query_and_join_attributes( stmt = _state_changed_during_period_stmt(
hass, no_attributes, include_last_changed=False _schema_version(hass),
start_time,
end_time,
entity_id,
no_attributes,
descending,
limit,
) )
states = execute_stmt_lambda_element(
baked_query += lambda q: q.filter( session, stmt, None if entity_id else start_time, end_time
(
(States.last_changed == States.last_updated)
| States.last_changed.is_(None)
)
& (States.last_updated > bindparam("start_time"))
) )
if end_time is not None:
baked_query += lambda q: q.filter(
States.last_updated < bindparam("end_time")
)
if entity_id is not None:
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
entity_id = entity_id.lower()
if join_attributes:
baked_query += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
if descending:
baked_query += lambda q: q.order_by(
States.entity_id, States.last_updated.desc()
)
else:
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
if limit:
baked_query += lambda q: q.limit(bindparam("limit"))
states = execute(
baked_query(session).params(
start_time=start_time,
end_time=end_time,
entity_id=entity_id,
limit=limit,
)
)
entity_ids = [entity_id] if entity_id is not None else None entity_ids = [entity_id] if entity_id is not None else None
return cast( return cast(
@ -426,41 +404,37 @@ def state_changes_during_period(
) )
def _get_last_state_changes_stmt(
schema_version: int, number_of_states: int, entity_id: str
) -> StatementLambdaElement:
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, False, include_last_changed=False
)
stmt += lambda q: q.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
).filter(States.entity_id == entity_id)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()).limit(
number_of_states
)
return stmt
def get_last_state_changes( def get_last_state_changes(
hass: HomeAssistant, number_of_states: int, entity_id: str hass: HomeAssistant, number_of_states: int, entity_id: str
) -> MutableMapping[str, list[State]]: ) -> MutableMapping[str, list[State]]:
"""Return the last number_of_states.""" """Return the last number_of_states."""
start_time = dt_util.utcnow() start_time = dt_util.utcnow()
entity_id = entity_id.lower() if entity_id is not None else None
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
baked_query, join_attributes = bake_query_and_join_attributes( stmt = _get_last_state_changes_stmt(
hass, False, include_last_changed=False _schema_version(hass), number_of_states, entity_id
) )
states = list(execute_stmt_lambda_element(session, stmt))
baked_query += lambda q: q.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
)
if entity_id is not None:
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
entity_id = entity_id.lower()
if join_attributes:
baked_query += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
baked_query += lambda q: q.order_by(
States.entity_id, States.last_updated.desc()
)
baked_query += lambda q: q.limit(bindparam("number_of_states"))
states = execute(
baked_query(session).params(
number_of_states=number_of_states, entity_id=entity_id
)
)
entity_ids = [entity_id] if entity_id is not None else None entity_ids = [entity_id] if entity_id is not None else None
return cast( return cast(
@ -476,96 +450,91 @@ def get_last_state_changes(
) )
def _most_recent_state_ids_entities_subquery(query: Query) -> Query: def _get_states_for_entites_stmt(
"""Query to find the most recent state id for specific entities.""" schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
entity_ids: list[str],
no_attributes: bool,
) -> StatementLambdaElement:
"""Baked query to get states for specific entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True
)
# We got an include-list of entities, accelerate the query by filtering already # We got an include-list of entities, accelerate the query by filtering already
# in the inner query. # in the inner query.
most_recent_state_ids = ( stmt += lambda q: q.where(
query.session.query(func.max(States.state_id).label("max_state_id")) States.state_id
.filter( == (
(States.last_updated >= bindparam("run_start")) select(func.max(States.state_id).label("max_state_id"))
& (States.last_updated < bindparam("utc_point_in_time")) .filter(
) (States.last_updated >= run_start)
.filter(States.entity_id.in_(bindparam("entity_ids", expanding=True))) & (States.last_updated < utc_point_in_time)
.group_by(States.entity_id) )
.subquery() .filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
).c.max_state_id
) )
return query.join(
most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id,
)
def _get_states_baked_query_for_entites(
hass: HomeAssistant,
no_attributes: bool = False,
) -> BakedQuery:
"""Baked query to get states for specific entities."""
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=True
)
baked_query += _most_recent_state_ids_entities_subquery
if join_attributes: if join_attributes:
baked_query += lambda q: q.outerjoin( stmt += lambda q: q.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
return baked_query return stmt
def _most_recent_state_ids_subquery(query: Query) -> Query: def _get_states_for_all_stmt(
"""Find the most recent state ids for all entiites.""" schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
filters: Filters | None,
no_attributes: bool,
) -> StatementLambdaElement:
"""Baked query to get states for all entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True
)
# We did not get an include-list of entities, query all states in the inner # We did not get an include-list of entities, query all states in the inner
# query, then filter out unwanted domains as well as applying the custom filter. # query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is # This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter. # not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = ( most_recent_states_by_date = (
query.session.query( select(
States.entity_id.label("max_entity_id"), States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"), func.max(States.last_updated).label("max_last_updated"),
) )
.filter( .filter(
(States.last_updated >= bindparam("run_start")) (States.last_updated >= run_start)
& (States.last_updated < bindparam("utc_point_in_time")) & (States.last_updated < utc_point_in_time)
) )
.group_by(States.entity_id) .group_by(States.entity_id)
.subquery() .subquery()
) )
most_recent_state_ids = ( stmt += lambda q: q.where(
query.session.query(func.max(States.state_id).label("max_state_id")) States.state_id
.join( == (
most_recent_states_by_date, select(func.max(States.state_id).label("max_state_id"))
and_( .join(
States.entity_id == most_recent_states_by_date.c.max_entity_id, most_recent_states_by_date,
States.last_updated == most_recent_states_by_date.c.max_last_updated, and_(
), States.entity_id == most_recent_states_by_date.c.max_entity_id,
) States.last_updated
.group_by(States.entity_id) == most_recent_states_by_date.c.max_last_updated,
.subquery() ),
)
.group_by(States.entity_id)
.subquery()
).c.max_state_id,
) )
return query.join( stmt += _ignore_domains_filter
most_recent_state_ids, if filters and filters.has_config:
States.state_id == most_recent_state_ids.c.max_state_id, entity_filter = filters.entity_filter()
) stmt += lambda q: q.filter(entity_filter)
def _get_states_baked_query_for_all(
hass: HomeAssistant,
filters: Filters | None = None,
no_attributes: bool = False,
) -> BakedQuery:
"""Baked query to get states for all entities."""
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=True
)
baked_query += _most_recent_state_ids_subquery
baked_query += _ignore_domains_filter
if filters:
filters.bake(baked_query)
if join_attributes: if join_attributes:
baked_query += lambda q: q.outerjoin( stmt += lambda q: q.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
return baked_query return stmt
def _get_rows_with_session( def _get_rows_with_session(
@ -576,11 +545,15 @@ def _get_rows_with_session(
run: RecorderRuns | None = None, run: RecorderRuns | None = None,
filters: Filters | None = None, filters: Filters | None = None,
no_attributes: bool = False, no_attributes: bool = False,
) -> list[Row]: ) -> Iterable[Row]:
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
schema_version = _schema_version(hass)
if entity_ids and len(entity_ids) == 1: if entity_ids and len(entity_ids) == 1:
return _get_single_entity_states_with_session( return execute_stmt_lambda_element(
hass, session, utc_point_in_time, entity_ids[0], no_attributes session,
_get_single_entity_states_stmt(
schema_version, utc_point_in_time, entity_ids[0], no_attributes
),
) )
if run is None: if run is None:
@ -593,46 +566,41 @@ def _get_rows_with_session(
# We have more than one entity to look at so we need to do a query on states # We have more than one entity to look at so we need to do a query on states
# since the last recorder run started. # since the last recorder run started.
if entity_ids: if entity_ids:
baked_query = _get_states_baked_query_for_entites(hass, no_attributes) stmt = _get_states_for_entites_stmt(
else: schema_version, run.start, utc_point_in_time, entity_ids, no_attributes
baked_query = _get_states_baked_query_for_all(hass, filters, no_attributes)
return execute(
baked_query(session).params(
run_start=run.start,
utc_point_in_time=utc_point_in_time,
entity_ids=entity_ids,
) )
) else:
stmt = _get_states_for_all_stmt(
schema_version, run.start, utc_point_in_time, filters, no_attributes
)
return execute_stmt_lambda_element(session, stmt)
def _get_single_entity_states_with_session( def _get_single_entity_states_stmt(
hass: HomeAssistant, schema_version: int,
session: Session,
utc_point_in_time: datetime, utc_point_in_time: datetime,
entity_id: str, entity_id: str,
no_attributes: bool = False, no_attributes: bool = False,
) -> list[Row]: ) -> StatementLambdaElement:
# Use an entirely different (and extremely fast) query if we only # Use an entirely different (and extremely fast) query if we only
# have a single entity id # have a single entity id
baked_query, join_attributes = bake_query_and_join_attributes( stmt, join_attributes = lambda_stmt_and_join_attributes(
hass, no_attributes, include_last_changed=True schema_version, no_attributes, include_last_changed=True
) )
baked_query += lambda q: q.filter( stmt += (
States.last_updated < bindparam("utc_point_in_time"), lambda q: q.filter(
States.entity_id == bindparam("entity_id"), States.last_updated < utc_point_in_time,
States.entity_id == entity_id,
)
.order_by(States.last_updated.desc())
.limit(1)
) )
if join_attributes: if join_attributes:
baked_query += lambda q: q.outerjoin( stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
baked_query += lambda q: q.order_by(States.last_updated.desc()).limit(1) return stmt
query = baked_query(session).params(
utc_point_in_time=utc_point_in_time, entity_id=entity_id
)
return execute(query)
def _sorted_states_to_dict( def _sorted_states_to_dict(

View File

@ -1,7 +1,7 @@
"""SQLAlchemy util functions.""" """SQLAlchemy util functions."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Generator from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager from contextlib import contextmanager
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
import functools import functools
@ -18,9 +18,12 @@ from awesomeversion import (
import ciso8601 import ciso8601
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.engine.cursor import CursorFetchStrategy from sqlalchemy.engine.cursor import CursorFetchStrategy
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.ext.baked import Result
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.lambdas import StatementLambdaElement
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -46,6 +49,7 @@ _LOGGER = logging.getLogger(__name__)
RETRIES = 3 RETRIES = 3
QUERY_RETRY_WAIT = 0.1 QUERY_RETRY_WAIT = 0.1
SQLITE3_POSTFIXES = ["", "-wal", "-shm"] SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
DEFAULT_YIELD_STATES_ROWS = 32768
MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER) MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER) MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER)
@ -119,8 +123,10 @@ def commit(session: Session, work: Any) -> bool:
def execute( def execute(
qry: Query, to_native: bool = False, validate_entity_ids: bool = True qry: Query | Result,
) -> list: to_native: bool = False,
validate_entity_ids: bool = True,
) -> list[Row]:
"""Query the database and convert the objects to HA native form. """Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections. This method also retries a few times in the case of stale connections.
@ -163,7 +169,39 @@ def execute(
raise raise
time.sleep(QUERY_RETRY_WAIT) time.sleep(QUERY_RETRY_WAIT)
assert False # unreachable assert False # unreachable # pragma: no cover
def execute_stmt_lambda_element(
session: Session,
stmt: StatementLambdaElement,
start_time: datetime | None = None,
end_time: datetime | None = None,
yield_per: int | None = DEFAULT_YIELD_STATES_ROWS,
) -> Iterable[Row]:
"""Execute a StatementLambdaElement.
If the time window passed is greater than one day
the execution method will switch to yield_per to
reduce memory pressure.
It is not recommended to pass a time window
when selecting non-ranged rows (ie selecting
specific entities) since they are usually faster
with .all().
"""
executed = session.execute(stmt)
use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1
for tryno in range(0, RETRIES):
try:
return executed.all() if use_all else executed.yield_per(yield_per) # type: ignore[no-any-return]
except SQLAlchemyError as err:
_LOGGER.error("Error executing query: %s", err)
if tryno == RETRIES - 1:
raise
time.sleep(QUERY_RETRY_WAIT)
assert False # unreachable # pragma: no cover
def validate_or_move_away_sqlite_database(dburl: str) -> bool: def validate_or_move_away_sqlite_database(dburl: str) -> bool:

View File

@ -6,10 +6,13 @@ from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.engine.result import ChunkedIteratorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.lambdas import StatementLambdaElement
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import util from homeassistant.components.recorder import history, util
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import RecorderRuns from homeassistant.components.recorder.models import RecorderRuns
from homeassistant.components.recorder.util import ( from homeassistant.components.recorder.util import (
@ -24,6 +27,7 @@ from homeassistant.util import dt as dt_util
from .common import corrupt_db_file, run_information_with_session from .common import corrupt_db_file, run_information_with_session
from tests.common import SetupRecorderInstanceT, async_test_home_assistant from tests.common import SetupRecorderInstanceT, async_test_home_assistant
from tests.components.recorder.common import wait_recording_done
def test_session_scope_not_setup(hass_recorder): def test_session_scope_not_setup(hass_recorder):
@ -510,8 +514,10 @@ def test_basic_sanity_check(hass_recorder):
def test_combined_checks(hass_recorder, caplog): def test_combined_checks(hass_recorder, caplog):
"""Run Checks on the open database.""" """Run Checks on the open database."""
hass = hass_recorder() hass = hass_recorder()
instance = recorder.get_instance(hass)
instance.db_retry_wait = 0
cursor = hass.data[DATA_INSTANCE].engine.raw_connection().cursor() cursor = instance.engine.raw_connection().cursor()
assert util.run_checks_on_open_db("fake_db_path", cursor) is None assert util.run_checks_on_open_db("fake_db_path", cursor) is None
assert "could not validate that the sqlite3 database" in caplog.text assert "could not validate that the sqlite3 database" in caplog.text
@ -658,3 +664,54 @@ def test_build_mysqldb_conv():
assert conv["DATETIME"]("2022-05-13T22:33:12.741") == datetime( assert conv["DATETIME"]("2022-05-13T22:33:12.741") == datetime(
2022, 5, 13, 22, 33, 12, 741000, tzinfo=None 2022, 5, 13, 22, 33, 12, 741000, tzinfo=None
) )
@patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0)
def test_execute_stmt_lambda_element(hass_recorder):
"""Test executing with execute_stmt_lambda_element."""
hass = hass_recorder()
instance = recorder.get_instance(hass)
hass.states.set("sensor.on", "on")
new_state = hass.states.get("sensor.on")
wait_recording_done(hass)
now = dt_util.utcnow()
tomorrow = now + timedelta(days=1)
one_week_from_now = now + timedelta(days=7)
class MockExecutor:
def __init__(self, stmt):
assert isinstance(stmt, StatementLambdaElement)
self.calls = 0
def all(self):
self.calls += 1
if self.calls == 2:
return ["mock_row"]
raise SQLAlchemyError
with session_scope(hass=hass) as session:
# No time window, we always get a list
stmt = history._get_single_entity_states_stmt(
instance.schema_version, dt_util.utcnow(), "sensor.on", False
)
rows = util.execute_stmt_lambda_element(session, stmt)
assert isinstance(rows, list)
assert rows[0].state == new_state.state
assert rows[0].entity_id == new_state.entity_id
# Time window >= 2 days, we get a ChunkedIteratorResult
rows = util.execute_stmt_lambda_element(session, stmt, now, one_week_from_now)
assert isinstance(rows, ChunkedIteratorResult)
row = next(rows)
assert row.state == new_state.state
assert row.entity_id == new_state.entity_id
# Time window < 2 days, we get a list
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
assert isinstance(rows, list)
assert rows[0].state == new_state.state
assert rows[0].entity_id == new_state.entity_id
with patch.object(session, "execute", MockExecutor):
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
assert rows == ["mock_row"]