mirror of
https://github.com/home-assistant/core.git
synced 2025-04-29 03:37:51 +00:00
Convert history queries to use lambda_stmt (#71870)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
8ea5ec6f08
commit
98809675ff
@ -20,7 +20,7 @@ from homeassistant.helpers.integration_platform import (
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from . import history, statistics, websocket_api
|
||||
from . import statistics, websocket_api
|
||||
from .const import (
|
||||
CONF_DB_INTEGRITY_CHECK,
|
||||
DATA_INSTANCE,
|
||||
@ -166,7 +166,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
instance.async_register()
|
||||
instance.start()
|
||||
async_register_services(hass, instance)
|
||||
history.async_setup(hass)
|
||||
statistics.async_setup(hass)
|
||||
websocket_api.async_setup(hass)
|
||||
await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform)
|
||||
|
@ -2,7 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import not_, or_
|
||||
from sqlalchemy.ext.baked import BakedQuery
|
||||
from sqlalchemy.sql.elements import ClauseList
|
||||
|
||||
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
|
||||
@ -60,16 +59,6 @@ class Filters:
|
||||
or self.included_entity_globs
|
||||
)
|
||||
|
||||
def bake(self, baked_query: BakedQuery) -> BakedQuery:
|
||||
"""Update a baked query.
|
||||
|
||||
Works the same as apply on a baked_query.
|
||||
"""
|
||||
if not self.has_config:
|
||||
return
|
||||
|
||||
baked_query += lambda q: q.filter(self.entity_filter())
|
||||
|
||||
def entity_filter(self) -> ClauseList:
|
||||
"""Generate the entity filter query."""
|
||||
includes = []
|
||||
|
@ -9,13 +9,12 @@ import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import Column, Text, and_, bindparam, func, or_
|
||||
from sqlalchemy import Column, Text, and_, func, lambda_stmt, or_, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.ext import baked
|
||||
from sqlalchemy.ext.baked import BakedQuery
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components.websocket_api.const import (
|
||||
@ -36,7 +35,7 @@ from .models import (
|
||||
process_timestamp_to_utc_isoformat,
|
||||
row_to_compressed_state,
|
||||
)
|
||||
from .util import execute, session_scope
|
||||
from .util import execute_stmt_lambda_element, session_scope
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
@ -111,52 +110,48 @@ QUERY_STATES_NO_LAST_CHANGED = [
|
||||
StateAttributes.shared_attrs,
|
||||
]
|
||||
|
||||
HISTORY_BAKERY = "recorder_history_bakery"
|
||||
|
||||
def _schema_version(hass: HomeAssistant) -> int:
|
||||
return recorder.get_instance(hass).schema_version
|
||||
|
||||
|
||||
def bake_query_and_join_attributes(
|
||||
hass: HomeAssistant, no_attributes: bool, include_last_changed: bool = True
|
||||
) -> tuple[Any, bool]:
|
||||
"""Return the initial backed query and if StateAttributes should be joined.
|
||||
def lambda_stmt_and_join_attributes(
|
||||
schema_version: int, no_attributes: bool, include_last_changed: bool = True
|
||||
) -> tuple[StatementLambdaElement, bool]:
|
||||
"""Return the lambda_stmt and if StateAttributes should be joined.
|
||||
|
||||
Because these are baked queries the values inside the lambdas need
|
||||
Because these are lambda_stmt the values inside the lambdas need
|
||||
to be explicitly written out to avoid caching the wrong values.
|
||||
"""
|
||||
bakery: baked.bakery = hass.data[HISTORY_BAKERY]
|
||||
# If no_attributes was requested we do the query
|
||||
# without the attributes fields and do not join the
|
||||
# state_attributes table
|
||||
if no_attributes:
|
||||
if include_last_changed:
|
||||
return bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR)), False
|
||||
return lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR)), False
|
||||
return (
|
||||
bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)),
|
||||
lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)),
|
||||
False,
|
||||
)
|
||||
# If we in the process of migrating schema we do
|
||||
# not want to join the state_attributes table as we
|
||||
# do not know if it will be there yet
|
||||
if recorder.get_instance(hass).schema_version < 25:
|
||||
if schema_version < 25:
|
||||
if include_last_changed:
|
||||
return (
|
||||
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25)),
|
||||
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25)),
|
||||
False,
|
||||
)
|
||||
return (
|
||||
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)),
|
||||
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)),
|
||||
False,
|
||||
)
|
||||
# Finally if no migration is in progress and no_attributes
|
||||
# was not requested, we query both attributes columns and
|
||||
# join state_attributes
|
||||
if include_last_changed:
|
||||
return bakery(lambda s: s.query(*QUERY_STATES)), True
|
||||
return bakery(lambda s: s.query(*QUERY_STATES_NO_LAST_CHANGED)), True
|
||||
|
||||
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the history hooks."""
|
||||
hass.data[HISTORY_BAKERY] = baked.bakery()
|
||||
return lambda_stmt(lambda: select(*QUERY_STATES)), True
|
||||
return lambda_stmt(lambda: select(*QUERY_STATES_NO_LAST_CHANGED)), True
|
||||
|
||||
|
||||
def get_significant_states(
|
||||
@ -200,38 +195,30 @@ def _ignore_domains_filter(query: Query) -> Query:
|
||||
)
|
||||
|
||||
|
||||
def _query_significant_states_with_session(
|
||||
hass: HomeAssistant,
|
||||
session: Session,
|
||||
def _significant_states_stmt(
|
||||
schema_version: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime | None = None,
|
||||
entity_ids: list[str] | None = None,
|
||||
filters: Filters | None = None,
|
||||
significant_changes_only: bool = True,
|
||||
no_attributes: bool = False,
|
||||
) -> list[Row]:
|
||||
end_time: datetime | None,
|
||||
entity_ids: list[str] | None,
|
||||
filters: Filters | None,
|
||||
significant_changes_only: bool,
|
||||
no_attributes: bool,
|
||||
) -> StatementLambdaElement:
|
||||
"""Query the database for significant state changes."""
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
timer_start = time.perf_counter()
|
||||
|
||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
||||
hass, no_attributes, include_last_changed=True
|
||||
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||
schema_version, no_attributes, include_last_changed=not significant_changes_only
|
||||
)
|
||||
|
||||
if entity_ids is not None and len(entity_ids) == 1:
|
||||
if (
|
||||
significant_changes_only
|
||||
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
|
||||
):
|
||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
||||
hass, no_attributes, include_last_changed=False
|
||||
)
|
||||
baked_query += lambda q: q.filter(
|
||||
(States.last_changed == States.last_updated)
|
||||
| States.last_changed.is_(None)
|
||||
)
|
||||
if (
|
||||
entity_ids
|
||||
and len(entity_ids) == 1
|
||||
and significant_changes_only
|
||||
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
|
||||
):
|
||||
stmt += lambda q: q.filter(
|
||||
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
|
||||
)
|
||||
elif significant_changes_only:
|
||||
baked_query += lambda q: q.filter(
|
||||
stmt += lambda q: q.filter(
|
||||
or_(
|
||||
*[
|
||||
States.entity_id.like(entity_domain)
|
||||
@ -244,36 +231,24 @@ def _query_significant_states_with_session(
|
||||
)
|
||||
)
|
||||
|
||||
if entity_ids is not None:
|
||||
baked_query += lambda q: q.filter(
|
||||
States.entity_id.in_(bindparam("entity_ids", expanding=True))
|
||||
)
|
||||
if entity_ids:
|
||||
stmt += lambda q: q.filter(States.entity_id.in_(entity_ids))
|
||||
else:
|
||||
baked_query += _ignore_domains_filter
|
||||
if filters:
|
||||
filters.bake(baked_query)
|
||||
stmt += _ignore_domains_filter
|
||||
if filters and filters.has_config:
|
||||
entity_filter = filters.entity_filter()
|
||||
stmt += lambda q: q.filter(entity_filter)
|
||||
|
||||
baked_query += lambda q: q.filter(States.last_updated > bindparam("start_time"))
|
||||
if end_time is not None:
|
||||
baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time"))
|
||||
stmt += lambda q: q.filter(States.last_updated > start_time)
|
||||
if end_time:
|
||||
stmt += lambda q: q.filter(States.last_updated < end_time)
|
||||
|
||||
if join_attributes:
|
||||
baked_query += lambda q: q.outerjoin(
|
||||
stmt += lambda q: q.outerjoin(
|
||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||
)
|
||||
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
|
||||
|
||||
states = execute(
|
||||
baked_query(session).params(
|
||||
start_time=start_time, end_time=end_time, entity_ids=entity_ids
|
||||
)
|
||||
)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
elapsed = time.perf_counter() - timer_start
|
||||
_LOGGER.debug("get_significant_states took %fs", elapsed)
|
||||
|
||||
return states
|
||||
stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
|
||||
return stmt
|
||||
|
||||
|
||||
def get_significant_states_with_session(
|
||||
@ -301,9 +276,8 @@ def get_significant_states_with_session(
|
||||
as well as all states from certain domains (for instance
|
||||
thermostat so that we get current temperature in our graphs).
|
||||
"""
|
||||
states = _query_significant_states_with_session(
|
||||
hass,
|
||||
session,
|
||||
stmt = _significant_states_stmt(
|
||||
_schema_version(hass),
|
||||
start_time,
|
||||
end_time,
|
||||
entity_ids,
|
||||
@ -311,6 +285,9 @@ def get_significant_states_with_session(
|
||||
significant_changes_only,
|
||||
no_attributes,
|
||||
)
|
||||
states = execute_stmt_lambda_element(
|
||||
session, stmt, None if entity_ids else start_time, end_time
|
||||
)
|
||||
return _sorted_states_to_dict(
|
||||
hass,
|
||||
session,
|
||||
@ -354,6 +331,38 @@ def get_full_significant_states_with_session(
|
||||
)
|
||||
|
||||
|
||||
def _state_changed_during_period_stmt(
|
||||
schema_version: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime | None,
|
||||
entity_id: str | None,
|
||||
no_attributes: bool,
|
||||
descending: bool,
|
||||
limit: int | None,
|
||||
) -> StatementLambdaElement:
|
||||
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||
schema_version, no_attributes, include_last_changed=False
|
||||
)
|
||||
stmt += lambda q: q.filter(
|
||||
((States.last_changed == States.last_updated) | States.last_changed.is_(None))
|
||||
& (States.last_updated > start_time)
|
||||
)
|
||||
if end_time:
|
||||
stmt += lambda q: q.filter(States.last_updated < end_time)
|
||||
stmt += lambda q: q.filter(States.entity_id == entity_id)
|
||||
if join_attributes:
|
||||
stmt += lambda q: q.outerjoin(
|
||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||
)
|
||||
if descending:
|
||||
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc())
|
||||
else:
|
||||
stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
|
||||
if limit:
|
||||
stmt += lambda q: q.limit(limit)
|
||||
return stmt
|
||||
|
||||
|
||||
def state_changes_during_period(
|
||||
hass: HomeAssistant,
|
||||
start_time: datetime,
|
||||
@ -365,52 +374,21 @@ def state_changes_during_period(
|
||||
include_start_time_state: bool = True,
|
||||
) -> MutableMapping[str, list[State]]:
|
||||
"""Return states changes during UTC period start_time - end_time."""
|
||||
entity_id = entity_id.lower() if entity_id is not None else None
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
||||
hass, no_attributes, include_last_changed=False
|
||||
stmt = _state_changed_during_period_stmt(
|
||||
_schema_version(hass),
|
||||
start_time,
|
||||
end_time,
|
||||
entity_id,
|
||||
no_attributes,
|
||||
descending,
|
||||
limit,
|
||||
)
|
||||
|
||||
baked_query += lambda q: q.filter(
|
||||
(
|
||||
(States.last_changed == States.last_updated)
|
||||
| States.last_changed.is_(None)
|
||||
)
|
||||
& (States.last_updated > bindparam("start_time"))
|
||||
states = execute_stmt_lambda_element(
|
||||
session, stmt, None if entity_id else start_time, end_time
|
||||
)
|
||||
|
||||
if end_time is not None:
|
||||
baked_query += lambda q: q.filter(
|
||||
States.last_updated < bindparam("end_time")
|
||||
)
|
||||
|
||||
if entity_id is not None:
|
||||
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
|
||||
entity_id = entity_id.lower()
|
||||
|
||||
if join_attributes:
|
||||
baked_query += lambda q: q.outerjoin(
|
||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||
)
|
||||
|
||||
if descending:
|
||||
baked_query += lambda q: q.order_by(
|
||||
States.entity_id, States.last_updated.desc()
|
||||
)
|
||||
else:
|
||||
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
|
||||
|
||||
if limit:
|
||||
baked_query += lambda q: q.limit(bindparam("limit"))
|
||||
|
||||
states = execute(
|
||||
baked_query(session).params(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
entity_id=entity_id,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
|
||||
entity_ids = [entity_id] if entity_id is not None else None
|
||||
|
||||
return cast(
|
||||
@ -426,41 +404,37 @@ def state_changes_during_period(
|
||||
)
|
||||
|
||||
|
||||
def _get_last_state_changes_stmt(
|
||||
schema_version: int, number_of_states: int, entity_id: str
|
||||
) -> StatementLambdaElement:
|
||||
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||
schema_version, False, include_last_changed=False
|
||||
)
|
||||
stmt += lambda q: q.filter(
|
||||
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
|
||||
).filter(States.entity_id == entity_id)
|
||||
if join_attributes:
|
||||
stmt += lambda q: q.outerjoin(
|
||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||
)
|
||||
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()).limit(
|
||||
number_of_states
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def get_last_state_changes(
|
||||
hass: HomeAssistant, number_of_states: int, entity_id: str
|
||||
) -> MutableMapping[str, list[State]]:
|
||||
"""Return the last number_of_states."""
|
||||
start_time = dt_util.utcnow()
|
||||
entity_id = entity_id.lower() if entity_id is not None else None
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
||||
hass, False, include_last_changed=False
|
||||
stmt = _get_last_state_changes_stmt(
|
||||
_schema_version(hass), number_of_states, entity_id
|
||||
)
|
||||
|
||||
baked_query += lambda q: q.filter(
|
||||
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
|
||||
)
|
||||
|
||||
if entity_id is not None:
|
||||
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
|
||||
entity_id = entity_id.lower()
|
||||
|
||||
if join_attributes:
|
||||
baked_query += lambda q: q.outerjoin(
|
||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||
)
|
||||
baked_query += lambda q: q.order_by(
|
||||
States.entity_id, States.last_updated.desc()
|
||||
)
|
||||
|
||||
baked_query += lambda q: q.limit(bindparam("number_of_states"))
|
||||
|
||||
states = execute(
|
||||
baked_query(session).params(
|
||||
number_of_states=number_of_states, entity_id=entity_id
|
||||
)
|
||||
)
|
||||
|
||||
states = list(execute_stmt_lambda_element(session, stmt))
|
||||
entity_ids = [entity_id] if entity_id is not None else None
|
||||
|
||||
return cast(
|
||||
@ -476,96 +450,91 @@ def get_last_state_changes(
|
||||
)
|
||||
|
||||
|
||||
def _most_recent_state_ids_entities_subquery(query: Query) -> Query:
|
||||
"""Query to find the most recent state id for specific entities."""
|
||||
def _get_states_for_entites_stmt(
|
||||
schema_version: int,
|
||||
run_start: datetime,
|
||||
utc_point_in_time: datetime,
|
||||
entity_ids: list[str],
|
||||
no_attributes: bool,
|
||||
) -> StatementLambdaElement:
|
||||
"""Baked query to get states for specific entities."""
|
||||
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||
schema_version, no_attributes, include_last_changed=True
|
||||
)
|
||||
# We got an include-list of entities, accelerate the query by filtering already
|
||||
# in the inner query.
|
||||
most_recent_state_ids = (
|
||||
query.session.query(func.max(States.state_id).label("max_state_id"))
|
||||
.filter(
|
||||
(States.last_updated >= bindparam("run_start"))
|
||||
& (States.last_updated < bindparam("utc_point_in_time"))
|
||||
)
|
||||
.filter(States.entity_id.in_(bindparam("entity_ids", expanding=True)))
|
||||
.group_by(States.entity_id)
|
||||
.subquery()
|
||||
stmt += lambda q: q.where(
|
||||
States.state_id
|
||||
== (
|
||||
select(func.max(States.state_id).label("max_state_id"))
|
||||
.filter(
|
||||
(States.last_updated >= run_start)
|
||||
& (States.last_updated < utc_point_in_time)
|
||||
)
|
||||
.filter(States.entity_id.in_(entity_ids))
|
||||
.group_by(States.entity_id)
|
||||
.subquery()
|
||||
).c.max_state_id
|
||||
)
|
||||
return query.join(
|
||||
most_recent_state_ids,
|
||||
States.state_id == most_recent_state_ids.c.max_state_id,
|
||||
)
|
||||
|
||||
|
||||
def _get_states_baked_query_for_entites(
|
||||
hass: HomeAssistant,
|
||||
no_attributes: bool = False,
|
||||
) -> BakedQuery:
|
||||
"""Baked query to get states for specific entities."""
|
||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
||||
hass, no_attributes, include_last_changed=True
|
||||
)
|
||||
baked_query += _most_recent_state_ids_entities_subquery
|
||||
if join_attributes:
|
||||
baked_query += lambda q: q.outerjoin(
|
||||
stmt += lambda q: q.outerjoin(
|
||||
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
||||
)
|
||||
return baked_query
|
||||
return stmt
|
||||
|
||||
|
||||
def _most_recent_state_ids_subquery(query: Query) -> Query:
|
||||
"""Find the most recent state ids for all entiites."""
|
||||
def _get_states_for_all_stmt(
|
||||
schema_version: int,
|
||||
run_start: datetime,
|
||||
utc_point_in_time: datetime,
|
||||
filters: Filters | None,
|
||||
no_attributes: bool,
|
||||
) -> StatementLambdaElement:
|
||||
"""Baked query to get states for all entities."""
|
||||
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||
schema_version, no_attributes, include_last_changed=True
|
||||
)
|
||||
# We did not get an include-list of entities, query all states in the inner
|
||||
# query, then filter out unwanted domains as well as applying the custom filter.
|
||||
# This filtering can't be done in the inner query because the domain column is
|
||||
# not indexed and we can't control what's in the custom filter.
|
||||
most_recent_states_by_date = (
|
||||
query.session.query(
|
||||
select(
|
||||
States.entity_id.label("max_entity_id"),
|
||||
func.max(States.last_updated).label("max_last_updated"),
|
||||
)
|
||||
.filter(
|
||||
(States.last_updated >= bindparam("run_start"))
|
||||
& (States.last_updated < bindparam("utc_point_in_time"))
|
||||
(States.last_updated >= run_start)
|
||||
& (States.last_updated < utc_point_in_time)
|
||||
)
|
||||
.group_by(States.entity_id)
|
||||
.subquery()
|
||||
)
|
||||
most_recent_state_ids = (
|
||||
query.session.query(func.max(States.state_id).label("max_state_id"))
|
||||
.join(
|
||||
most_recent_states_by_date,
|
||||
and_(
|
||||
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
||||
States.last_updated == most_recent_states_by_date.c.max_last_updated,
|
||||
),
|
||||
)
|
||||
.group_by(States.entity_id)
|
||||
.subquery()
|
||||
stmt += lambda q: q.where(
|
||||
States.state_id
|
||||
== (
|
||||
select(func.max(States.state_id).label("max_state_id"))
|
||||
.join(
|
||||
most_recent_states_by_date,
|
||||
and_(
|
||||
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
||||
States.last_updated
|
||||
== most_recent_states_by_date.c.max_last_updated,
|
||||
),
|
||||
)
|
||||
.group_by(States.entity_id)
|
||||
.subquery()
|
||||
).c.max_state_id,
|
||||
)
|
||||
return query.join(
|
||||
most_recent_state_ids,
|
||||
States.state_id == most_recent_state_ids.c.max_state_id,
|
||||
)
|
||||
|
||||
|
||||
def _get_states_baked_query_for_all(
|
||||
hass: HomeAssistant,
|
||||
filters: Filters | None = None,
|
||||
no_attributes: bool = False,
|
||||
) -> BakedQuery:
|
||||
"""Baked query to get states for all entities."""
|
||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
||||
hass, no_attributes, include_last_changed=True
|
||||
)
|
||||
baked_query += _most_recent_state_ids_subquery
|
||||
baked_query += _ignore_domains_filter
|
||||
if filters:
|
||||
filters.bake(baked_query)
|
||||
stmt += _ignore_domains_filter
|
||||
if filters and filters.has_config:
|
||||
entity_filter = filters.entity_filter()
|
||||
stmt += lambda q: q.filter(entity_filter)
|
||||
if join_attributes:
|
||||
baked_query += lambda q: q.outerjoin(
|
||||
stmt += lambda q: q.outerjoin(
|
||||
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
||||
)
|
||||
return baked_query
|
||||
return stmt
|
||||
|
||||
|
||||
def _get_rows_with_session(
|
||||
@ -576,11 +545,15 @@ def _get_rows_with_session(
|
||||
run: RecorderRuns | None = None,
|
||||
filters: Filters | None = None,
|
||||
no_attributes: bool = False,
|
||||
) -> list[Row]:
|
||||
) -> Iterable[Row]:
|
||||
"""Return the states at a specific point in time."""
|
||||
schema_version = _schema_version(hass)
|
||||
if entity_ids and len(entity_ids) == 1:
|
||||
return _get_single_entity_states_with_session(
|
||||
hass, session, utc_point_in_time, entity_ids[0], no_attributes
|
||||
return execute_stmt_lambda_element(
|
||||
session,
|
||||
_get_single_entity_states_stmt(
|
||||
schema_version, utc_point_in_time, entity_ids[0], no_attributes
|
||||
),
|
||||
)
|
||||
|
||||
if run is None:
|
||||
@ -593,46 +566,41 @@ def _get_rows_with_session(
|
||||
# We have more than one entity to look at so we need to do a query on states
|
||||
# since the last recorder run started.
|
||||
if entity_ids:
|
||||
baked_query = _get_states_baked_query_for_entites(hass, no_attributes)
|
||||
else:
|
||||
baked_query = _get_states_baked_query_for_all(hass, filters, no_attributes)
|
||||
|
||||
return execute(
|
||||
baked_query(session).params(
|
||||
run_start=run.start,
|
||||
utc_point_in_time=utc_point_in_time,
|
||||
entity_ids=entity_ids,
|
||||
stmt = _get_states_for_entites_stmt(
|
||||
schema_version, run.start, utc_point_in_time, entity_ids, no_attributes
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = _get_states_for_all_stmt(
|
||||
schema_version, run.start, utc_point_in_time, filters, no_attributes
|
||||
)
|
||||
|
||||
return execute_stmt_lambda_element(session, stmt)
|
||||
|
||||
|
||||
def _get_single_entity_states_with_session(
|
||||
hass: HomeAssistant,
|
||||
session: Session,
|
||||
def _get_single_entity_states_stmt(
|
||||
schema_version: int,
|
||||
utc_point_in_time: datetime,
|
||||
entity_id: str,
|
||||
no_attributes: bool = False,
|
||||
) -> list[Row]:
|
||||
) -> StatementLambdaElement:
|
||||
# Use an entirely different (and extremely fast) query if we only
|
||||
# have a single entity id
|
||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
||||
hass, no_attributes, include_last_changed=True
|
||||
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||
schema_version, no_attributes, include_last_changed=True
|
||||
)
|
||||
baked_query += lambda q: q.filter(
|
||||
States.last_updated < bindparam("utc_point_in_time"),
|
||||
States.entity_id == bindparam("entity_id"),
|
||||
stmt += (
|
||||
lambda q: q.filter(
|
||||
States.last_updated < utc_point_in_time,
|
||||
States.entity_id == entity_id,
|
||||
)
|
||||
.order_by(States.last_updated.desc())
|
||||
.limit(1)
|
||||
)
|
||||
if join_attributes:
|
||||
baked_query += lambda q: q.outerjoin(
|
||||
stmt += lambda q: q.outerjoin(
|
||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||
)
|
||||
baked_query += lambda q: q.order_by(States.last_updated.desc()).limit(1)
|
||||
|
||||
query = baked_query(session).params(
|
||||
utc_point_in_time=utc_point_in_time, entity_id=entity_id
|
||||
)
|
||||
|
||||
return execute(query)
|
||||
return stmt
|
||||
|
||||
|
||||
def _sorted_states_to_dict(
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""SQLAlchemy util functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Callable, Generator, Iterable
|
||||
from contextlib import contextmanager
|
||||
from datetime import date, datetime, timedelta
|
||||
import functools
|
||||
@ -18,9 +18,12 @@ from awesomeversion import (
|
||||
import ciso8601
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.cursor import CursorFetchStrategy
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.ext.baked import Result
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -46,6 +49,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
RETRIES = 3
|
||||
QUERY_RETRY_WAIT = 0.1
|
||||
SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
|
||||
DEFAULT_YIELD_STATES_ROWS = 32768
|
||||
|
||||
MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER)
|
||||
MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER)
|
||||
@ -119,8 +123,10 @@ def commit(session: Session, work: Any) -> bool:
|
||||
|
||||
|
||||
def execute(
|
||||
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
|
||||
) -> list:
|
||||
qry: Query | Result,
|
||||
to_native: bool = False,
|
||||
validate_entity_ids: bool = True,
|
||||
) -> list[Row]:
|
||||
"""Query the database and convert the objects to HA native form.
|
||||
|
||||
This method also retries a few times in the case of stale connections.
|
||||
@ -163,7 +169,39 @@ def execute(
|
||||
raise
|
||||
time.sleep(QUERY_RETRY_WAIT)
|
||||
|
||||
assert False # unreachable
|
||||
assert False # unreachable # pragma: no cover
|
||||
|
||||
|
||||
def execute_stmt_lambda_element(
|
||||
session: Session,
|
||||
stmt: StatementLambdaElement,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
yield_per: int | None = DEFAULT_YIELD_STATES_ROWS,
|
||||
) -> Iterable[Row]:
|
||||
"""Execute a StatementLambdaElement.
|
||||
|
||||
If the time window passed is greater than one day
|
||||
the execution method will switch to yield_per to
|
||||
reduce memory pressure.
|
||||
|
||||
It is not recommended to pass a time window
|
||||
when selecting non-ranged rows (ie selecting
|
||||
specific entities) since they are usually faster
|
||||
with .all().
|
||||
"""
|
||||
executed = session.execute(stmt)
|
||||
use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1
|
||||
for tryno in range(0, RETRIES):
|
||||
try:
|
||||
return executed.all() if use_all else executed.yield_per(yield_per) # type: ignore[no-any-return]
|
||||
except SQLAlchemyError as err:
|
||||
_LOGGER.error("Error executing query: %s", err)
|
||||
if tryno == RETRIES - 1:
|
||||
raise
|
||||
time.sleep(QUERY_RETRY_WAIT)
|
||||
|
||||
assert False # unreachable # pragma: no cover
|
||||
|
||||
|
||||
def validate_or_move_away_sqlite_database(dburl: str) -> bool:
|
||||
|
@ -6,10 +6,13 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.result import ChunkedIteratorResult
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components.recorder import util
|
||||
from homeassistant.components.recorder import history, util
|
||||
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
|
||||
from homeassistant.components.recorder.models import RecorderRuns
|
||||
from homeassistant.components.recorder.util import (
|
||||
@ -24,6 +27,7 @@ from homeassistant.util import dt as dt_util
|
||||
from .common import corrupt_db_file, run_information_with_session
|
||||
|
||||
from tests.common import SetupRecorderInstanceT, async_test_home_assistant
|
||||
from tests.components.recorder.common import wait_recording_done
|
||||
|
||||
|
||||
def test_session_scope_not_setup(hass_recorder):
|
||||
@ -510,8 +514,10 @@ def test_basic_sanity_check(hass_recorder):
|
||||
def test_combined_checks(hass_recorder, caplog):
|
||||
"""Run Checks on the open database."""
|
||||
hass = hass_recorder()
|
||||
instance = recorder.get_instance(hass)
|
||||
instance.db_retry_wait = 0
|
||||
|
||||
cursor = hass.data[DATA_INSTANCE].engine.raw_connection().cursor()
|
||||
cursor = instance.engine.raw_connection().cursor()
|
||||
|
||||
assert util.run_checks_on_open_db("fake_db_path", cursor) is None
|
||||
assert "could not validate that the sqlite3 database" in caplog.text
|
||||
@ -658,3 +664,54 @@ def test_build_mysqldb_conv():
|
||||
assert conv["DATETIME"]("2022-05-13T22:33:12.741") == datetime(
|
||||
2022, 5, 13, 22, 33, 12, 741000, tzinfo=None
|
||||
)
|
||||
|
||||
|
||||
@patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0)
|
||||
def test_execute_stmt_lambda_element(hass_recorder):
|
||||
"""Test executing with execute_stmt_lambda_element."""
|
||||
hass = hass_recorder()
|
||||
instance = recorder.get_instance(hass)
|
||||
hass.states.set("sensor.on", "on")
|
||||
new_state = hass.states.get("sensor.on")
|
||||
wait_recording_done(hass)
|
||||
now = dt_util.utcnow()
|
||||
tomorrow = now + timedelta(days=1)
|
||||
one_week_from_now = now + timedelta(days=7)
|
||||
|
||||
class MockExecutor:
|
||||
def __init__(self, stmt):
|
||||
assert isinstance(stmt, StatementLambdaElement)
|
||||
self.calls = 0
|
||||
|
||||
def all(self):
|
||||
self.calls += 1
|
||||
if self.calls == 2:
|
||||
return ["mock_row"]
|
||||
raise SQLAlchemyError
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
# No time window, we always get a list
|
||||
stmt = history._get_single_entity_states_stmt(
|
||||
instance.schema_version, dt_util.utcnow(), "sensor.on", False
|
||||
)
|
||||
rows = util.execute_stmt_lambda_element(session, stmt)
|
||||
assert isinstance(rows, list)
|
||||
assert rows[0].state == new_state.state
|
||||
assert rows[0].entity_id == new_state.entity_id
|
||||
|
||||
# Time window >= 2 days, we get a ChunkedIteratorResult
|
||||
rows = util.execute_stmt_lambda_element(session, stmt, now, one_week_from_now)
|
||||
assert isinstance(rows, ChunkedIteratorResult)
|
||||
row = next(rows)
|
||||
assert row.state == new_state.state
|
||||
assert row.entity_id == new_state.entity_id
|
||||
|
||||
# Time window < 2 days, we get a list
|
||||
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
|
||||
assert isinstance(rows, list)
|
||||
assert rows[0].state == new_state.state
|
||||
assert rows[0].entity_id == new_state.entity_id
|
||||
|
||||
with patch.object(session, "execute", MockExecutor):
|
||||
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
|
||||
assert rows == ["mock_row"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user