mirror of
https://github.com/home-assistant/core.git
synced 2025-04-30 04:07:51 +00:00
Convert history queries to use lambda_stmt (#71870)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
8ea5ec6f08
commit
98809675ff
@ -20,7 +20,7 @@ from homeassistant.helpers.integration_platform import (
|
|||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
|
||||||
from . import history, statistics, websocket_api
|
from . import statistics, websocket_api
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_DB_INTEGRITY_CHECK,
|
CONF_DB_INTEGRITY_CHECK,
|
||||||
DATA_INSTANCE,
|
DATA_INSTANCE,
|
||||||
@ -166,7 +166,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
instance.async_register()
|
instance.async_register()
|
||||||
instance.start()
|
instance.start()
|
||||||
async_register_services(hass, instance)
|
async_register_services(hass, instance)
|
||||||
history.async_setup(hass)
|
|
||||||
statistics.async_setup(hass)
|
statistics.async_setup(hass)
|
||||||
websocket_api.async_setup(hass)
|
websocket_api.async_setup(hass)
|
||||||
await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform)
|
await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlalchemy import not_, or_
|
from sqlalchemy import not_, or_
|
||||||
from sqlalchemy.ext.baked import BakedQuery
|
|
||||||
from sqlalchemy.sql.elements import ClauseList
|
from sqlalchemy.sql.elements import ClauseList
|
||||||
|
|
||||||
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
|
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
|
||||||
@ -60,16 +59,6 @@ class Filters:
|
|||||||
or self.included_entity_globs
|
or self.included_entity_globs
|
||||||
)
|
)
|
||||||
|
|
||||||
def bake(self, baked_query: BakedQuery) -> BakedQuery:
|
|
||||||
"""Update a baked query.
|
|
||||||
|
|
||||||
Works the same as apply on a baked_query.
|
|
||||||
"""
|
|
||||||
if not self.has_config:
|
|
||||||
return
|
|
||||||
|
|
||||||
baked_query += lambda q: q.filter(self.entity_filter())
|
|
||||||
|
|
||||||
def entity_filter(self) -> ClauseList:
|
def entity_filter(self) -> ClauseList:
|
||||||
"""Generate the entity filter query."""
|
"""Generate the entity filter query."""
|
||||||
includes = []
|
includes = []
|
||||||
|
@ -9,13 +9,12 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import Column, Text, and_, bindparam, func, or_
|
from sqlalchemy import Column, Text, and_, func, lambda_stmt, or_, select
|
||||||
from sqlalchemy.engine.row import Row
|
from sqlalchemy.engine.row import Row
|
||||||
from sqlalchemy.ext import baked
|
|
||||||
from sqlalchemy.ext.baked import BakedQuery
|
|
||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
from sqlalchemy.sql.expression import literal
|
from sqlalchemy.sql.expression import literal
|
||||||
|
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||||
|
|
||||||
from homeassistant.components import recorder
|
from homeassistant.components import recorder
|
||||||
from homeassistant.components.websocket_api.const import (
|
from homeassistant.components.websocket_api.const import (
|
||||||
@ -36,7 +35,7 @@ from .models import (
|
|||||||
process_timestamp_to_utc_isoformat,
|
process_timestamp_to_utc_isoformat,
|
||||||
row_to_compressed_state,
|
row_to_compressed_state,
|
||||||
)
|
)
|
||||||
from .util import execute, session_scope
|
from .util import execute_stmt_lambda_element, session_scope
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||||
|
|
||||||
@ -111,52 +110,48 @@ QUERY_STATES_NO_LAST_CHANGED = [
|
|||||||
StateAttributes.shared_attrs,
|
StateAttributes.shared_attrs,
|
||||||
]
|
]
|
||||||
|
|
||||||
HISTORY_BAKERY = "recorder_history_bakery"
|
|
||||||
|
def _schema_version(hass: HomeAssistant) -> int:
|
||||||
|
return recorder.get_instance(hass).schema_version
|
||||||
|
|
||||||
|
|
||||||
def bake_query_and_join_attributes(
|
def lambda_stmt_and_join_attributes(
|
||||||
hass: HomeAssistant, no_attributes: bool, include_last_changed: bool = True
|
schema_version: int, no_attributes: bool, include_last_changed: bool = True
|
||||||
) -> tuple[Any, bool]:
|
) -> tuple[StatementLambdaElement, bool]:
|
||||||
"""Return the initial backed query and if StateAttributes should be joined.
|
"""Return the lambda_stmt and if StateAttributes should be joined.
|
||||||
|
|
||||||
Because these are baked queries the values inside the lambdas need
|
Because these are lambda_stmt the values inside the lambdas need
|
||||||
to be explicitly written out to avoid caching the wrong values.
|
to be explicitly written out to avoid caching the wrong values.
|
||||||
"""
|
"""
|
||||||
bakery: baked.bakery = hass.data[HISTORY_BAKERY]
|
|
||||||
# If no_attributes was requested we do the query
|
# If no_attributes was requested we do the query
|
||||||
# without the attributes fields and do not join the
|
# without the attributes fields and do not join the
|
||||||
# state_attributes table
|
# state_attributes table
|
||||||
if no_attributes:
|
if no_attributes:
|
||||||
if include_last_changed:
|
if include_last_changed:
|
||||||
return bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR)), False
|
return lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR)), False
|
||||||
return (
|
return (
|
||||||
bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)),
|
lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)),
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
# If we in the process of migrating schema we do
|
# If we in the process of migrating schema we do
|
||||||
# not want to join the state_attributes table as we
|
# not want to join the state_attributes table as we
|
||||||
# do not know if it will be there yet
|
# do not know if it will be there yet
|
||||||
if recorder.get_instance(hass).schema_version < 25:
|
if schema_version < 25:
|
||||||
if include_last_changed:
|
if include_last_changed:
|
||||||
return (
|
return (
|
||||||
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25)),
|
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25)),
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)),
|
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)),
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
# Finally if no migration is in progress and no_attributes
|
# Finally if no migration is in progress and no_attributes
|
||||||
# was not requested, we query both attributes columns and
|
# was not requested, we query both attributes columns and
|
||||||
# join state_attributes
|
# join state_attributes
|
||||||
if include_last_changed:
|
if include_last_changed:
|
||||||
return bakery(lambda s: s.query(*QUERY_STATES)), True
|
return lambda_stmt(lambda: select(*QUERY_STATES)), True
|
||||||
return bakery(lambda s: s.query(*QUERY_STATES_NO_LAST_CHANGED)), True
|
return lambda_stmt(lambda: select(*QUERY_STATES_NO_LAST_CHANGED)), True
|
||||||
|
|
||||||
|
|
||||||
def async_setup(hass: HomeAssistant) -> None:
|
|
||||||
"""Set up the history hooks."""
|
|
||||||
hass.data[HISTORY_BAKERY] = baked.bakery()
|
|
||||||
|
|
||||||
|
|
||||||
def get_significant_states(
|
def get_significant_states(
|
||||||
@ -200,38 +195,30 @@ def _ignore_domains_filter(query: Query) -> Query:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _query_significant_states_with_session(
|
def _significant_states_stmt(
|
||||||
hass: HomeAssistant,
|
schema_version: int,
|
||||||
session: Session,
|
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
end_time: datetime | None = None,
|
end_time: datetime | None,
|
||||||
entity_ids: list[str] | None = None,
|
entity_ids: list[str] | None,
|
||||||
filters: Filters | None = None,
|
filters: Filters | None,
|
||||||
significant_changes_only: bool = True,
|
significant_changes_only: bool,
|
||||||
no_attributes: bool = False,
|
no_attributes: bool,
|
||||||
) -> list[Row]:
|
) -> StatementLambdaElement:
|
||||||
"""Query the database for significant state changes."""
|
"""Query the database for significant state changes."""
|
||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||||
timer_start = time.perf_counter()
|
schema_version, no_attributes, include_last_changed=not significant_changes_only
|
||||||
|
|
||||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
|
||||||
hass, no_attributes, include_last_changed=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if entity_ids is not None and len(entity_ids) == 1:
|
|
||||||
if (
|
if (
|
||||||
significant_changes_only
|
entity_ids
|
||||||
|
and len(entity_ids) == 1
|
||||||
|
and significant_changes_only
|
||||||
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
|
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
|
||||||
):
|
):
|
||||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
stmt += lambda q: q.filter(
|
||||||
hass, no_attributes, include_last_changed=False
|
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
|
||||||
)
|
|
||||||
baked_query += lambda q: q.filter(
|
|
||||||
(States.last_changed == States.last_updated)
|
|
||||||
| States.last_changed.is_(None)
|
|
||||||
)
|
)
|
||||||
elif significant_changes_only:
|
elif significant_changes_only:
|
||||||
baked_query += lambda q: q.filter(
|
stmt += lambda q: q.filter(
|
||||||
or_(
|
or_(
|
||||||
*[
|
*[
|
||||||
States.entity_id.like(entity_domain)
|
States.entity_id.like(entity_domain)
|
||||||
@ -244,36 +231,24 @@ def _query_significant_states_with_session(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if entity_ids is not None:
|
if entity_ids:
|
||||||
baked_query += lambda q: q.filter(
|
stmt += lambda q: q.filter(States.entity_id.in_(entity_ids))
|
||||||
States.entity_id.in_(bindparam("entity_ids", expanding=True))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
baked_query += _ignore_domains_filter
|
stmt += _ignore_domains_filter
|
||||||
if filters:
|
if filters and filters.has_config:
|
||||||
filters.bake(baked_query)
|
entity_filter = filters.entity_filter()
|
||||||
|
stmt += lambda q: q.filter(entity_filter)
|
||||||
|
|
||||||
baked_query += lambda q: q.filter(States.last_updated > bindparam("start_time"))
|
stmt += lambda q: q.filter(States.last_updated > start_time)
|
||||||
if end_time is not None:
|
if end_time:
|
||||||
baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time"))
|
stmt += lambda q: q.filter(States.last_updated < end_time)
|
||||||
|
|
||||||
if join_attributes:
|
if join_attributes:
|
||||||
baked_query += lambda q: q.outerjoin(
|
stmt += lambda q: q.outerjoin(
|
||||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||||
)
|
)
|
||||||
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
|
stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
|
||||||
|
return stmt
|
||||||
states = execute(
|
|
||||||
baked_query(session).params(
|
|
||||||
start_time=start_time, end_time=end_time, entity_ids=entity_ids
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
|
||||||
elapsed = time.perf_counter() - timer_start
|
|
||||||
_LOGGER.debug("get_significant_states took %fs", elapsed)
|
|
||||||
|
|
||||||
return states
|
|
||||||
|
|
||||||
|
|
||||||
def get_significant_states_with_session(
|
def get_significant_states_with_session(
|
||||||
@ -301,9 +276,8 @@ def get_significant_states_with_session(
|
|||||||
as well as all states from certain domains (for instance
|
as well as all states from certain domains (for instance
|
||||||
thermostat so that we get current temperature in our graphs).
|
thermostat so that we get current temperature in our graphs).
|
||||||
"""
|
"""
|
||||||
states = _query_significant_states_with_session(
|
stmt = _significant_states_stmt(
|
||||||
hass,
|
_schema_version(hass),
|
||||||
session,
|
|
||||||
start_time,
|
start_time,
|
||||||
end_time,
|
end_time,
|
||||||
entity_ids,
|
entity_ids,
|
||||||
@ -311,6 +285,9 @@ def get_significant_states_with_session(
|
|||||||
significant_changes_only,
|
significant_changes_only,
|
||||||
no_attributes,
|
no_attributes,
|
||||||
)
|
)
|
||||||
|
states = execute_stmt_lambda_element(
|
||||||
|
session, stmt, None if entity_ids else start_time, end_time
|
||||||
|
)
|
||||||
return _sorted_states_to_dict(
|
return _sorted_states_to_dict(
|
||||||
hass,
|
hass,
|
||||||
session,
|
session,
|
||||||
@ -354,6 +331,38 @@ def get_full_significant_states_with_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _state_changed_during_period_stmt(
|
||||||
|
schema_version: int,
|
||||||
|
start_time: datetime,
|
||||||
|
end_time: datetime | None,
|
||||||
|
entity_id: str | None,
|
||||||
|
no_attributes: bool,
|
||||||
|
descending: bool,
|
||||||
|
limit: int | None,
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||||
|
schema_version, no_attributes, include_last_changed=False
|
||||||
|
)
|
||||||
|
stmt += lambda q: q.filter(
|
||||||
|
((States.last_changed == States.last_updated) | States.last_changed.is_(None))
|
||||||
|
& (States.last_updated > start_time)
|
||||||
|
)
|
||||||
|
if end_time:
|
||||||
|
stmt += lambda q: q.filter(States.last_updated < end_time)
|
||||||
|
stmt += lambda q: q.filter(States.entity_id == entity_id)
|
||||||
|
if join_attributes:
|
||||||
|
stmt += lambda q: q.outerjoin(
|
||||||
|
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||||
|
)
|
||||||
|
if descending:
|
||||||
|
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc())
|
||||||
|
else:
|
||||||
|
stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
|
||||||
|
if limit:
|
||||||
|
stmt += lambda q: q.limit(limit)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def state_changes_during_period(
|
def state_changes_during_period(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
@ -365,52 +374,21 @@ def state_changes_during_period(
|
|||||||
include_start_time_state: bool = True,
|
include_start_time_state: bool = True,
|
||||||
) -> MutableMapping[str, list[State]]:
|
) -> MutableMapping[str, list[State]]:
|
||||||
"""Return states changes during UTC period start_time - end_time."""
|
"""Return states changes during UTC period start_time - end_time."""
|
||||||
|
entity_id = entity_id.lower() if entity_id is not None else None
|
||||||
|
|
||||||
with session_scope(hass=hass) as session:
|
with session_scope(hass=hass) as session:
|
||||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
stmt = _state_changed_during_period_stmt(
|
||||||
hass, no_attributes, include_last_changed=False
|
_schema_version(hass),
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
entity_id,
|
||||||
|
no_attributes,
|
||||||
|
descending,
|
||||||
|
limit,
|
||||||
)
|
)
|
||||||
|
states = execute_stmt_lambda_element(
|
||||||
baked_query += lambda q: q.filter(
|
session, stmt, None if entity_id else start_time, end_time
|
||||||
(
|
|
||||||
(States.last_changed == States.last_updated)
|
|
||||||
| States.last_changed.is_(None)
|
|
||||||
)
|
)
|
||||||
& (States.last_updated > bindparam("start_time"))
|
|
||||||
)
|
|
||||||
|
|
||||||
if end_time is not None:
|
|
||||||
baked_query += lambda q: q.filter(
|
|
||||||
States.last_updated < bindparam("end_time")
|
|
||||||
)
|
|
||||||
|
|
||||||
if entity_id is not None:
|
|
||||||
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
|
|
||||||
entity_id = entity_id.lower()
|
|
||||||
|
|
||||||
if join_attributes:
|
|
||||||
baked_query += lambda q: q.outerjoin(
|
|
||||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if descending:
|
|
||||||
baked_query += lambda q: q.order_by(
|
|
||||||
States.entity_id, States.last_updated.desc()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
|
|
||||||
|
|
||||||
if limit:
|
|
||||||
baked_query += lambda q: q.limit(bindparam("limit"))
|
|
||||||
|
|
||||||
states = execute(
|
|
||||||
baked_query(session).params(
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
entity_id=entity_id,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
entity_ids = [entity_id] if entity_id is not None else None
|
entity_ids = [entity_id] if entity_id is not None else None
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
@ -426,41 +404,37 @@ def state_changes_during_period(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_last_state_changes_stmt(
|
||||||
|
schema_version: int, number_of_states: int, entity_id: str
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||||
|
schema_version, False, include_last_changed=False
|
||||||
|
)
|
||||||
|
stmt += lambda q: q.filter(
|
||||||
|
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
|
||||||
|
).filter(States.entity_id == entity_id)
|
||||||
|
if join_attributes:
|
||||||
|
stmt += lambda q: q.outerjoin(
|
||||||
|
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||||
|
)
|
||||||
|
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()).limit(
|
||||||
|
number_of_states
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def get_last_state_changes(
|
def get_last_state_changes(
|
||||||
hass: HomeAssistant, number_of_states: int, entity_id: str
|
hass: HomeAssistant, number_of_states: int, entity_id: str
|
||||||
) -> MutableMapping[str, list[State]]:
|
) -> MutableMapping[str, list[State]]:
|
||||||
"""Return the last number_of_states."""
|
"""Return the last number_of_states."""
|
||||||
start_time = dt_util.utcnow()
|
start_time = dt_util.utcnow()
|
||||||
|
entity_id = entity_id.lower() if entity_id is not None else None
|
||||||
|
|
||||||
with session_scope(hass=hass) as session:
|
with session_scope(hass=hass) as session:
|
||||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
stmt = _get_last_state_changes_stmt(
|
||||||
hass, False, include_last_changed=False
|
_schema_version(hass), number_of_states, entity_id
|
||||||
)
|
)
|
||||||
|
states = list(execute_stmt_lambda_element(session, stmt))
|
||||||
baked_query += lambda q: q.filter(
|
|
||||||
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
|
|
||||||
)
|
|
||||||
|
|
||||||
if entity_id is not None:
|
|
||||||
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
|
|
||||||
entity_id = entity_id.lower()
|
|
||||||
|
|
||||||
if join_attributes:
|
|
||||||
baked_query += lambda q: q.outerjoin(
|
|
||||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
|
||||||
)
|
|
||||||
baked_query += lambda q: q.order_by(
|
|
||||||
States.entity_id, States.last_updated.desc()
|
|
||||||
)
|
|
||||||
|
|
||||||
baked_query += lambda q: q.limit(bindparam("number_of_states"))
|
|
||||||
|
|
||||||
states = execute(
|
|
||||||
baked_query(session).params(
|
|
||||||
number_of_states=number_of_states, entity_id=entity_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
entity_ids = [entity_id] if entity_id is not None else None
|
entity_ids = [entity_id] if entity_id is not None else None
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
@ -476,96 +450,91 @@ def get_last_state_changes(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _most_recent_state_ids_entities_subquery(query: Query) -> Query:
|
def _get_states_for_entites_stmt(
|
||||||
"""Query to find the most recent state id for specific entities."""
|
schema_version: int,
|
||||||
|
run_start: datetime,
|
||||||
|
utc_point_in_time: datetime,
|
||||||
|
entity_ids: list[str],
|
||||||
|
no_attributes: bool,
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
"""Baked query to get states for specific entities."""
|
||||||
|
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||||
|
schema_version, no_attributes, include_last_changed=True
|
||||||
|
)
|
||||||
# We got an include-list of entities, accelerate the query by filtering already
|
# We got an include-list of entities, accelerate the query by filtering already
|
||||||
# in the inner query.
|
# in the inner query.
|
||||||
most_recent_state_ids = (
|
stmt += lambda q: q.where(
|
||||||
query.session.query(func.max(States.state_id).label("max_state_id"))
|
States.state_id
|
||||||
|
== (
|
||||||
|
select(func.max(States.state_id).label("max_state_id"))
|
||||||
.filter(
|
.filter(
|
||||||
(States.last_updated >= bindparam("run_start"))
|
(States.last_updated >= run_start)
|
||||||
& (States.last_updated < bindparam("utc_point_in_time"))
|
& (States.last_updated < utc_point_in_time)
|
||||||
)
|
)
|
||||||
.filter(States.entity_id.in_(bindparam("entity_ids", expanding=True)))
|
.filter(States.entity_id.in_(entity_ids))
|
||||||
.group_by(States.entity_id)
|
.group_by(States.entity_id)
|
||||||
.subquery()
|
.subquery()
|
||||||
|
).c.max_state_id
|
||||||
)
|
)
|
||||||
return query.join(
|
|
||||||
most_recent_state_ids,
|
|
||||||
States.state_id == most_recent_state_ids.c.max_state_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_states_baked_query_for_entites(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
no_attributes: bool = False,
|
|
||||||
) -> BakedQuery:
|
|
||||||
"""Baked query to get states for specific entities."""
|
|
||||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
|
||||||
hass, no_attributes, include_last_changed=True
|
|
||||||
)
|
|
||||||
baked_query += _most_recent_state_ids_entities_subquery
|
|
||||||
if join_attributes:
|
if join_attributes:
|
||||||
baked_query += lambda q: q.outerjoin(
|
stmt += lambda q: q.outerjoin(
|
||||||
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
||||||
)
|
)
|
||||||
return baked_query
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def _most_recent_state_ids_subquery(query: Query) -> Query:
|
def _get_states_for_all_stmt(
|
||||||
"""Find the most recent state ids for all entiites."""
|
schema_version: int,
|
||||||
|
run_start: datetime,
|
||||||
|
utc_point_in_time: datetime,
|
||||||
|
filters: Filters | None,
|
||||||
|
no_attributes: bool,
|
||||||
|
) -> StatementLambdaElement:
|
||||||
|
"""Baked query to get states for all entities."""
|
||||||
|
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||||
|
schema_version, no_attributes, include_last_changed=True
|
||||||
|
)
|
||||||
# We did not get an include-list of entities, query all states in the inner
|
# We did not get an include-list of entities, query all states in the inner
|
||||||
# query, then filter out unwanted domains as well as applying the custom filter.
|
# query, then filter out unwanted domains as well as applying the custom filter.
|
||||||
# This filtering can't be done in the inner query because the domain column is
|
# This filtering can't be done in the inner query because the domain column is
|
||||||
# not indexed and we can't control what's in the custom filter.
|
# not indexed and we can't control what's in the custom filter.
|
||||||
most_recent_states_by_date = (
|
most_recent_states_by_date = (
|
||||||
query.session.query(
|
select(
|
||||||
States.entity_id.label("max_entity_id"),
|
States.entity_id.label("max_entity_id"),
|
||||||
func.max(States.last_updated).label("max_last_updated"),
|
func.max(States.last_updated).label("max_last_updated"),
|
||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
(States.last_updated >= bindparam("run_start"))
|
(States.last_updated >= run_start)
|
||||||
& (States.last_updated < bindparam("utc_point_in_time"))
|
& (States.last_updated < utc_point_in_time)
|
||||||
)
|
)
|
||||||
.group_by(States.entity_id)
|
.group_by(States.entity_id)
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
most_recent_state_ids = (
|
stmt += lambda q: q.where(
|
||||||
query.session.query(func.max(States.state_id).label("max_state_id"))
|
States.state_id
|
||||||
|
== (
|
||||||
|
select(func.max(States.state_id).label("max_state_id"))
|
||||||
.join(
|
.join(
|
||||||
most_recent_states_by_date,
|
most_recent_states_by_date,
|
||||||
and_(
|
and_(
|
||||||
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
States.entity_id == most_recent_states_by_date.c.max_entity_id,
|
||||||
States.last_updated == most_recent_states_by_date.c.max_last_updated,
|
States.last_updated
|
||||||
|
== most_recent_states_by_date.c.max_last_updated,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.group_by(States.entity_id)
|
.group_by(States.entity_id)
|
||||||
.subquery()
|
.subquery()
|
||||||
|
).c.max_state_id,
|
||||||
)
|
)
|
||||||
return query.join(
|
stmt += _ignore_domains_filter
|
||||||
most_recent_state_ids,
|
if filters and filters.has_config:
|
||||||
States.state_id == most_recent_state_ids.c.max_state_id,
|
entity_filter = filters.entity_filter()
|
||||||
)
|
stmt += lambda q: q.filter(entity_filter)
|
||||||
|
|
||||||
|
|
||||||
def _get_states_baked_query_for_all(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
filters: Filters | None = None,
|
|
||||||
no_attributes: bool = False,
|
|
||||||
) -> BakedQuery:
|
|
||||||
"""Baked query to get states for all entities."""
|
|
||||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
|
||||||
hass, no_attributes, include_last_changed=True
|
|
||||||
)
|
|
||||||
baked_query += _most_recent_state_ids_subquery
|
|
||||||
baked_query += _ignore_domains_filter
|
|
||||||
if filters:
|
|
||||||
filters.bake(baked_query)
|
|
||||||
if join_attributes:
|
if join_attributes:
|
||||||
baked_query += lambda q: q.outerjoin(
|
stmt += lambda q: q.outerjoin(
|
||||||
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
||||||
)
|
)
|
||||||
return baked_query
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def _get_rows_with_session(
|
def _get_rows_with_session(
|
||||||
@ -576,11 +545,15 @@ def _get_rows_with_session(
|
|||||||
run: RecorderRuns | None = None,
|
run: RecorderRuns | None = None,
|
||||||
filters: Filters | None = None,
|
filters: Filters | None = None,
|
||||||
no_attributes: bool = False,
|
no_attributes: bool = False,
|
||||||
) -> list[Row]:
|
) -> Iterable[Row]:
|
||||||
"""Return the states at a specific point in time."""
|
"""Return the states at a specific point in time."""
|
||||||
|
schema_version = _schema_version(hass)
|
||||||
if entity_ids and len(entity_ids) == 1:
|
if entity_ids and len(entity_ids) == 1:
|
||||||
return _get_single_entity_states_with_session(
|
return execute_stmt_lambda_element(
|
||||||
hass, session, utc_point_in_time, entity_ids[0], no_attributes
|
session,
|
||||||
|
_get_single_entity_states_stmt(
|
||||||
|
schema_version, utc_point_in_time, entity_ids[0], no_attributes
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if run is None:
|
if run is None:
|
||||||
@ -593,46 +566,41 @@ def _get_rows_with_session(
|
|||||||
# We have more than one entity to look at so we need to do a query on states
|
# We have more than one entity to look at so we need to do a query on states
|
||||||
# since the last recorder run started.
|
# since the last recorder run started.
|
||||||
if entity_ids:
|
if entity_ids:
|
||||||
baked_query = _get_states_baked_query_for_entites(hass, no_attributes)
|
stmt = _get_states_for_entites_stmt(
|
||||||
|
schema_version, run.start, utc_point_in_time, entity_ids, no_attributes
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
baked_query = _get_states_baked_query_for_all(hass, filters, no_attributes)
|
stmt = _get_states_for_all_stmt(
|
||||||
|
schema_version, run.start, utc_point_in_time, filters, no_attributes
|
||||||
return execute(
|
|
||||||
baked_query(session).params(
|
|
||||||
run_start=run.start,
|
|
||||||
utc_point_in_time=utc_point_in_time,
|
|
||||||
entity_ids=entity_ids,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return execute_stmt_lambda_element(session, stmt)
|
||||||
|
|
||||||
def _get_single_entity_states_with_session(
|
|
||||||
hass: HomeAssistant,
|
def _get_single_entity_states_stmt(
|
||||||
session: Session,
|
schema_version: int,
|
||||||
utc_point_in_time: datetime,
|
utc_point_in_time: datetime,
|
||||||
entity_id: str,
|
entity_id: str,
|
||||||
no_attributes: bool = False,
|
no_attributes: bool = False,
|
||||||
) -> list[Row]:
|
) -> StatementLambdaElement:
|
||||||
# Use an entirely different (and extremely fast) query if we only
|
# Use an entirely different (and extremely fast) query if we only
|
||||||
# have a single entity id
|
# have a single entity id
|
||||||
baked_query, join_attributes = bake_query_and_join_attributes(
|
stmt, join_attributes = lambda_stmt_and_join_attributes(
|
||||||
hass, no_attributes, include_last_changed=True
|
schema_version, no_attributes, include_last_changed=True
|
||||||
)
|
)
|
||||||
baked_query += lambda q: q.filter(
|
stmt += (
|
||||||
States.last_updated < bindparam("utc_point_in_time"),
|
lambda q: q.filter(
|
||||||
States.entity_id == bindparam("entity_id"),
|
States.last_updated < utc_point_in_time,
|
||||||
|
States.entity_id == entity_id,
|
||||||
|
)
|
||||||
|
.order_by(States.last_updated.desc())
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
if join_attributes:
|
if join_attributes:
|
||||||
baked_query += lambda q: q.outerjoin(
|
stmt += lambda q: q.outerjoin(
|
||||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||||
)
|
)
|
||||||
baked_query += lambda q: q.order_by(States.last_updated.desc()).limit(1)
|
return stmt
|
||||||
|
|
||||||
query = baked_query(session).params(
|
|
||||||
utc_point_in_time=utc_point_in_time, entity_id=entity_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return execute(query)
|
|
||||||
|
|
||||||
|
|
||||||
def _sorted_states_to_dict(
|
def _sorted_states_to_dict(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""SQLAlchemy util functions."""
|
"""SQLAlchemy util functions."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator, Iterable
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
import functools
|
import functools
|
||||||
@ -18,9 +18,12 @@ from awesomeversion import (
|
|||||||
import ciso8601
|
import ciso8601
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine.cursor import CursorFetchStrategy
|
from sqlalchemy.engine.cursor import CursorFetchStrategy
|
||||||
|
from sqlalchemy.engine.row import Row
|
||||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.baked import Result
|
||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -46,6 +49,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
RETRIES = 3
|
RETRIES = 3
|
||||||
QUERY_RETRY_WAIT = 0.1
|
QUERY_RETRY_WAIT = 0.1
|
||||||
SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
|
SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
|
||||||
|
DEFAULT_YIELD_STATES_ROWS = 32768
|
||||||
|
|
||||||
MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER)
|
MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER)
|
||||||
MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER)
|
MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER)
|
||||||
@ -119,8 +123,10 @@ def commit(session: Session, work: Any) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
|
qry: Query | Result,
|
||||||
) -> list:
|
to_native: bool = False,
|
||||||
|
validate_entity_ids: bool = True,
|
||||||
|
) -> list[Row]:
|
||||||
"""Query the database and convert the objects to HA native form.
|
"""Query the database and convert the objects to HA native form.
|
||||||
|
|
||||||
This method also retries a few times in the case of stale connections.
|
This method also retries a few times in the case of stale connections.
|
||||||
@ -163,7 +169,39 @@ def execute(
|
|||||||
raise
|
raise
|
||||||
time.sleep(QUERY_RETRY_WAIT)
|
time.sleep(QUERY_RETRY_WAIT)
|
||||||
|
|
||||||
assert False # unreachable
|
assert False # unreachable # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
def execute_stmt_lambda_element(
|
||||||
|
session: Session,
|
||||||
|
stmt: StatementLambdaElement,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
yield_per: int | None = DEFAULT_YIELD_STATES_ROWS,
|
||||||
|
) -> Iterable[Row]:
|
||||||
|
"""Execute a StatementLambdaElement.
|
||||||
|
|
||||||
|
If the time window passed is greater than one day
|
||||||
|
the execution method will switch to yield_per to
|
||||||
|
reduce memory pressure.
|
||||||
|
|
||||||
|
It is not recommended to pass a time window
|
||||||
|
when selecting non-ranged rows (ie selecting
|
||||||
|
specific entities) since they are usually faster
|
||||||
|
with .all().
|
||||||
|
"""
|
||||||
|
executed = session.execute(stmt)
|
||||||
|
use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1
|
||||||
|
for tryno in range(0, RETRIES):
|
||||||
|
try:
|
||||||
|
return executed.all() if use_all else executed.yield_per(yield_per) # type: ignore[no-any-return]
|
||||||
|
except SQLAlchemyError as err:
|
||||||
|
_LOGGER.error("Error executing query: %s", err)
|
||||||
|
if tryno == RETRIES - 1:
|
||||||
|
raise
|
||||||
|
time.sleep(QUERY_RETRY_WAIT)
|
||||||
|
|
||||||
|
assert False # unreachable # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
def validate_or_move_away_sqlite_database(dburl: str) -> bool:
|
def validate_or_move_away_sqlite_database(dburl: str) -> bool:
|
||||||
|
@ -6,10 +6,13 @@ from unittest.mock import MagicMock, Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine.result import ChunkedIteratorResult
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.sql.elements import TextClause
|
from sqlalchemy.sql.elements import TextClause
|
||||||
|
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||||
|
|
||||||
from homeassistant.components import recorder
|
from homeassistant.components import recorder
|
||||||
from homeassistant.components.recorder import util
|
from homeassistant.components.recorder import history, util
|
||||||
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
|
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
|
||||||
from homeassistant.components.recorder.models import RecorderRuns
|
from homeassistant.components.recorder.models import RecorderRuns
|
||||||
from homeassistant.components.recorder.util import (
|
from homeassistant.components.recorder.util import (
|
||||||
@ -24,6 +27,7 @@ from homeassistant.util import dt as dt_util
|
|||||||
from .common import corrupt_db_file, run_information_with_session
|
from .common import corrupt_db_file, run_information_with_session
|
||||||
|
|
||||||
from tests.common import SetupRecorderInstanceT, async_test_home_assistant
|
from tests.common import SetupRecorderInstanceT, async_test_home_assistant
|
||||||
|
from tests.components.recorder.common import wait_recording_done
|
||||||
|
|
||||||
|
|
||||||
def test_session_scope_not_setup(hass_recorder):
|
def test_session_scope_not_setup(hass_recorder):
|
||||||
@ -510,8 +514,10 @@ def test_basic_sanity_check(hass_recorder):
|
|||||||
def test_combined_checks(hass_recorder, caplog):
|
def test_combined_checks(hass_recorder, caplog):
|
||||||
"""Run Checks on the open database."""
|
"""Run Checks on the open database."""
|
||||||
hass = hass_recorder()
|
hass = hass_recorder()
|
||||||
|
instance = recorder.get_instance(hass)
|
||||||
|
instance.db_retry_wait = 0
|
||||||
|
|
||||||
cursor = hass.data[DATA_INSTANCE].engine.raw_connection().cursor()
|
cursor = instance.engine.raw_connection().cursor()
|
||||||
|
|
||||||
assert util.run_checks_on_open_db("fake_db_path", cursor) is None
|
assert util.run_checks_on_open_db("fake_db_path", cursor) is None
|
||||||
assert "could not validate that the sqlite3 database" in caplog.text
|
assert "could not validate that the sqlite3 database" in caplog.text
|
||||||
@ -658,3 +664,54 @@ def test_build_mysqldb_conv():
|
|||||||
assert conv["DATETIME"]("2022-05-13T22:33:12.741") == datetime(
|
assert conv["DATETIME"]("2022-05-13T22:33:12.741") == datetime(
|
||||||
2022, 5, 13, 22, 33, 12, 741000, tzinfo=None
|
2022, 5, 13, 22, 33, 12, 741000, tzinfo=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0)
|
||||||
|
def test_execute_stmt_lambda_element(hass_recorder):
|
||||||
|
"""Test executing with execute_stmt_lambda_element."""
|
||||||
|
hass = hass_recorder()
|
||||||
|
instance = recorder.get_instance(hass)
|
||||||
|
hass.states.set("sensor.on", "on")
|
||||||
|
new_state = hass.states.get("sensor.on")
|
||||||
|
wait_recording_done(hass)
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
tomorrow = now + timedelta(days=1)
|
||||||
|
one_week_from_now = now + timedelta(days=7)
|
||||||
|
|
||||||
|
class MockExecutor:
|
||||||
|
def __init__(self, stmt):
|
||||||
|
assert isinstance(stmt, StatementLambdaElement)
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 2:
|
||||||
|
return ["mock_row"]
|
||||||
|
raise SQLAlchemyError
|
||||||
|
|
||||||
|
with session_scope(hass=hass) as session:
|
||||||
|
# No time window, we always get a list
|
||||||
|
stmt = history._get_single_entity_states_stmt(
|
||||||
|
instance.schema_version, dt_util.utcnow(), "sensor.on", False
|
||||||
|
)
|
||||||
|
rows = util.execute_stmt_lambda_element(session, stmt)
|
||||||
|
assert isinstance(rows, list)
|
||||||
|
assert rows[0].state == new_state.state
|
||||||
|
assert rows[0].entity_id == new_state.entity_id
|
||||||
|
|
||||||
|
# Time window >= 2 days, we get a ChunkedIteratorResult
|
||||||
|
rows = util.execute_stmt_lambda_element(session, stmt, now, one_week_from_now)
|
||||||
|
assert isinstance(rows, ChunkedIteratorResult)
|
||||||
|
row = next(rows)
|
||||||
|
assert row.state == new_state.state
|
||||||
|
assert row.entity_id == new_state.entity_id
|
||||||
|
|
||||||
|
# Time window < 2 days, we get a list
|
||||||
|
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
|
||||||
|
assert isinstance(rows, list)
|
||||||
|
assert rows[0].state == new_state.state
|
||||||
|
assert rows[0].entity_id == new_state.entity_id
|
||||||
|
|
||||||
|
with patch.object(session, "execute", MockExecutor):
|
||||||
|
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
|
||||||
|
assert rows == ["mock_row"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user