Convert history queries to use lambda_stmt (#71870)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2022-05-15 10:47:29 -05:00 committed by GitHub
parent 8ea5ec6f08
commit 98809675ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 313 additions and 262 deletions

View File

@ -20,7 +20,7 @@ from homeassistant.helpers.integration_platform import (
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from . import history, statistics, websocket_api
from . import statistics, websocket_api
from .const import (
CONF_DB_INTEGRITY_CHECK,
DATA_INSTANCE,
@ -166,7 +166,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
instance.async_register()
instance.start()
async_register_services(hass, instance)
history.async_setup(hass)
statistics.async_setup(hass)
websocket_api.async_setup(hass)
await async_process_integration_platforms(hass, DOMAIN, _process_recorder_platform)

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from sqlalchemy import not_, or_
from sqlalchemy.ext.baked import BakedQuery
from sqlalchemy.sql.elements import ClauseList
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
@ -60,16 +59,6 @@ class Filters:
or self.included_entity_globs
)
def bake(self, baked_query: BakedQuery) -> BakedQuery:
"""Update a baked query.
Works the same as apply on a baked_query.
"""
if not self.has_config:
return
baked_query += lambda q: q.filter(self.entity_filter())
def entity_filter(self) -> ClauseList:
"""Generate the entity filter query."""
includes = []

View File

@ -9,13 +9,12 @@ import logging
import time
from typing import Any, cast
from sqlalchemy import Column, Text, and_, bindparam, func, or_
from sqlalchemy import Column, Text, and_, func, lambda_stmt, or_, select
from sqlalchemy.engine.row import Row
from sqlalchemy.ext import baked
from sqlalchemy.ext.baked import BakedQuery
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement
from homeassistant.components import recorder
from homeassistant.components.websocket_api.const import (
@ -36,7 +35,7 @@ from .models import (
process_timestamp_to_utc_isoformat,
row_to_compressed_state,
)
from .util import execute, session_scope
from .util import execute_stmt_lambda_element, session_scope
# mypy: allow-untyped-defs, no-check-untyped-defs
@ -111,52 +110,48 @@ QUERY_STATES_NO_LAST_CHANGED = [
StateAttributes.shared_attrs,
]
HISTORY_BAKERY = "recorder_history_bakery"
def _schema_version(hass: HomeAssistant) -> int:
return recorder.get_instance(hass).schema_version
def bake_query_and_join_attributes(
hass: HomeAssistant, no_attributes: bool, include_last_changed: bool = True
) -> tuple[Any, bool]:
"""Return the initial backed query and if StateAttributes should be joined.
def lambda_stmt_and_join_attributes(
schema_version: int, no_attributes: bool, include_last_changed: bool = True
) -> tuple[StatementLambdaElement, bool]:
"""Return the lambda_stmt and if StateAttributes should be joined.
Because these are baked queries the values inside the lambdas need
Because these are lambda_stmt the values inside the lambdas need
to be explicitly written out to avoid caching the wrong values.
"""
bakery: baked.bakery = hass.data[HISTORY_BAKERY]
# If no_attributes was requested we do the query
# without the attributes fields and do not join the
# state_attributes table
if no_attributes:
if include_last_changed:
return bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR)), False
return lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR)), False
return (
bakery(lambda s: s.query(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)),
lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)),
False,
)
# If we in the process of migrating schema we do
# not want to join the state_attributes table as we
# do not know if it will be there yet
if recorder.get_instance(hass).schema_version < 25:
if schema_version < 25:
if include_last_changed:
return (
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25)),
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25)),
False,
)
return (
bakery(lambda s: s.query(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)),
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)),
False,
)
# Finally if no migration is in progress and no_attributes
# was not requested, we query both attributes columns and
# join state_attributes
if include_last_changed:
return bakery(lambda s: s.query(*QUERY_STATES)), True
return bakery(lambda s: s.query(*QUERY_STATES_NO_LAST_CHANGED)), True
def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks."""
hass.data[HISTORY_BAKERY] = baked.bakery()
return lambda_stmt(lambda: select(*QUERY_STATES)), True
return lambda_stmt(lambda: select(*QUERY_STATES_NO_LAST_CHANGED)), True
def get_significant_states(
@ -200,38 +195,30 @@ def _ignore_domains_filter(query: Query) -> Query:
)
def _query_significant_states_with_session(
hass: HomeAssistant,
session: Session,
def _significant_states_stmt(
schema_version: int,
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Filters | None = None,
significant_changes_only: bool = True,
no_attributes: bool = False,
) -> list[Row]:
end_time: datetime | None,
entity_ids: list[str] | None,
filters: Filters | None,
significant_changes_only: bool,
no_attributes: bool,
) -> StatementLambdaElement:
"""Query the database for significant state changes."""
if _LOGGER.isEnabledFor(logging.DEBUG):
timer_start = time.perf_counter()
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=True
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=not significant_changes_only
)
if entity_ids is not None and len(entity_ids) == 1:
if (
significant_changes_only
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
):
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=False
)
baked_query += lambda q: q.filter(
(States.last_changed == States.last_updated)
| States.last_changed.is_(None)
)
if (
entity_ids
and len(entity_ids) == 1
and significant_changes_only
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
):
stmt += lambda q: q.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
)
elif significant_changes_only:
baked_query += lambda q: q.filter(
stmt += lambda q: q.filter(
or_(
*[
States.entity_id.like(entity_domain)
@ -244,36 +231,24 @@ def _query_significant_states_with_session(
)
)
if entity_ids is not None:
baked_query += lambda q: q.filter(
States.entity_id.in_(bindparam("entity_ids", expanding=True))
)
if entity_ids:
stmt += lambda q: q.filter(States.entity_id.in_(entity_ids))
else:
baked_query += _ignore_domains_filter
if filters:
filters.bake(baked_query)
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.entity_filter()
stmt += lambda q: q.filter(entity_filter)
baked_query += lambda q: q.filter(States.last_updated > bindparam("start_time"))
if end_time is not None:
baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time"))
stmt += lambda q: q.filter(States.last_updated > start_time)
if end_time:
stmt += lambda q: q.filter(States.last_updated < end_time)
if join_attributes:
baked_query += lambda q: q.outerjoin(
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
states = execute(
baked_query(session).params(
start_time=start_time, end_time=end_time, entity_ids=entity_ids
)
)
if _LOGGER.isEnabledFor(logging.DEBUG):
elapsed = time.perf_counter() - timer_start
_LOGGER.debug("get_significant_states took %fs", elapsed)
return states
stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
return stmt
def get_significant_states_with_session(
@ -301,9 +276,8 @@ def get_significant_states_with_session(
as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs).
"""
states = _query_significant_states_with_session(
hass,
session,
stmt = _significant_states_stmt(
_schema_version(hass),
start_time,
end_time,
entity_ids,
@ -311,6 +285,9 @@ def get_significant_states_with_session(
significant_changes_only,
no_attributes,
)
states = execute_stmt_lambda_element(
session, stmt, None if entity_ids else start_time, end_time
)
return _sorted_states_to_dict(
hass,
session,
@ -354,6 +331,38 @@ def get_full_significant_states_with_session(
)
def _state_changed_during_period_stmt(
schema_version: int,
start_time: datetime,
end_time: datetime | None,
entity_id: str | None,
no_attributes: bool,
descending: bool,
limit: int | None,
) -> StatementLambdaElement:
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=False
)
stmt += lambda q: q.filter(
((States.last_changed == States.last_updated) | States.last_changed.is_(None))
& (States.last_updated > start_time)
)
if end_time:
stmt += lambda q: q.filter(States.last_updated < end_time)
stmt += lambda q: q.filter(States.entity_id == entity_id)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
if descending:
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc())
else:
stmt += lambda q: q.order_by(States.entity_id, States.last_updated)
if limit:
stmt += lambda q: q.limit(limit)
return stmt
def state_changes_during_period(
hass: HomeAssistant,
start_time: datetime,
@ -365,52 +374,21 @@ def state_changes_during_period(
include_start_time_state: bool = True,
) -> MutableMapping[str, list[State]]:
"""Return states changes during UTC period start_time - end_time."""
entity_id = entity_id.lower() if entity_id is not None else None
with session_scope(hass=hass) as session:
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=False
stmt = _state_changed_during_period_stmt(
_schema_version(hass),
start_time,
end_time,
entity_id,
no_attributes,
descending,
limit,
)
baked_query += lambda q: q.filter(
(
(States.last_changed == States.last_updated)
| States.last_changed.is_(None)
)
& (States.last_updated > bindparam("start_time"))
states = execute_stmt_lambda_element(
session, stmt, None if entity_id else start_time, end_time
)
if end_time is not None:
baked_query += lambda q: q.filter(
States.last_updated < bindparam("end_time")
)
if entity_id is not None:
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
entity_id = entity_id.lower()
if join_attributes:
baked_query += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
if descending:
baked_query += lambda q: q.order_by(
States.entity_id, States.last_updated.desc()
)
else:
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
if limit:
baked_query += lambda q: q.limit(bindparam("limit"))
states = execute(
baked_query(session).params(
start_time=start_time,
end_time=end_time,
entity_id=entity_id,
limit=limit,
)
)
entity_ids = [entity_id] if entity_id is not None else None
return cast(
@ -426,41 +404,37 @@ def state_changes_during_period(
)
def _get_last_state_changes_stmt(
schema_version: int, number_of_states: int, entity_id: str
) -> StatementLambdaElement:
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, False, include_last_changed=False
)
stmt += lambda q: q.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
).filter(States.entity_id == entity_id)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()).limit(
number_of_states
)
return stmt
def get_last_state_changes(
hass: HomeAssistant, number_of_states: int, entity_id: str
) -> MutableMapping[str, list[State]]:
"""Return the last number_of_states."""
start_time = dt_util.utcnow()
entity_id = entity_id.lower() if entity_id is not None else None
with session_scope(hass=hass) as session:
baked_query, join_attributes = bake_query_and_join_attributes(
hass, False, include_last_changed=False
stmt = _get_last_state_changes_stmt(
_schema_version(hass), number_of_states, entity_id
)
baked_query += lambda q: q.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None)
)
if entity_id is not None:
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
entity_id = entity_id.lower()
if join_attributes:
baked_query += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
baked_query += lambda q: q.order_by(
States.entity_id, States.last_updated.desc()
)
baked_query += lambda q: q.limit(bindparam("number_of_states"))
states = execute(
baked_query(session).params(
number_of_states=number_of_states, entity_id=entity_id
)
)
states = list(execute_stmt_lambda_element(session, stmt))
entity_ids = [entity_id] if entity_id is not None else None
return cast(
@ -476,96 +450,91 @@ def get_last_state_changes(
)
def _most_recent_state_ids_entities_subquery(query: Query) -> Query:
"""Query to find the most recent state id for specific entities."""
def _get_states_for_entites_stmt(
schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
entity_ids: list[str],
no_attributes: bool,
) -> StatementLambdaElement:
"""Baked query to get states for specific entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True
)
# We got an include-list of entities, accelerate the query by filtering already
# in the inner query.
most_recent_state_ids = (
query.session.query(func.max(States.state_id).label("max_state_id"))
.filter(
(States.last_updated >= bindparam("run_start"))
& (States.last_updated < bindparam("utc_point_in_time"))
)
.filter(States.entity_id.in_(bindparam("entity_ids", expanding=True)))
.group_by(States.entity_id)
.subquery()
stmt += lambda q: q.where(
States.state_id
== (
select(func.max(States.state_id).label("max_state_id"))
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
).c.max_state_id
)
return query.join(
most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id,
)
def _get_states_baked_query_for_entites(
hass: HomeAssistant,
no_attributes: bool = False,
) -> BakedQuery:
"""Baked query to get states for specific entities."""
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=True
)
baked_query += _most_recent_state_ids_entities_subquery
if join_attributes:
baked_query += lambda q: q.outerjoin(
stmt += lambda q: q.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
)
return baked_query
return stmt
def _most_recent_state_ids_subquery(query: Query) -> Query:
"""Find the most recent state ids for all entiites."""
def _get_states_for_all_stmt(
schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
filters: Filters | None,
no_attributes: bool,
) -> StatementLambdaElement:
"""Baked query to get states for all entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True
)
# We did not get an include-list of entities, query all states in the inner
# query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = (
query.session.query(
select(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= bindparam("run_start"))
& (States.last_updated < bindparam("utc_point_in_time"))
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
most_recent_state_ids = (
query.session.query(func.max(States.state_id).label("max_state_id"))
.join(
most_recent_states_by_date,
and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated == most_recent_states_by_date.c.max_last_updated,
),
)
.group_by(States.entity_id)
.subquery()
stmt += lambda q: q.where(
States.state_id
== (
select(func.max(States.state_id).label("max_state_id"))
.join(
most_recent_states_by_date,
and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated
== most_recent_states_by_date.c.max_last_updated,
),
)
.group_by(States.entity_id)
.subquery()
).c.max_state_id,
)
return query.join(
most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id,
)
def _get_states_baked_query_for_all(
hass: HomeAssistant,
filters: Filters | None = None,
no_attributes: bool = False,
) -> BakedQuery:
"""Baked query to get states for all entities."""
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=True
)
baked_query += _most_recent_state_ids_subquery
baked_query += _ignore_domains_filter
if filters:
filters.bake(baked_query)
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.entity_filter()
stmt += lambda q: q.filter(entity_filter)
if join_attributes:
baked_query += lambda q: q.outerjoin(
stmt += lambda q: q.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
)
return baked_query
return stmt
def _get_rows_with_session(
@ -576,11 +545,15 @@ def _get_rows_with_session(
run: RecorderRuns | None = None,
filters: Filters | None = None,
no_attributes: bool = False,
) -> list[Row]:
) -> Iterable[Row]:
"""Return the states at a specific point in time."""
schema_version = _schema_version(hass)
if entity_ids and len(entity_ids) == 1:
return _get_single_entity_states_with_session(
hass, session, utc_point_in_time, entity_ids[0], no_attributes
return execute_stmt_lambda_element(
session,
_get_single_entity_states_stmt(
schema_version, utc_point_in_time, entity_ids[0], no_attributes
),
)
if run is None:
@ -593,46 +566,41 @@ def _get_rows_with_session(
# We have more than one entity to look at so we need to do a query on states
# since the last recorder run started.
if entity_ids:
baked_query = _get_states_baked_query_for_entites(hass, no_attributes)
else:
baked_query = _get_states_baked_query_for_all(hass, filters, no_attributes)
return execute(
baked_query(session).params(
run_start=run.start,
utc_point_in_time=utc_point_in_time,
entity_ids=entity_ids,
stmt = _get_states_for_entites_stmt(
schema_version, run.start, utc_point_in_time, entity_ids, no_attributes
)
)
else:
stmt = _get_states_for_all_stmt(
schema_version, run.start, utc_point_in_time, filters, no_attributes
)
return execute_stmt_lambda_element(session, stmt)
def _get_single_entity_states_with_session(
hass: HomeAssistant,
session: Session,
def _get_single_entity_states_stmt(
schema_version: int,
utc_point_in_time: datetime,
entity_id: str,
no_attributes: bool = False,
) -> list[Row]:
) -> StatementLambdaElement:
# Use an entirely different (and extremely fast) query if we only
# have a single entity id
baked_query, join_attributes = bake_query_and_join_attributes(
hass, no_attributes, include_last_changed=True
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True
)
baked_query += lambda q: q.filter(
States.last_updated < bindparam("utc_point_in_time"),
States.entity_id == bindparam("entity_id"),
stmt += (
lambda q: q.filter(
States.last_updated < utc_point_in_time,
States.entity_id == entity_id,
)
.order_by(States.last_updated.desc())
.limit(1)
)
if join_attributes:
baked_query += lambda q: q.outerjoin(
stmt += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
)
baked_query += lambda q: q.order_by(States.last_updated.desc()).limit(1)
query = baked_query(session).params(
utc_point_in_time=utc_point_in_time, entity_id=entity_id
)
return execute(query)
return stmt
def _sorted_states_to_dict(

View File

@ -1,7 +1,7 @@
"""SQLAlchemy util functions."""
from __future__ import annotations
from collections.abc import Callable, Generator
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from datetime import date, datetime, timedelta
import functools
@ -18,9 +18,12 @@ from awesomeversion import (
import ciso8601
from sqlalchemy import text
from sqlalchemy.engine.cursor import CursorFetchStrategy
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.ext.baked import Result
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.lambdas import StatementLambdaElement
from typing_extensions import Concatenate, ParamSpec
from homeassistant.core import HomeAssistant
@ -46,6 +49,7 @@ _LOGGER = logging.getLogger(__name__)
RETRIES = 3
QUERY_RETRY_WAIT = 0.1
SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
DEFAULT_YIELD_STATES_ROWS = 32768
MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER)
@ -119,8 +123,10 @@ def commit(session: Session, work: Any) -> bool:
def execute(
qry: Query, to_native: bool = False, validate_entity_ids: bool = True
) -> list:
qry: Query | Result,
to_native: bool = False,
validate_entity_ids: bool = True,
) -> list[Row]:
"""Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections.
@ -163,7 +169,39 @@ def execute(
raise
time.sleep(QUERY_RETRY_WAIT)
assert False # unreachable
assert False # unreachable # pragma: no cover
def execute_stmt_lambda_element(
session: Session,
stmt: StatementLambdaElement,
start_time: datetime | None = None,
end_time: datetime | None = None,
yield_per: int | None = DEFAULT_YIELD_STATES_ROWS,
) -> Iterable[Row]:
"""Execute a StatementLambdaElement.
If the time window passed is greater than one day
the execution method will switch to yield_per to
reduce memory pressure.
It is not recommended to pass a time window
when selecting non-ranged rows (ie selecting
specific entities) since they are usually faster
with .all().
"""
executed = session.execute(stmt)
use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1
for tryno in range(0, RETRIES):
try:
return executed.all() if use_all else executed.yield_per(yield_per) # type: ignore[no-any-return]
except SQLAlchemyError as err:
_LOGGER.error("Error executing query: %s", err)
if tryno == RETRIES - 1:
raise
time.sleep(QUERY_RETRY_WAIT)
assert False # unreachable # pragma: no cover
def validate_or_move_away_sqlite_database(dburl: str) -> bool:

View File

@ -6,10 +6,13 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy import text
from sqlalchemy.engine.result import ChunkedIteratorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.lambdas import StatementLambdaElement
from homeassistant.components import recorder
from homeassistant.components.recorder import util
from homeassistant.components.recorder import history, util
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import RecorderRuns
from homeassistant.components.recorder.util import (
@ -24,6 +27,7 @@ from homeassistant.util import dt as dt_util
from .common import corrupt_db_file, run_information_with_session
from tests.common import SetupRecorderInstanceT, async_test_home_assistant
from tests.components.recorder.common import wait_recording_done
def test_session_scope_not_setup(hass_recorder):
@ -510,8 +514,10 @@ def test_basic_sanity_check(hass_recorder):
def test_combined_checks(hass_recorder, caplog):
"""Run Checks on the open database."""
hass = hass_recorder()
instance = recorder.get_instance(hass)
instance.db_retry_wait = 0
cursor = hass.data[DATA_INSTANCE].engine.raw_connection().cursor()
cursor = instance.engine.raw_connection().cursor()
assert util.run_checks_on_open_db("fake_db_path", cursor) is None
assert "could not validate that the sqlite3 database" in caplog.text
@ -658,3 +664,54 @@ def test_build_mysqldb_conv():
assert conv["DATETIME"]("2022-05-13T22:33:12.741") == datetime(
2022, 5, 13, 22, 33, 12, 741000, tzinfo=None
)
@patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0)
def test_execute_stmt_lambda_element(hass_recorder):
"""Test executing with execute_stmt_lambda_element."""
hass = hass_recorder()
instance = recorder.get_instance(hass)
hass.states.set("sensor.on", "on")
new_state = hass.states.get("sensor.on")
wait_recording_done(hass)
now = dt_util.utcnow()
tomorrow = now + timedelta(days=1)
one_week_from_now = now + timedelta(days=7)
class MockExecutor:
def __init__(self, stmt):
assert isinstance(stmt, StatementLambdaElement)
self.calls = 0
def all(self):
self.calls += 1
if self.calls == 2:
return ["mock_row"]
raise SQLAlchemyError
with session_scope(hass=hass) as session:
# No time window, we always get a list
stmt = history._get_single_entity_states_stmt(
instance.schema_version, dt_util.utcnow(), "sensor.on", False
)
rows = util.execute_stmt_lambda_element(session, stmt)
assert isinstance(rows, list)
assert rows[0].state == new_state.state
assert rows[0].entity_id == new_state.entity_id
# Time window >= 2 days, we get a ChunkedIteratorResult
rows = util.execute_stmt_lambda_element(session, stmt, now, one_week_from_now)
assert isinstance(rows, ChunkedIteratorResult)
row = next(rows)
assert row.state == new_state.state
assert row.entity_id == new_state.entity_id
# Time window < 2 days, we get a list
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
assert isinstance(rows, list)
assert rows[0].state == new_state.state
assert rows[0].entity_id == new_state.entity_id
with patch.object(session, "execute", MockExecutor):
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow)
assert rows == ["mock_row"]