Reduce latency to find stats metadata (#89824)

This commit is contained in:
J. Nick Koston 2023-03-16 19:00:02 -10:00 committed by GitHub
parent 04a99fdbfc
commit f6f3565796
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 589 additions and 255 deletions

View File

@ -86,6 +86,7 @@ from .table_managers.event_types import EventTypeManager
from .table_managers.state_attributes import StateAttributesManager from .table_managers.state_attributes import StateAttributesManager
from .table_managers.states import StatesManager from .table_managers.states import StatesManager
from .table_managers.states_meta import StatesMetaManager from .table_managers.states_meta import StatesMetaManager
from .table_managers.statistics_meta import StatisticsMetaManager
from .tasks import ( from .tasks import (
AdjustLRUSizeTask, AdjustLRUSizeTask,
AdjustStatisticsTask, AdjustStatisticsTask,
@ -172,6 +173,7 @@ class Recorder(threading.Thread):
threading.Thread.__init__(self, name="Recorder") threading.Thread.__init__(self, name="Recorder")
self.hass = hass self.hass = hass
self.thread_id: int | None = None
self.auto_purge = auto_purge self.auto_purge = auto_purge
self.auto_repack = auto_repack self.auto_repack = auto_repack
self.keep_days = keep_days self.keep_days = keep_days
@ -208,6 +210,7 @@ class Recorder(threading.Thread):
self.state_attributes_manager = StateAttributesManager( self.state_attributes_manager = StateAttributesManager(
self, exclude_attributes_by_domain self, exclude_attributes_by_domain
) )
self.statistics_meta_manager = StatisticsMetaManager(self)
self.event_session: Session | None = None self.event_session: Session | None = None
self._get_session: Callable[[], Session] | None = None self._get_session: Callable[[], Session] | None = None
self._completed_first_database_setup: bool | None = None self._completed_first_database_setup: bool | None = None
@ -613,6 +616,7 @@ class Recorder(threading.Thread):
def run(self) -> None: def run(self) -> None:
"""Start processing events to save.""" """Start processing events to save."""
self.thread_id = threading.get_ident()
setup_result = self._setup_recorder() setup_result = self._setup_recorder()
if not setup_result: if not setup_result:
@ -668,7 +672,7 @@ class Recorder(threading.Thread):
"Database Migration Failed", "Database Migration Failed",
"recorder_database_migration", "recorder_database_migration",
) )
self._activate_and_set_db_ready() self.hass.add_job(self.async_set_db_ready)
self._shutdown() self._shutdown()
return return
@ -687,7 +691,14 @@ class Recorder(threading.Thread):
def _activate_and_set_db_ready(self) -> None: def _activate_and_set_db_ready(self) -> None:
"""Activate the table managers or schedule migrations and mark the db as ready.""" """Activate the table managers or schedule migrations and mark the db as ready."""
with session_scope(session=self.get_session()) as session: with session_scope(session=self.get_session(), read_only=True) as session:
# Prime the statistics meta manager as soon as possible
# since we want the frontend queries to avoid a thundering
# herd of queries to find the statistics meta data if
# there are a lot of statistics graphs on the frontend.
if self.schema_version >= 23:
self.statistics_meta_manager.load(session)
if ( if (
self.schema_version < 36 self.schema_version < 36
or session.execute(has_events_context_ids_to_migrate()).scalar() or session.execute(has_events_context_ids_to_migrate()).scalar()
@ -758,10 +769,11 @@ class Recorder(threading.Thread):
non_state_change_events.append(event_) non_state_change_events.append(event_)
assert self.event_session is not None assert self.event_session is not None
self.event_data_manager.load(non_state_change_events, self.event_session) session = self.event_session
self.event_type_manager.load(non_state_change_events, self.event_session) self.event_data_manager.load(non_state_change_events, session)
self.states_meta_manager.load(state_change_events, self.event_session) self.event_type_manager.load(non_state_change_events, session)
self.state_attributes_manager.load(state_change_events, self.event_session) self.states_meta_manager.load(state_change_events, session)
self.state_attributes_manager.load(state_change_events, session)
def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None: def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None:
"""Process a task, guarding against exceptions to ensure the loop does not collapse.""" """Process a task, guarding against exceptions to ensure the loop does not collapse."""
@ -1077,6 +1089,7 @@ class Recorder(threading.Thread):
self.event_data_manager.reset() self.event_data_manager.reset()
self.event_type_manager.reset() self.event_type_manager.reset()
self.states_meta_manager.reset() self.states_meta_manager.reset()
self.statistics_meta_manager.reset()
if not self.event_session: if not self.event_session:
return return

View File

@ -873,7 +873,7 @@ def _apply_update( # noqa: C901
# There may be duplicated statistics_meta entries, delete duplicates # There may be duplicated statistics_meta entries, delete duplicates
# and try again # and try again
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
delete_statistics_meta_duplicates(session) delete_statistics_meta_duplicates(instance, session)
_create_index( _create_index(
session_maker, "statistics_meta", "ix_statistics_meta_statistic_id" session_maker, "statistics_meta", "ix_statistics_meta_statistic_id"
) )

View File

@ -21,7 +21,7 @@ from sqlalchemy.engine import Engine
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
import voluptuous as vol import voluptuous as vol
@ -132,16 +132,6 @@ QUERY_STATISTICS_SUMMARY_SUM = (
.label("rownum"), .label("rownum"),
) )
QUERY_STATISTIC_META = (
StatisticsMeta.id,
StatisticsMeta.statistic_id,
StatisticsMeta.source,
StatisticsMeta.unit_of_measurement,
StatisticsMeta.has_mean,
StatisticsMeta.has_sum,
StatisticsMeta.name,
)
STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = {
**{unit: DataRateConverter for unit in DataRateConverter.VALID_UNITS}, **{unit: DataRateConverter for unit in DataRateConverter.VALID_UNITS},
@ -373,56 +363,6 @@ def get_start_time() -> datetime:
return last_period return last_period
def _update_or_add_metadata(
session: Session,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> int:
"""Get metadata_id for a statistic_id.
If the statistic_id is previously unknown, add it. If it's already known, update
metadata if needed.
Updating metadata source is not possible.
"""
statistic_id = new_metadata["statistic_id"]
if statistic_id not in old_metadata_dict:
meta = StatisticsMeta.from_meta(new_metadata)
session.add(meta)
session.flush() # Flush to get the metadata id assigned
_LOGGER.debug(
"Added new statistics metadata for %s, new_metadata: %s",
statistic_id,
new_metadata,
)
return meta.id
metadata_id, old_metadata = old_metadata_dict[statistic_id]
if (
old_metadata["has_mean"] != new_metadata["has_mean"]
or old_metadata["has_sum"] != new_metadata["has_sum"]
or old_metadata["name"] != new_metadata["name"]
or old_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"]
):
session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update(
{
StatisticsMeta.has_mean: new_metadata["has_mean"],
StatisticsMeta.has_sum: new_metadata["has_sum"],
StatisticsMeta.name: new_metadata["name"],
StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"],
},
synchronize_session=False,
)
_LOGGER.debug(
"Updated statistics metadata for %s, old_metadata: %s, new_metadata: %s",
statistic_id,
old_metadata,
new_metadata,
)
return metadata_id
def _find_duplicates( def _find_duplicates(
session: Session, table: type[StatisticsBase] session: Session, table: type[StatisticsBase]
) -> tuple[list[int], list[dict]]: ) -> tuple[list[int], list[dict]]:
@ -642,13 +582,16 @@ def _delete_statistics_meta_duplicates(session: Session) -> int:
return total_deleted_rows return total_deleted_rows
def delete_statistics_meta_duplicates(session: Session) -> None: def delete_statistics_meta_duplicates(instance: Recorder, session: Session) -> None:
"""Identify and delete duplicated statistics_meta. """Identify and delete duplicated statistics_meta.
This is used when migrating from schema version 28 to schema version 29. This is used when migrating from schema version 28 to schema version 29.
""" """
deleted_statistics_rows = _delete_statistics_meta_duplicates(session) deleted_statistics_rows = _delete_statistics_meta_duplicates(session)
if deleted_statistics_rows: if deleted_statistics_rows:
statistics_meta_manager = instance.statistics_meta_manager
statistics_meta_manager.reset()
statistics_meta_manager.load(session)
_LOGGER.info( _LOGGER.info(
"Deleted %s duplicated statistics_meta rows", deleted_statistics_rows "Deleted %s duplicated statistics_meta rows", deleted_statistics_rows
) )
@ -750,6 +693,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -
""" """
start = dt_util.as_utc(start) start = dt_util.as_utc(start)
end = start + timedelta(minutes=5) end = start + timedelta(minutes=5)
statistics_meta_manager = instance.statistics_meta_manager
# Return if we already have 5-minute statistics for the requested period # Return if we already have 5-minute statistics for the requested period
with session_scope( with session_scope(
@ -782,7 +726,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -
# Insert collected statistics in the database # Insert collected statistics in the database
for stats in platform_stats: for stats in platform_stats:
metadata_id = _update_or_add_metadata( _, metadata_id = statistics_meta_manager.update_or_add(
session, stats["meta"], current_metadata session, stats["meta"], current_metadata
) )
_insert_statistics( _insert_statistics(
@ -877,28 +821,8 @@ def _update_statistics(
) )
def _generate_get_metadata_stmt(
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> StatementLambdaElement:
"""Generate a statement to fetch metadata."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META))
if statistic_ids:
stmt += lambda q: q.where(
# https://github.com/python/mypy/issues/2608
StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type]
)
if statistic_source is not None:
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source)
if statistic_type == "mean":
stmt += lambda q: q.where(StatisticsMeta.has_mean == true())
elif statistic_type == "sum":
stmt += lambda q: q.where(StatisticsMeta.has_sum == true())
return stmt
def get_metadata_with_session( def get_metadata_with_session(
instance: Recorder,
session: Session, session: Session,
*, *,
statistic_ids: list[str] | None = None, statistic_ids: list[str] | None = None,
@ -908,31 +832,15 @@ def get_metadata_with_session(
"""Fetch meta data. """Fetch meta data.
Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id. Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id.
If statistic_ids is given, fetch metadata only for the listed statistics_ids. If statistic_ids is given, fetch metadata only for the listed statistics_ids.
If statistic_type is given, fetch metadata only for statistic_ids supporting it. If statistic_type is given, fetch metadata only for statistic_ids supporting it.
""" """
return instance.statistics_meta_manager.get_many(
# Fetch metatadata from the database session,
stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source) statistic_ids=statistic_ids,
result = execute_stmt_lambda_element(session, stmt) statistic_type=statistic_type,
if not result: statistic_source=statistic_source,
return {}
return {
meta.statistic_id: (
meta.id,
{
"has_mean": meta.has_mean,
"has_sum": meta.has_sum,
"name": meta.name,
"source": meta.source,
"statistic_id": meta.statistic_id,
"unit_of_measurement": meta.unit_of_measurement,
},
) )
for meta in result
}
def get_metadata( def get_metadata(
@ -945,6 +853,7 @@ def get_metadata(
"""Return metadata for statistic_ids.""" """Return metadata for statistic_ids."""
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
return get_metadata_with_session( return get_metadata_with_session(
get_instance(hass),
session, session,
statistic_ids=statistic_ids, statistic_ids=statistic_ids,
statistic_type=statistic_type, statistic_type=statistic_type,
@ -952,17 +861,10 @@ def get_metadata(
) )
def _clear_statistics_with_session(session: Session, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids."""
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False)
def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids.""" """Clear statistics for a list of statistic_ids."""
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
_clear_statistics_with_session(session, statistic_ids) instance.statistics_meta_manager.delete(session, statistic_ids)
def update_statistics_metadata( def update_statistics_metadata(
@ -972,20 +874,20 @@ def update_statistics_metadata(
new_unit_of_measurement: str | None | UndefinedType, new_unit_of_measurement: str | None | UndefinedType,
) -> None: ) -> None:
"""Update statistics metadata for a statistic_id.""" """Update statistics metadata for a statistic_id."""
statistics_meta_manager = instance.statistics_meta_manager
if new_unit_of_measurement is not UNDEFINED: if new_unit_of_measurement is not UNDEFINED:
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter( statistics_meta_manager.update_unit_of_measurement(
StatisticsMeta.statistic_id == statistic_id session, statistic_id, new_unit_of_measurement
).update({StatisticsMeta.unit_of_measurement: new_unit_of_measurement}) )
if new_statistic_id is not UNDEFINED: if new_statistic_id is not UNDEFINED and new_statistic_id is not None:
with session_scope( with session_scope(
session=instance.get_session(), session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance), exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session: ) as session:
session.query(StatisticsMeta).filter( statistics_meta_manager.update_statistic_id(
(StatisticsMeta.statistic_id == statistic_id) session, DOMAIN, statistic_id, new_statistic_id
& (StatisticsMeta.source == DOMAIN) )
).update({StatisticsMeta.statistic_id: new_statistic_id})
def list_statistic_ids( def list_statistic_ids(
@ -1004,7 +906,7 @@ def list_statistic_ids(
# Query the database # Query the database
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
metadata = get_metadata_with_session( metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_type=statistic_type, statistic_ids=statistic_ids session, statistic_type=statistic_type, statistic_ids=statistic_ids
) )
@ -1609,11 +1511,13 @@ def statistic_during_period(
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id # Fetch metadata for the given statistic_id
if not ( if not (
metadata := get_metadata_with_session(session, statistic_ids=[statistic_id]) metadata := get_instance(hass).statistics_meta_manager.get(
session, statistic_id
)
): ):
return result return result
metadata_id = metadata[statistic_id][0] metadata_id = metadata[0]
oldest_stat = _first_statistic(session, Statistics, metadata_id) oldest_stat = _first_statistic(session, Statistics, metadata_id)
oldest_5_min_stat = None oldest_5_min_stat = None
@ -1724,7 +1628,7 @@ def statistic_during_period(
else: else:
result["change"] = None result["change"] = None
state_unit = unit = metadata[statistic_id][1]["unit_of_measurement"] state_unit = unit = metadata[1]["unit_of_measurement"]
if state := hass.states.get(statistic_id): if state := hass.states.get(statistic_id):
state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
convert = _get_statistic_to_display_unit_converter(unit, state_unit, units) convert = _get_statistic_to_display_unit_converter(unit, state_unit, units)
@ -1749,7 +1653,9 @@ def _statistics_during_period_with_session(
""" """
metadata = None metadata = None
# Fetch metadata for the given (or all) statistic_ids # Fetch metadata for the given (or all) statistic_ids
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_ids=statistic_ids
)
if not metadata: if not metadata:
return {} return {}
@ -1885,7 +1791,9 @@ def _get_last_statistics(
statistic_ids = [statistic_id] statistic_ids = [statistic_id]
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id # Fetch metadata for the given statistic_id
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_ids=statistic_ids
)
if not metadata: if not metadata:
return {} return {}
metadata_id = metadata[statistic_id][0] metadata_id = metadata[statistic_id][0]
@ -1973,7 +1881,9 @@ def get_latest_short_term_statistics(
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_ids # Fetch metadata for the given statistic_ids
if not metadata: if not metadata:
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) metadata = get_instance(hass).statistics_meta_manager.get_many(
session, statistic_ids=statistic_ids
)
if not metadata: if not metadata:
return {} return {}
metadata_ids = [ metadata_ids = [
@ -2318,16 +2228,20 @@ def _filter_unique_constraint_integrity_error(
def _import_statistics_with_session( def _import_statistics_with_session(
instance: Recorder,
session: Session, session: Session,
metadata: StatisticMetaData, metadata: StatisticMetaData,
statistics: Iterable[StatisticData], statistics: Iterable[StatisticData],
table: type[StatisticsBase], table: type[StatisticsBase],
) -> bool: ) -> bool:
"""Import statistics to the database.""" """Import statistics to the database."""
old_metadata_dict = get_metadata_with_session( statistics_meta_manager = instance.statistics_meta_manager
old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[metadata["statistic_id"]] session, statistic_ids=[metadata["statistic_id"]]
) )
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict) _, metadata_id = statistics_meta_manager.update_or_add(
session, metadata, old_metadata_dict
)
for stat in statistics: for stat in statistics:
if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]): if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]):
_update_statistics(session, table, stat_id, stat) _update_statistics(session, table, stat_id, stat)
@ -2350,7 +2264,9 @@ def import_statistics(
session=instance.get_session(), session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance), exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session: ) as session:
return _import_statistics_with_session(session, metadata, statistics, table) return _import_statistics_with_session(
instance, session, metadata, statistics, table
)
@retryable_database_job("adjust_statistics") @retryable_database_job("adjust_statistics")
@ -2364,7 +2280,9 @@ def adjust_statistics(
"""Process an add_statistics job.""" """Process an add_statistics job."""
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session(session, statistic_ids=[statistic_id]) metadata = instance.statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id]
)
if statistic_id not in metadata: if statistic_id not in metadata:
return True return True
@ -2423,10 +2341,9 @@ def change_statistics_unit(
old_unit: str, old_unit: str,
) -> None: ) -> None:
"""Change statistics unit for a statistic_id.""" """Change statistics unit for a statistic_id."""
statistics_meta_manager = instance.statistics_meta_manager
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session(session, statistic_ids=[statistic_id]).get( metadata = statistics_meta_manager.get(session, statistic_id)
statistic_id
)
# Guard against the statistics being removed or updated before the # Guard against the statistics being removed or updated before the
# change_statistics_unit job executes # change_statistics_unit job executes
@ -2447,9 +2364,10 @@ def change_statistics_unit(
) )
for table in tables: for table in tables:
_change_statistics_unit_for_table(session, table, metadata_id, convert) _change_statistics_unit_for_table(session, table, metadata_id, convert)
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id statistics_meta_manager.update_unit_of_measurement(
).update({StatisticsMeta.unit_of_measurement: new_unit}) session, statistic_id, new_unit
)
@callback @callback
@ -2495,16 +2413,19 @@ def _validate_db_schema_utf8(
"statistic_id": statistic_id, "statistic_id": statistic_id,
"unit_of_measurement": None, "unit_of_measurement": None,
} }
statistics_meta_manager = instance.statistics_meta_manager
# Try inserting some metadata which needs utfmb4 support # Try inserting some metadata which needs utfmb4 support
try: try:
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
old_metadata_dict = get_metadata_with_session( old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id] session, statistic_ids=[statistic_id]
) )
try: try:
_update_or_add_metadata(session, metadata, old_metadata_dict) statistics_meta_manager.update_or_add(
_clear_statistics_with_session(session, statistic_ids=[statistic_id]) session, metadata, old_metadata_dict
)
statistics_meta_manager.delete(session, statistic_ids=[statistic_id])
except OperationalError as err: except OperationalError as err:
if err.orig and err.orig.args[0] == 1366: if err.orig and err.orig.args[0] == 1366:
_LOGGER.debug( _LOGGER.debug(
@ -2524,6 +2445,7 @@ def _validate_db_schema(
) -> set[str]: ) -> set[str]:
"""Do some basic checks for common schema errors caused by manual migration.""" """Do some basic checks for common schema errors caused by manual migration."""
schema_errors: set[str] = set() schema_errors: set[str] = set()
statistics_meta_manager = instance.statistics_meta_manager
# Wrong precision is only an issue for MySQL / MariaDB / PostgreSQL # Wrong precision is only an issue for MySQL / MariaDB / PostgreSQL
if instance.dialect_name not in ( if instance.dialect_name not in (
@ -2586,7 +2508,9 @@ def _validate_db_schema(
try: try:
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
for table in tables: for table in tables:
_import_statistics_with_session(session, metadata, (statistics,), table) _import_statistics_with_session(
instance, session, metadata, (statistics,), table
)
stored_statistics = _statistics_during_period_with_session( stored_statistics = _statistics_during_period_with_session(
hass, hass,
session, session,
@ -2625,7 +2549,7 @@ def _validate_db_schema(
table.__tablename__, table.__tablename__,
"µs precision", "µs precision",
) )
_clear_statistics_with_session(session, statistic_ids=[statistic_id]) statistics_meta_manager.delete(session, statistic_ids=[statistic_id])
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
_LOGGER.exception("Error when validating DB schema: %s", exc) _LOGGER.exception("Error when validating DB schema: %s", exc)

View File

@ -14,7 +14,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventData from ..db_schema import EventData
from ..queries import get_shared_event_datas from ..queries import get_shared_event_datas
from ..util import chunked from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING: if TYPE_CHECKING:
from ..core import Recorder from ..core import Recorder
@ -96,8 +96,8 @@ class EventDataManager(BaseLRUTableManager[EventData]):
results: dict[str, int | None] = {} results: dict[str, int | None] = {}
with session.no_autoflush: with session.no_autoflush:
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS): for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for data_id, shared_data in session.execute( for data_id, shared_data in execute_stmt_lambda_element(
get_shared_event_datas(hashs_chunk) session, get_shared_event_datas(hashs_chunk)
): ):
results[shared_data] = self._id_map[shared_data] = cast( results[shared_data] = self._id_map[shared_data] = cast(
int, data_id int, data_id

View File

@ -12,7 +12,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventTypes from ..db_schema import EventTypes
from ..queries import find_event_type_ids from ..queries import find_event_type_ids
from ..util import chunked from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING: if TYPE_CHECKING:
from ..core import Recorder from ..core import Recorder
@ -68,8 +68,8 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]):
with session.no_autoflush: with session.no_autoflush:
for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS): for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS):
for event_type_id, event_type in session.execute( for event_type_id, event_type in execute_stmt_lambda_element(
find_event_type_ids(missing_chunk) session, find_event_type_ids(missing_chunk)
): ):
results[event_type] = self._id_map[event_type] = cast( results[event_type] = self._id_map[event_type] = cast(
int, event_type_id int, event_type_id

View File

@ -15,7 +15,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StateAttributes from ..db_schema import StateAttributes
from ..queries import get_shared_attributes from ..queries import get_shared_attributes
from ..util import chunked from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING: if TYPE_CHECKING:
from ..core import Recorder from ..core import Recorder
@ -113,8 +113,8 @@ class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
results: dict[str, int | None] = {} results: dict[str, int | None] = {}
with session.no_autoflush: with session.no_autoflush:
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS): for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for attributes_id, shared_attrs in session.execute( for attributes_id, shared_attrs in execute_stmt_lambda_element(
get_shared_attributes(hashs_chunk) session, get_shared_attributes(hashs_chunk)
): ):
results[shared_attrs] = self._id_map[shared_attrs] = cast( results[shared_attrs] = self._id_map[shared_attrs] = cast(
int, attributes_id int, attributes_id

View File

@ -12,7 +12,7 @@ from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StatesMeta from ..db_schema import StatesMeta
from ..queries import find_all_states_metadata_ids, find_states_metadata_ids from ..queries import find_all_states_metadata_ids, find_states_metadata_ids
from ..util import chunked from ..util import chunked, execute_stmt_lambda_element
if TYPE_CHECKING: if TYPE_CHECKING:
from ..core import Recorder from ..core import Recorder
@ -98,8 +98,8 @@ class StatesMetaManager(BaseLRUTableManager[StatesMeta]):
with session.no_autoflush: with session.no_autoflush:
for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS): for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS):
for metadata_id, entity_id in session.execute( for metadata_id, entity_id in execute_stmt_lambda_element(
find_states_metadata_ids(missing_chunk) session, find_states_metadata_ids(missing_chunk)
): ):
metadata_id = cast(int, metadata_id) metadata_id = cast(int, metadata_id)
results[entity_id] = metadata_id results[entity_id] = metadata_id

View File

@ -0,0 +1,322 @@
"""Support managing StatesMeta."""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, Literal, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import lambda_stmt, select
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import true
from sqlalchemy.sql.lambdas import StatementLambdaElement
from ..db_schema import StatisticsMeta
from ..models import StatisticMetaData
from ..util import execute_stmt_lambda_element
if TYPE_CHECKING:
from ..core import Recorder
CACHE_SIZE = 8192
_LOGGER = logging.getLogger(__name__)
QUERY_STATISTIC_META = (
StatisticsMeta.id,
StatisticsMeta.statistic_id,
StatisticsMeta.source,
StatisticsMeta.unit_of_measurement,
StatisticsMeta.has_mean,
StatisticsMeta.has_sum,
StatisticsMeta.name,
)
def _generate_get_metadata_stmt(
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> StatementLambdaElement:
"""Generate a statement to fetch metadata."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META))
if statistic_ids:
stmt += lambda q: q.where(
# https://github.com/python/mypy/issues/2608
StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type]
)
if statistic_source is not None:
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source)
if statistic_type == "mean":
stmt += lambda q: q.where(StatisticsMeta.has_mean == true())
elif statistic_type == "sum":
stmt += lambda q: q.where(StatisticsMeta.has_sum == true())
return stmt
def _statistics_meta_to_id_statistics_metadata(
meta: StatisticsMeta,
) -> tuple[int, StatisticMetaData]:
"""Convert StatisticsMeta tuple of metadata_id and StatisticMetaData."""
return (
meta.id,
{
"has_mean": meta.has_mean, # type: ignore[typeddict-item]
"has_sum": meta.has_sum, # type: ignore[typeddict-item]
"name": meta.name,
"source": meta.source, # type: ignore[typeddict-item]
"statistic_id": meta.statistic_id, # type: ignore[typeddict-item]
"unit_of_measurement": meta.unit_of_measurement,
},
)
class StatisticsMetaManager:
"""Manage the StatisticsMeta table."""
def __init__(self, recorder: Recorder) -> None:
"""Initialize the statistics meta manager."""
self.recorder = recorder
self._stat_id_to_id_meta: dict[str, tuple[int, StatisticMetaData]] = LRU(
CACHE_SIZE
)
def _clear_cache(self, statistic_ids: list[str]) -> None:
"""Clear the cache."""
for statistic_id in statistic_ids:
self._stat_id_to_id_meta.pop(statistic_id, None)
def _get_from_database(
self,
session: Session,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Fetch meta data and process it into results and/or cache."""
# Only update the cache if we are in the recorder thread and there are no
# new objects that are not yet committed to the database in the session.
update_cache = (
not session.new
and not session.dirty
and self.recorder.thread_id == threading.get_ident()
)
results: dict[str, tuple[int, StatisticMetaData]] = {}
with session.no_autoflush:
stat_id_to_id_meta = self._stat_id_to_id_meta
for row in execute_stmt_lambda_element(
session,
_generate_get_metadata_stmt(
statistic_ids, statistic_type, statistic_source
),
):
statistics_meta = cast(StatisticsMeta, row)
id_meta = _statistics_meta_to_id_statistics_metadata(statistics_meta)
statistic_id = cast(str, statistics_meta.statistic_id)
results[statistic_id] = id_meta
if update_cache:
stat_id_to_id_meta[statistic_id] = id_meta
return results
def _assert_in_recorder_thread(self) -> None:
"""Assert that we are in the recorder thread."""
if self.recorder.thread_id != threading.get_ident():
raise RuntimeError("Detected unsafe call not in recorder thread")
def _add_metadata(
self, session: Session, statistic_id: str, new_metadata: StatisticMetaData
) -> int:
"""Add metadata to the database.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
meta = StatisticsMeta.from_meta(new_metadata)
session.add(meta)
# Flush to assign an ID
session.flush()
_LOGGER.debug(
"Added new statistics metadata for %s, new_metadata: %s",
statistic_id,
new_metadata,
)
return meta.id
def _update_metadata(
self,
session: Session,
statistic_id: str,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]:
"""Update metadata in the database.
This call is not thread-safe and must be called from the
recorder thread.
"""
metadata_id, old_metadata = old_metadata_dict[statistic_id]
if not (
old_metadata["has_mean"] != new_metadata["has_mean"]
or old_metadata["has_sum"] != new_metadata["has_sum"]
or old_metadata["name"] != new_metadata["name"]
or old_metadata["unit_of_measurement"]
!= new_metadata["unit_of_measurement"]
):
return False, metadata_id
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update(
{
StatisticsMeta.has_mean: new_metadata["has_mean"],
StatisticsMeta.has_sum: new_metadata["has_sum"],
StatisticsMeta.name: new_metadata["name"],
StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"],
},
synchronize_session=False,
)
self._clear_cache([statistic_id])
_LOGGER.debug(
"Updated statistics metadata for %s, old_metadata: %s, new_metadata: %s",
statistic_id,
old_metadata,
new_metadata,
)
return True, metadata_id
def load(self, session: Session) -> None:
"""Load the statistic_id to metadata_id mapping into memory.
This call is not thread-safe and must be called from the
recorder thread.
"""
self.get_many(session)
def get(
self, session: Session, statistic_id: str
) -> tuple[int, StatisticMetaData] | None:
"""Resolve statistic_id to the metadata_id."""
return self.get_many(session, [statistic_id]).get(statistic_id)
def get_many(
self,
session: Session,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Fetch meta data.
Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id.
If statistic_ids is given, fetch metadata only for the listed statistics_ids.
If statistic_type is given, fetch metadata only for statistic_ids supporting it.
"""
if statistic_ids is None:
# Fetch metadata from the database
return self._get_from_database(
session,
statistic_type=statistic_type,
statistic_source=statistic_source,
)
if statistic_type is not None or statistic_source is not None:
# This was originally implemented but we never used it
# so the code was ripped out to reduce the maintenance
# burden.
raise ValueError(
"Providing statistic_type and statistic_source is mutually exclusive of statistic_ids"
)
results: dict[str, tuple[int, StatisticMetaData]] = {}
missing_statistic_id: list[str] = []
for statistic_id in statistic_ids:
if id_meta := self._stat_id_to_id_meta.get(statistic_id):
results[statistic_id] = id_meta
else:
missing_statistic_id.append(statistic_id)
if not missing_statistic_id:
return results
# Fetch metadata from the database
return results | self._get_from_database(
session, statistic_ids=missing_statistic_id
)
def update_or_add(
self,
session: Session,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]:
"""Get metadata_id for a statistic_id.
If the statistic_id is previously unknown, add it. If it's already known, update
metadata if needed.
Updating metadata source is not possible.
Returns a tuple of (updated, metadata_id).
updated is True if the metadata was updated, False if it was not updated.
This call is not thread-safe and must be called from the
recorder thread.
"""
statistic_id = new_metadata["statistic_id"]
if statistic_id not in old_metadata_dict:
return True, self._add_metadata(session, statistic_id, new_metadata)
return self._update_metadata(
session, statistic_id, new_metadata, old_metadata_dict
)
def update_unit_of_measurement(
self, session: Session, statistic_id: str, new_unit: str | None
) -> None:
"""Update the unit of measurement for a statistic_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: new_unit})
self._clear_cache([statistic_id])
def update_statistic_id(
self,
session: Session,
source: str,
old_statistic_id: str,
new_statistic_id: str,
) -> None:
"""Update the statistic_id for a statistic_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter(
(StatisticsMeta.statistic_id == old_statistic_id)
& (StatisticsMeta.source == source)
).update({StatisticsMeta.statistic_id: new_statistic_id})
self._clear_cache([old_statistic_id, new_statistic_id])
def delete(self, session: Session, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False)
self._clear_cache(statistic_ids)
def reset(self) -> None:
"""Reset the cache."""
self._stat_id_to_id_meta = {}

View File

@ -145,31 +145,36 @@ def _parse_float(state: str) -> float:
return fstate return fstate
def _float_or_none(state: str) -> float | None:
"""Return a float or None."""
try:
return _parse_float(state)
except (ValueError, TypeError):
return None
def _entity_history_to_float_and_state(
entity_history: Iterable[State],
) -> list[tuple[float, State]]:
"""Return a list of (float, state) tuples for the given entity."""
return [
(fstate, state)
for state in entity_history
if (fstate := _float_or_none(state.state)) is not None
]
def _normalize_states( def _normalize_states(
hass: HomeAssistant, hass: HomeAssistant,
session: Session,
old_metadatas: dict[str, tuple[int, StatisticMetaData]], old_metadatas: dict[str, tuple[int, StatisticMetaData]],
entity_history: Iterable[State], fstates: list[tuple[float, State]],
entity_id: str, entity_id: str,
) -> tuple[str | None, list[tuple[float, State]]]: ) -> tuple[str | None, list[tuple[float, State]]]:
"""Normalize units.""" """Normalize units."""
old_metadata = old_metadatas[entity_id][1] if entity_id in old_metadatas else None
state_unit: str | None = None state_unit: str | None = None
fstates: list[tuple[float, State]] = []
for state in entity_history:
try:
fstate = _parse_float(state.state)
except (ValueError, TypeError): # TypeError to guard for NULL state in DB
continue
fstates.append((fstate, state))
if not fstates:
return None, fstates
state_unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT)
statistics_unit: str | None statistics_unit: str | None
state_unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT)
old_metadata = old_metadatas[entity_id][1] if entity_id in old_metadatas else None
if not old_metadata: if not old_metadata:
# We've not seen this sensor before, the first valid state determines the unit # We've not seen this sensor before, the first valid state determines the unit
# used for statistics # used for statistics
@ -379,7 +384,15 @@ def compile_statistics(
Note: This will query the database and must not be run in the event loop Note: This will query the database and must not be run in the event loop
""" """
with recorder_util.session_scope(hass=hass) as session: # There is already an active session when this code is called since
# it is called from the recorder statistics. We need to make sure
# this session never gets committed since it would be out of sync
# with the recorder statistics session so we mark it as read only.
#
# If we ever need to write to the database from this function we
# will need to refactor the recorder statistics to use a single
# session.
with recorder_util.session_scope(hass=hass, read_only=True) as session:
compiled = _compile_statistics(hass, session, start, end) compiled = _compile_statistics(hass, session, start, end)
return compiled return compiled
@ -395,10 +408,6 @@ def _compile_statistics( # noqa: C901
sensor_states = _get_sensor_states(hass) sensor_states = _get_sensor_states(hass)
wanted_statistics = _wanted_statistics(sensor_states) wanted_statistics = _wanted_statistics(sensor_states)
old_metadatas = statistics.get_metadata_with_session(
session, statistic_ids=[i.entity_id for i in sensor_states]
)
# Get history between start and end # Get history between start and end
entities_full_history = [ entities_full_history = [
i.entity_id for i in sensor_states if "sum" in wanted_statistics[i.entity_id] i.entity_id for i in sensor_states if "sum" in wanted_statistics[i.entity_id]
@ -427,34 +436,41 @@ def _compile_statistics( # noqa: C901
entity_ids=entities_significant_history, entity_ids=entities_significant_history,
) )
history_list = {**history_list, **_history_list} history_list = {**history_list, **_history_list}
# If there are no recent state changes, the sensor's state may already be pruned
# from the recorder. Get the state from the state machine instead.
for _state in sensor_states:
if _state.entity_id not in history_list:
history_list[_state.entity_id] = [_state]
to_process = [] entities_with_float_states: dict[str, list[tuple[float, State]]] = {}
to_query = []
for _state in sensor_states: for _state in sensor_states:
entity_id = _state.entity_id entity_id = _state.entity_id
if entity_id not in history_list: # If there are no recent state changes, the sensor's state may already be pruned
# from the recorder. Get the state from the state machine instead.
if not (entity_history := history_list.get(entity_id, [_state])):
continue continue
if not (float_states := _entity_history_to_float_and_state(entity_history)):
continue
entities_with_float_states[entity_id] = float_states
entity_history = history_list[entity_id] # Only lookup metadata for entities that have valid float states
statistics_unit, fstates = _normalize_states( # since it will result in cache misses for statistic_ids
# that are not in the metadata table and we are not working
# with them anyway.
old_metadatas = statistics.get_metadata_with_session(
get_instance(hass), session, statistic_ids=list(entities_with_float_states)
)
to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = []
to_query: list[str] = []
for _state in sensor_states:
entity_id = _state.entity_id
if not (maybe_float_states := entities_with_float_states.get(entity_id)):
continue
statistics_unit, valid_float_states = _normalize_states(
hass, hass,
session,
old_metadatas, old_metadatas,
entity_history, maybe_float_states,
entity_id, entity_id,
) )
if not valid_float_states:
if not fstates:
continue continue
state_class: str = _state.attributes[ATTR_STATE_CLASS]
state_class = _state.attributes[ATTR_STATE_CLASS] to_process.append((entity_id, statistics_unit, state_class, valid_float_states))
to_process.append((entity_id, statistics_unit, state_class, fstates))
if "sum" in wanted_statistics[entity_id]: if "sum" in wanted_statistics[entity_id]:
to_query.append(entity_id) to_query.append(entity_id)
@ -465,7 +481,7 @@ def _compile_statistics( # noqa: C901
entity_id, entity_id,
statistics_unit, statistics_unit,
state_class, state_class,
fstates, valid_float_states,
) in to_process: ) in to_process:
# Check metadata # Check metadata
if old_metadata := old_metadatas.get(entity_id): if old_metadata := old_metadatas.get(entity_id):
@ -507,20 +523,20 @@ def _compile_statistics( # noqa: C901
if "max" in wanted_statistics[entity_id]: if "max" in wanted_statistics[entity_id]:
stat["max"] = max( stat["max"] = max(
*itertools.islice( *itertools.islice(
zip(*fstates), # type: ignore[typeddict-item] zip(*valid_float_states), # type: ignore[typeddict-item]
1, 1,
) )
) )
if "min" in wanted_statistics[entity_id]: if "min" in wanted_statistics[entity_id]:
stat["min"] = min( stat["min"] = min(
*itertools.islice( *itertools.islice(
zip(*fstates), # type: ignore[typeddict-item] zip(*valid_float_states), # type: ignore[typeddict-item]
1, 1,
) )
) )
if "mean" in wanted_statistics[entity_id]: if "mean" in wanted_statistics[entity_id]:
stat["mean"] = _time_weighted_average(fstates, start, end) stat["mean"] = _time_weighted_average(valid_float_states, start, end)
if "sum" in wanted_statistics[entity_id]: if "sum" in wanted_statistics[entity_id]:
last_reset = old_last_reset = None last_reset = old_last_reset = None
@ -535,7 +551,7 @@ def _compile_statistics( # noqa: C901
new_state = old_state = last_stat["state"] new_state = old_state = last_stat["state"]
_sum = last_stat["sum"] or 0.0 _sum = last_stat["sum"] or 0.0
for fstate, state in fstates: for fstate, state in valid_float_states:
reset = False reset = False
if ( if (
state_class != SensorStateClass.TOTAL_INCREASING state_class != SensorStateClass.TOTAL_INCREASING

View File

@ -0,0 +1 @@
"""Tests for the recorder table managers."""

View File

@ -0,0 +1,53 @@
"""The tests for the Recorder component."""
from __future__ import annotations
import pytest
from homeassistant.components import recorder
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
from tests.typing import RecorderInstanceGenerator
async def test_passing_mutually_exclusive_options_to_get_many(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None:
"""Test passing mutually exclusive options to get_many."""
instance = await async_setup_recorder_instance(
hass, {recorder.CONF_COMMIT_INTERVAL: 0}
)
with session_scope(session=instance.get_session()) as session:
with pytest.raises(ValueError):
instance.statistics_meta_manager.get_many(
session,
statistic_ids=["light.kitchen"],
statistic_type="mean",
)
with pytest.raises(ValueError):
instance.statistics_meta_manager.get_many(
session, statistic_ids=["light.kitchen"], statistic_source="sensor"
)
assert (
instance.statistics_meta_manager.get_many(
session,
statistic_ids=["light.kitchen"],
)
== {}
)
async def test_unsafe_calls_to_statistics_meta_manager(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None:
"""Test we raise when trying to call non-threadsafe functions on statistics_meta_manager."""
instance = await async_setup_recorder_instance(
hass, {recorder.CONF_COMMIT_INTERVAL: 0}
)
with session_scope(session=instance.get_session()) as session, pytest.raises(
RuntimeError, match="Detected unsafe call not in recorder thread"
):
instance.statistics_meta_manager.delete(
session,
statistic_ids=["light.kitchen"],
)

View File

@ -564,32 +564,6 @@ def _add_entities(hass, entity_ids):
return states return states
def _add_events(hass, events):
with session_scope(hass=hass) as session:
session.query(Events).delete(synchronize_session=False)
for event_type in events:
hass.bus.fire(event_type)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
events = []
for event, event_data, event_types in (
session.query(Events, EventData, EventTypes)
.outerjoin(EventTypes, (Events.event_type_id == EventTypes.event_type_id))
.outerjoin(EventData, Events.data_id == EventData.data_id)
):
event = cast(Events, event)
event_data = cast(EventData, event_data)
event_types = cast(EventTypes, event_types)
native_event = event.to_native()
if event_data:
native_event.data = event_data.to_native()
native_event.event_type = event_types.event_type
events.append(native_event)
return events
def _state_with_context(hass, entity_id): def _state_with_context(hass, entity_id):
# We don't restore context unless we need it by joining the # We don't restore context unless we need it by joining the
# events table on the event_id for state_changed events # events table on the event_id for state_changed events
@ -646,12 +620,12 @@ def test_saving_state_incl_entities(
assert _state_with_context(hass, "test2.recorder").as_dict() == states[0].as_dict() assert _state_with_context(hass, "test2.recorder").as_dict() == states[0].as_dict()
def test_saving_event_exclude_event_type( async def test_saving_event_exclude_event_type(
hass_recorder: Callable[..., HomeAssistant] async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant,
) -> None: ) -> None:
"""Test saving and restoring an event.""" """Test saving and restoring an event."""
hass = hass_recorder( config = {
{
"exclude": { "exclude": {
"event_types": [ "event_types": [
"service_registered", "service_registered",
@ -663,8 +637,36 @@ def test_saving_event_exclude_event_type(
] ]
} }
} }
instance = await async_setup_recorder_instance(hass, config)
events = ["test", "test2"]
for event_type in events:
hass.bus.async_fire(event_type)
await async_wait_recording_done(hass)
def _get_events(hass: HomeAssistant, event_types: list[str]) -> list[Event]:
with session_scope(hass=hass) as session:
events = []
for event, event_data, event_types in (
session.query(Events, EventData, EventTypes)
.outerjoin(
EventTypes, (Events.event_type_id == EventTypes.event_type_id)
) )
events = _add_events(hass, ["test", "test2"]) .outerjoin(EventData, Events.data_id == EventData.data_id)
.where(EventTypes.event_type.in_(event_types))
):
event = cast(Events, event)
event_data = cast(EventData, event_data)
event_types = cast(EventTypes, event_types)
native_event = event.to_native()
if event_data:
native_event.data = event_data.to_native()
native_event.event_type = event_types.event_type
events.append(native_event)
return events
events = await instance.async_add_executor_job(_get_events, hass, ["test", "test2"])
assert len(events) == 1 assert len(events) == 1
assert events[0].event_type == "test2" assert events[0].event_type == "test2"

View File

@ -22,12 +22,10 @@ from homeassistant.components.recorder.models import (
) )
from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.statistics import (
STATISTIC_UNIT_TO_UNIT_CONVERTER, STATISTIC_UNIT_TO_UNIT_CONVERTER,
_generate_get_metadata_stmt,
_generate_max_mean_min_statistic_in_sub_period_stmt, _generate_max_mean_min_statistic_in_sub_period_stmt,
_generate_statistics_at_time_stmt, _generate_statistics_at_time_stmt,
_generate_statistics_during_period_stmt, _generate_statistics_during_period_stmt,
_statistics_during_period_with_session, _statistics_during_period_with_session,
_update_or_add_metadata,
async_add_external_statistics, async_add_external_statistics,
async_import_statistics, async_import_statistics,
delete_statistics_duplicates, delete_statistics_duplicates,
@ -38,6 +36,10 @@ from homeassistant.components.recorder.statistics import (
get_metadata, get_metadata,
list_statistic_ids, list_statistic_ids,
) )
from homeassistant.components.recorder.table_managers.statistics_meta import (
StatisticsMetaManager,
_generate_get_metadata_stmt,
)
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
from homeassistant.components.sensor import UNIT_CONVERTERS from homeassistant.components.sensor import UNIT_CONVERTERS
from homeassistant.const import UnitOfTemperature from homeassistant.const import UnitOfTemperature
@ -1520,7 +1522,8 @@ def test_delete_metadata_duplicates_no_duplicates(
hass = hass_recorder() hass = hass_recorder()
wait_recording_done(hass) wait_recording_done(hass)
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
delete_statistics_meta_duplicates(session) instance = recorder.get_instance(hass)
delete_statistics_meta_duplicates(instance, session)
assert "duplicated statistics_meta rows" not in caplog.text assert "duplicated statistics_meta rows" not in caplog.text
@ -1562,9 +1565,9 @@ async def test_validate_db_schema_fix_utf8_issue(
with patch( with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", "mysql" "homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"
), patch( ), patch(
"homeassistant.components.recorder.statistics._update_or_add_metadata", "homeassistant.components.recorder.table_managers.statistics_meta.StatisticsMetaManager.update_or_add",
wraps=StatisticsMetaManager.update_or_add,
side_effect=[utf8_error, DEFAULT, DEFAULT], side_effect=[utf8_error, DEFAULT, DEFAULT],
wraps=_update_or_add_metadata,
): ):
await async_setup_recorder_instance(hass) await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass) await async_wait_recording_done(hass)