Avoid database executor job to fetch statistic metadata on cache hit (#89960)

* Avoid database executor job to fetch statistic metadata on cache hit

Since we will almost always have a cache hit fetching
statistic meta data we can avoid an executor job

* Avoid database executor job to fetch statistic metadata on cache hit

Since we will almost always have a cache hit fetching
statistic meta data we can avoid an executor job

* Avoid database executor job to fetch statistic metadata on cache hit

Since we will almost always have a cache hit fetching
statistic meta data we can avoid an executor job

* remove exception catch since the threading.excepthook will actually catch this in production

* fix a few missed ones

* threadsafe

* Update homeassistant/components/recorder/table_managers/statistics_meta.py

* coverage and optimistic caching
This commit is contained in:
J. Nick Koston 2023-03-19 16:01:16 -10:00 committed by GitHub
parent d7de23fa65
commit 5ffb233004
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 232 additions and 118 deletions

View File

@ -603,7 +603,7 @@ async def async_validate(hass: HomeAssistant) -> EnergyPreferencesValidation:
functools.partial( functools.partial(
recorder.statistics.get_metadata, recorder.statistics.get_metadata,
hass, hass,
statistic_ids=list(wanted_statistics_metadata), statistic_ids=set(wanted_statistics_metadata),
) )
) )
) )

View File

@ -262,8 +262,8 @@ async def ws_get_fossil_energy_consumption(
connection.send_error(msg["id"], "invalid_end_time", "Invalid end_time") connection.send_error(msg["id"], "invalid_end_time", "Invalid end_time")
return return
statistic_ids = list(msg["energy_statistic_ids"]) statistic_ids = set(msg["energy_statistic_ids"])
statistic_ids.append(msg["co2_statistic_id"]) statistic_ids.add(msg["co2_statistic_id"])
# Fetch energy + CO2 statistics # Fetch energy + CO2 statistics
statistics = await recorder.get_instance(hass).async_add_executor_job( statistics = await recorder.get_instance(hass).async_add_executor_job(

View File

@ -501,6 +501,7 @@ class Recorder(threading.Thread):
new_size = self.hass.states.async_entity_ids_count() * 2 new_size = self.hass.states.async_entity_ids_count() * 2
self.state_attributes_manager.adjust_lru_size(new_size) self.state_attributes_manager.adjust_lru_size(new_size)
self.states_meta_manager.adjust_lru_size(new_size) self.states_meta_manager.adjust_lru_size(new_size)
self.statistics_meta_manager.adjust_lru_size(new_size)
@callback @callback
def async_periodic_statistics(self) -> None: def async_periodic_statistics(self) -> None:

View File

@ -713,10 +713,10 @@ def compile_missing_statistics(instance: Recorder) -> bool:
periods_without_commit += 1 periods_without_commit += 1
end = start + timedelta(minutes=period_size) end = start + timedelta(minutes=period_size)
_LOGGER.debug("Compiling missing statistics for %s-%s", start, end) _LOGGER.debug("Compiling missing statistics for %s-%s", start, end)
metadata_modified = _compile_statistics( modified_statistic_ids = _compile_statistics(
instance, session, start, end >= last_period instance, session, start, end >= last_period
) )
if periods_without_commit == commit_interval or metadata_modified: if periods_without_commit == commit_interval or modified_statistic_ids:
session.commit() session.commit()
session.expunge_all() session.expunge_all()
periods_without_commit = 0 periods_without_commit = 0
@ -736,29 +736,40 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -
session=instance.get_session(), session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance), exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session: ) as session:
_compile_statistics(instance, session, start, fire_events) modified_statistic_ids = _compile_statistics(
instance, session, start, fire_events
)
if modified_statistic_ids:
# In the rare case that we have modified statistic_ids, we reload the modified
# statistics meta data into the cache in a fresh session to ensure that the
# cache is up to date and future calls to get statistics meta data will
# not have to hit the database again.
with session_scope(session=instance.get_session(), read_only=True) as session:
instance.statistics_meta_manager.get_many(session, modified_statistic_ids)
return True return True
def _compile_statistics( def _compile_statistics(
instance: Recorder, session: Session, start: datetime, fire_events: bool instance: Recorder, session: Session, start: datetime, fire_events: bool
) -> bool: ) -> set[str]:
"""Compile 5-minute statistics for all integrations with a recorder platform. """Compile 5-minute statistics for all integrations with a recorder platform.
This is a helper function for compile_statistics and compile_missing_statistics This is a helper function for compile_statistics and compile_missing_statistics
that does not retry on database errors since both callers already retry. that does not retry on database errors since both callers already retry.
returns True if metadata was modified, False otherwise returns a set of modified statistic_ids if any were modified.
""" """
assert start.tzinfo == dt_util.UTC, "start must be in UTC" assert start.tzinfo == dt_util.UTC, "start must be in UTC"
end = start + timedelta(minutes=5) end = start + timedelta(minutes=5)
statistics_meta_manager = instance.statistics_meta_manager statistics_meta_manager = instance.statistics_meta_manager
metadata_modified = False modified_statistic_ids: set[str] = set()
# Return if we already have 5-minute statistics for the requested period # Return if we already have 5-minute statistics for the requested period
if session.query(StatisticsRuns).filter_by(start=start).first(): if session.query(StatisticsRuns).filter_by(start=start).first():
_LOGGER.debug("Statistics already compiled for %s-%s", start, end) _LOGGER.debug("Statistics already compiled for %s-%s", start, end)
return metadata_modified return modified_statistic_ids
_LOGGER.debug("Compiling statistics for %s-%s", start, end) _LOGGER.debug("Compiling statistics for %s-%s", start, end)
platform_stats: list[StatisticResult] = [] platform_stats: list[StatisticResult] = []
@ -782,10 +793,11 @@ def _compile_statistics(
# Insert collected statistics in the database # Insert collected statistics in the database
for stats in platform_stats: for stats in platform_stats:
updated, metadata_id = statistics_meta_manager.update_or_add( modified_statistic_id, metadata_id = statistics_meta_manager.update_or_add(
session, stats["meta"], current_metadata session, stats["meta"], current_metadata
) )
metadata_modified |= updated if modified_statistic_id is not None:
modified_statistic_ids.add(modified_statistic_id)
_insert_statistics( _insert_statistics(
session, session,
StatisticsShortTerm, StatisticsShortTerm,
@ -804,7 +816,7 @@ def _compile_statistics(
if start.minute == 55: if start.minute == 55:
instance.hass.bus.fire(EVENT_RECORDER_HOURLY_STATISTICS_GENERATED) instance.hass.bus.fire(EVENT_RECORDER_HOURLY_STATISTICS_GENERATED)
return metadata_modified return modified_statistic_ids
def _adjust_sum_statistics( def _adjust_sum_statistics(
@ -882,7 +894,7 @@ def get_metadata_with_session(
instance: Recorder, instance: Recorder,
session: Session, session: Session,
*, *,
statistic_ids: list[str] | None = None, statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]: ) -> dict[str, tuple[int, StatisticMetaData]]:
@ -903,7 +915,7 @@ def get_metadata_with_session(
def get_metadata( def get_metadata(
hass: HomeAssistant, hass: HomeAssistant,
*, *,
statistic_ids: list[str] | None = None, statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]: ) -> dict[str, tuple[int, StatisticMetaData]]:
@ -947,9 +959,9 @@ def update_statistics_metadata(
) )
def list_statistic_ids( async def async_list_statistic_ids(
hass: HomeAssistant, hass: HomeAssistant,
statistic_ids: list[str] | None = None, statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
) -> list[dict]: ) -> list[dict]:
"""Return all statistic_ids (or filtered one) and unit of measurement. """Return all statistic_ids (or filtered one) and unit of measurement.
@ -958,16 +970,31 @@ def list_statistic_ids(
a recorder platform for statistic_ids which will be added in the next statistics a recorder platform for statistic_ids which will be added in the next statistics
period. period.
""" """
result = {} instance = get_instance(hass)
statistic_ids_set = set(statistic_ids) if statistic_ids else None
# Query the database if statistic_ids is not None:
with session_scope(hass=hass, read_only=True) as session: # Try to get the results from the cache since there is nearly
metadata = get_instance(hass).statistics_meta_manager.get_many( # always a cache hit.
session, statistic_type=statistic_type, statistic_ids=statistic_ids statistics_meta_manager = instance.statistics_meta_manager
metadata = statistics_meta_manager.get_from_cache_threadsafe(statistic_ids)
if not statistic_ids.difference(metadata):
result = _statistic_by_id_from_metadata(hass, metadata)
return _flatten_list_statistic_ids_metadata_result(result)
return await instance.async_add_executor_job(
list_statistic_ids,
hass,
statistic_ids,
statistic_type,
) )
result = {
def _statistic_by_id_from_metadata(
hass: HomeAssistant,
metadata: dict[str, tuple[int, StatisticMetaData]],
) -> dict[str, dict[str, Any]]:
"""Return a list of results for a given metadata dict."""
return {
meta["statistic_id"]: { meta["statistic_id"]: {
"display_unit_of_measurement": get_display_unit( "display_unit_of_measurement": get_display_unit(
hass, meta["statistic_id"], meta["unit_of_measurement"] hass, meta["statistic_id"], meta["unit_of_measurement"]
@ -982,7 +1009,49 @@ def list_statistic_ids(
for _, meta in metadata.values() for _, meta in metadata.values()
} }
if not statistic_ids_set or statistic_ids_set.difference(result):
def _flatten_list_statistic_ids_metadata_result(
result: dict[str, dict[str, Any]]
) -> list[dict]:
"""Return a flat dict of metadata."""
return [
{
"statistic_id": _id,
"display_unit_of_measurement": info["display_unit_of_measurement"],
"has_mean": info["has_mean"],
"has_sum": info["has_sum"],
"name": info.get("name"),
"source": info["source"],
"statistics_unit_of_measurement": info["unit_of_measurement"],
"unit_class": info["unit_class"],
}
for _id, info in result.items()
]
def list_statistic_ids(
hass: HomeAssistant,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
) -> list[dict]:
"""Return all statistic_ids (or filtered one) and unit of measurement.
Queries the database for existing statistic_ids, as well as integrations with
a recorder platform for statistic_ids which will be added in the next statistics
period.
"""
result = {}
instance = get_instance(hass)
statistics_meta_manager = instance.statistics_meta_manager
# Query the database
with session_scope(hass=hass, read_only=True) as session:
metadata = statistics_meta_manager.get_many(
session, statistic_type=statistic_type, statistic_ids=statistic_ids
)
result = _statistic_by_id_from_metadata(hass, metadata)
if not statistic_ids or statistic_ids.difference(result):
# If we want all statistic_ids, or some are missing, we need to query # If we want all statistic_ids, or some are missing, we need to query
# the integrations for the missing ones. # the integrations for the missing ones.
# #
@ -1009,19 +1078,7 @@ def list_statistic_ids(
} }
# Return a list of statistic_id + metadata # Return a list of statistic_id + metadata
return [ return _flatten_list_statistic_ids_metadata_result(result)
{
"statistic_id": _id,
"display_unit_of_measurement": info["display_unit_of_measurement"],
"has_mean": info["has_mean"],
"has_sum": info["has_sum"],
"name": info.get("name"),
"source": info["source"],
"statistics_unit_of_measurement": info["unit_of_measurement"],
"unit_class": info["unit_class"],
}
for _id, info in result.items()
]
def _reduce_statistics( def _reduce_statistics(
@ -1698,7 +1755,7 @@ def _statistics_during_period_with_session(
session: Session, session: Session,
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
statistic_ids: list[str] | None, statistic_ids: set[str] | None,
period: Literal["5minute", "day", "hour", "week", "month"], period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None, units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
@ -1708,6 +1765,10 @@ def _statistics_during_period_with_session(
If end_time is omitted, returns statistics newer than or equal to start_time. If end_time is omitted, returns statistics newer than or equal to start_time.
If statistic_ids is omitted, returns statistics for all statistics ids. If statistic_ids is omitted, returns statistics for all statistics ids.
""" """
if statistic_ids is not None and not isinstance(statistic_ids, set):
# This is for backwards compatibility to avoid a breaking change
# for custom integrations that call this method.
statistic_ids = set(statistic_ids) # type: ignore[unreachable]
metadata = None metadata = None
# Fetch metadata for the given (or all) statistic_ids # Fetch metadata for the given (or all) statistic_ids
metadata = get_instance(hass).statistics_meta_manager.get_many( metadata = get_instance(hass).statistics_meta_manager.get_many(
@ -1784,7 +1845,7 @@ def statistics_during_period(
hass: HomeAssistant, hass: HomeAssistant,
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
statistic_ids: list[str] | None, statistic_ids: set[str] | None,
period: Literal["5minute", "day", "hour", "week", "month"], period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None, units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
@ -1845,7 +1906,7 @@ def _get_last_statistics(
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[StatisticsRow]]: ) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats statistics for a given statistic_id.""" """Return the last number_of_stats statistics for a given statistic_id."""
statistic_ids = [statistic_id] statistic_ids = {statistic_id}
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id # Fetch metadata for the given statistic_id
metadata = get_instance(hass).statistics_meta_manager.get_many( metadata = get_instance(hass).statistics_meta_manager.get_many(
@ -1930,7 +1991,7 @@ def _latest_short_term_statistics_stmt(
def get_latest_short_term_statistics( def get_latest_short_term_statistics(
hass: HomeAssistant, hass: HomeAssistant,
statistic_ids: list[str], statistic_ids: set[str],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None, metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[StatisticsRow]]: ) -> dict[str, list[StatisticsRow]]:
@ -2031,7 +2092,7 @@ def _sorted_statistics_to_dict(
hass: HomeAssistant, hass: HomeAssistant,
session: Session, session: Session,
stats: Sequence[Row[Any]], stats: Sequence[Row[Any]],
statistic_ids: list[str] | None, statistic_ids: set[str] | None,
_metadata: dict[str, tuple[int, StatisticMetaData]], _metadata: dict[str, tuple[int, StatisticMetaData]],
convert_units: bool, convert_units: bool,
table: type[StatisticsBase], table: type[StatisticsBase],
@ -2294,7 +2355,7 @@ def _import_statistics_with_session(
"""Import statistics to the database.""" """Import statistics to the database."""
statistics_meta_manager = instance.statistics_meta_manager statistics_meta_manager = instance.statistics_meta_manager
old_metadata_dict = statistics_meta_manager.get_many( old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[metadata["statistic_id"]] session, statistic_ids={metadata["statistic_id"]}
) )
_, metadata_id = statistics_meta_manager.update_or_add( _, metadata_id = statistics_meta_manager.update_or_add(
session, metadata, old_metadata_dict session, metadata, old_metadata_dict
@ -2338,7 +2399,7 @@ def adjust_statistics(
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
metadata = instance.statistics_meta_manager.get_many( metadata = instance.statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id] session, statistic_ids={statistic_id}
) )
if statistic_id not in metadata: if statistic_id not in metadata:
return True return True
@ -2476,7 +2537,7 @@ def _validate_db_schema_utf8(
try: try:
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
old_metadata_dict = statistics_meta_manager.get_many( old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id] session, statistic_ids={statistic_id}
) )
try: try:
statistics_meta_manager.update_or_add( statistics_meta_manager.update_or_add(
@ -2573,7 +2634,7 @@ def _validate_db_schema(
session, session,
start_time, start_time,
None, None,
[statistic_id], {statistic_id},
"hour" if table == Statistics else "5minute", "hour" if table == Statistics else "5minute",
None, None,
{"last_reset", "max", "mean", "min", "state", "sum"}, {"last_reset", "max", "mean", "min", "state", "sum"},

View File

@ -34,7 +34,7 @@ QUERY_STATISTIC_META = (
def _generate_get_metadata_stmt( def _generate_get_metadata_stmt(
statistic_ids: list[str] | None = None, statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> StatementLambdaElement: ) -> StatementLambdaElement:
@ -89,7 +89,7 @@ class StatisticsMetaManager:
def _get_from_database( def _get_from_database(
self, self,
session: Session, session: Session,
statistic_ids: list[str] | None = None, statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]: ) -> dict[str, tuple[int, StatisticMetaData]]:
@ -112,6 +112,7 @@ class StatisticsMetaManager:
): ):
statistics_meta = cast(StatisticsMeta, row) statistics_meta = cast(StatisticsMeta, row)
id_meta = _statistics_meta_to_id_statistics_metadata(statistics_meta) id_meta = _statistics_meta_to_id_statistics_metadata(statistics_meta)
statistic_id = cast(str, statistics_meta.statistic_id) statistic_id = cast(str, statistics_meta.statistic_id)
results[statistic_id] = id_meta results[statistic_id] = id_meta
if update_cache: if update_cache:
@ -149,7 +150,7 @@ class StatisticsMetaManager:
statistic_id: str, statistic_id: str,
new_metadata: StatisticMetaData, new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]], old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]: ) -> tuple[str | None, int]:
"""Update metadata in the database. """Update metadata in the database.
This call is not thread-safe and must be called from the This call is not thread-safe and must be called from the
@ -163,7 +164,7 @@ class StatisticsMetaManager:
or old_metadata["unit_of_measurement"] or old_metadata["unit_of_measurement"]
!= new_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"]
): ):
return False, metadata_id return None, metadata_id
self._assert_in_recorder_thread() self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update( session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update(
@ -182,7 +183,7 @@ class StatisticsMetaManager:
old_metadata, old_metadata,
new_metadata, new_metadata,
) )
return True, metadata_id return statistic_id, metadata_id
def load(self, session: Session) -> None: def load(self, session: Session) -> None:
"""Load the statistic_id to metadata_id mapping into memory. """Load the statistic_id to metadata_id mapping into memory.
@ -196,12 +197,12 @@ class StatisticsMetaManager:
self, session: Session, statistic_id: str self, session: Session, statistic_id: str
) -> tuple[int, StatisticMetaData] | None: ) -> tuple[int, StatisticMetaData] | None:
"""Resolve statistic_id to the metadata_id.""" """Resolve statistic_id to the metadata_id."""
return self.get_many(session, [statistic_id]).get(statistic_id) return self.get_many(session, {statistic_id}).get(statistic_id)
def get_many( def get_many(
self, self,
session: Session, session: Session,
statistic_ids: list[str] | None = None, statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]: ) -> dict[str, tuple[int, StatisticMetaData]]:
@ -228,16 +229,8 @@ class StatisticsMetaManager:
"Providing statistic_type and statistic_source is mutually exclusive of statistic_ids" "Providing statistic_type and statistic_source is mutually exclusive of statistic_ids"
) )
results: dict[str, tuple[int, StatisticMetaData]] = {} results = self.get_from_cache_threadsafe(statistic_ids)
missing_statistic_id: list[str] = [] if not (missing_statistic_id := statistic_ids.difference(results)):
for statistic_id in statistic_ids:
if id_meta := self._stat_id_to_id_meta.get(statistic_id):
results[statistic_id] = id_meta
else:
missing_statistic_id.append(statistic_id)
if not missing_statistic_id:
return results return results
# Fetch metadata from the database # Fetch metadata from the database
@ -245,12 +238,29 @@ class StatisticsMetaManager:
session, statistic_ids=missing_statistic_id session, statistic_ids=missing_statistic_id
) )
def get_from_cache_threadsafe(
self, statistic_ids: set[str]
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Get metadata from cache.
This call is thread safe and can be run in the event loop,
the database executor, or the recorder thread.
"""
return {
statistic_id: id_meta
for statistic_id in statistic_ids
# We must use a get call here and never iterate over the dict
# because the dict can be modified by the recorder thread
# while we are iterating over it.
if (id_meta := self._stat_id_to_id_meta.get(statistic_id))
}
def update_or_add( def update_or_add(
self, self,
session: Session, session: Session,
new_metadata: StatisticMetaData, new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]], old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]: ) -> tuple[str | None, int]:
"""Get metadata_id for a statistic_id. """Get metadata_id for a statistic_id.
If the statistic_id is previously unknown, add it. If it's already known, update If the statistic_id is previously unknown, add it. If it's already known, update
@ -258,16 +268,16 @@ class StatisticsMetaManager:
Updating metadata source is not possible. Updating metadata source is not possible.
Returns a tuple of (updated, metadata_id). Returns a tuple of (statistic_id | None, metadata_id).
updated is True if the metadata was updated, False if it was not updated. statistic_id is None if the metadata was not updated
This call is not thread-safe and must be called from the This call is not thread-safe and must be called from the
recorder thread. recorder thread.
""" """
statistic_id = new_metadata["statistic_id"] statistic_id = new_metadata["statistic_id"]
if statistic_id not in old_metadata_dict: if statistic_id not in old_metadata_dict:
return True, self._add_metadata(session, statistic_id, new_metadata) return statistic_id, self._add_metadata(session, statistic_id, new_metadata)
return self._update_metadata( return self._update_metadata(
session, statistic_id, new_metadata, old_metadata_dict session, statistic_id, new_metadata, old_metadata_dict
) )
@ -319,4 +329,14 @@ class StatisticsMetaManager:
def reset(self) -> None: def reset(self) -> None:
"""Reset the cache.""" """Reset the cache."""
self._stat_id_to_id_meta = {} self._stat_id_to_id_meta.clear()
def adjust_lru_size(self, new_size: int) -> None:
"""Adjust the LRU cache size.
This call is not thread-safe and must be called from the
recorder thread.
"""
lru: LRU = self._stat_id_to_id_meta
if new_size > lru.get_size():
lru.set_size(new_size)

View File

@ -37,6 +37,7 @@ from .statistics import (
async_add_external_statistics, async_add_external_statistics,
async_change_statistics_unit, async_change_statistics_unit,
async_import_statistics, async_import_statistics,
async_list_statistic_ids,
list_statistic_ids, list_statistic_ids,
statistic_during_period, statistic_during_period,
statistics_during_period, statistics_during_period,
@ -151,7 +152,7 @@ def _ws_get_statistics_during_period(
msg_id: int, msg_id: int,
start_time: dt, start_time: dt,
end_time: dt | None, end_time: dt | None,
statistic_ids: list[str] | None, statistic_ids: set[str] | None,
period: Literal["5minute", "day", "hour", "week", "month"], period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str], units: dict[str, str],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
@ -208,7 +209,7 @@ async def ws_handle_get_statistics_during_period(
msg["id"], msg["id"],
start_time, start_time,
end_time, end_time,
msg["statistic_ids"], set(msg["statistic_ids"]),
msg.get("period"), msg.get("period"),
msg.get("units"), msg.get("units"),
types, types,
@ -329,11 +330,10 @@ async def ws_get_statistics_metadata(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get metadata for a list of statistic_ids.""" """Get metadata for a list of statistic_ids."""
instance = get_instance(hass) statistic_ids = msg.get("statistic_ids")
statistic_ids = await instance.async_add_executor_job( statistic_ids_set_or_none = set(statistic_ids) if statistic_ids else None
list_statistic_ids, hass, msg.get("statistic_ids") metadata = await async_list_statistic_ids(hass, statistic_ids_set_or_none)
) connection.send_result(msg["id"], metadata)
connection.send_result(msg["id"], statistic_ids)
@websocket_api.require_admin @websocket_api.require_admin
@ -413,7 +413,7 @@ async def ws_adjust_sum_statistics(
instance = get_instance(hass) instance = get_instance(hass)
metadatas = await instance.async_add_executor_job( metadatas = await instance.async_add_executor_job(
list_statistic_ids, hass, (msg["statistic_id"],) list_statistic_ids, hass, {msg["statistic_id"]}
) )
if not metadatas: if not metadatas:
connection.send_error(msg["id"], "unknown_statistic_id", "Unknown statistic ID") connection.send_error(msg["id"], "unknown_statistic_id", "Unknown statistic ID")

View File

@ -453,10 +453,10 @@ def _compile_statistics( # noqa: C901
# that are not in the metadata table and we are not working # that are not in the metadata table and we are not working
# with them anyway. # with them anyway.
old_metadatas = statistics.get_metadata_with_session( old_metadatas = statistics.get_metadata_with_session(
get_instance(hass), session, statistic_ids=list(entities_with_float_states) get_instance(hass), session, statistic_ids=set(entities_with_float_states)
) )
to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = [] to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = []
to_query: list[str] = [] to_query: set[str] = set()
for _state in sensor_states: for _state in sensor_states:
entity_id = _state.entity_id entity_id = _state.entity_id
if not (maybe_float_states := entities_with_float_states.get(entity_id)): if not (maybe_float_states := entities_with_float_states.get(entity_id)):
@ -472,7 +472,7 @@ def _compile_statistics( # noqa: C901
state_class: str = _state.attributes[ATTR_STATE_CLASS] state_class: str = _state.attributes[ATTR_STATE_CLASS]
to_process.append((entity_id, statistics_unit, state_class, valid_float_states)) to_process.append((entity_id, statistics_unit, state_class, valid_float_states))
if "sum" in wanted_statistics[entity_id]: if "sum" in wanted_statistics[entity_id]:
to_query.append(entity_id) to_query.add(entity_id)
last_stats = statistics.get_latest_short_term_statistics( last_stats = statistics.get_latest_short_term_statistics(
hass, to_query, {"last_reset", "state", "sum"}, metadata=old_metadatas hass, to_query, {"last_reset", "state", "sum"}, metadata=old_metadatas

View File

@ -636,7 +636,7 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]):
self.hass, self.hass,
start, start,
None, None,
[statistic_id], {statistic_id},
"hour", "hour",
None, None,
{"sum"}, {"sum"},

View File

@ -144,13 +144,15 @@ def statistics_during_period(
hass: HomeAssistant, hass: HomeAssistant,
start_time: datetime, start_time: datetime,
end_time: datetime | None = None, end_time: datetime | None = None,
statistic_ids: list[str] | None = None, statistic_ids: set[str] | None = None,
period: Literal["5minute", "day", "hour", "week", "month"] = "hour", period: Literal["5minute", "day", "hour", "week", "month"] = "hour",
units: dict[str, str] | None = None, units: dict[str, str] | None = None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]] types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]]
| None = None, | None = None,
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[dict[str, Any]]]:
"""Call statistics_during_period with defaults for simpler tests.""" """Call statistics_during_period with defaults for simpler tests."""
if statistic_ids is not None and not isinstance(statistic_ids, set):
statistic_ids = set(statistic_ids)
if types is None: if types is None:
types = {"last_reset", "max", "mean", "min", "state", "sum"} types = {"last_reset", "max", "mean", "min", "state", "sum"}
return statistics.statistics_during_period( return statistics.statistics_during_period(

View File

@ -292,6 +292,15 @@ class States(Base): # type: ignore[misc,valid-type]
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
origin_idx = Column(SmallInteger) # 0 is local, 1 is remote origin_idx = Column(SmallInteger) # 0 is local, 1 is remote
context_id_bin = Column(
LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH)
) # *** Not originally in v28, only added for recorder to startup ok
context_user_id_bin = Column(
LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH)
) # *** Not originally in v28, only added for recorder to startup ok
context_parent_id_bin = Column(
LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH)
) # *** Not originally in v28, only added for recorder to startup ok
metadata_id = Column( metadata_id = Column(
Integer, ForeignKey("states_meta.metadata_id"), index=True Integer, ForeignKey("states_meta.metadata_id"), index=True
) # *** Not originally in v28, only added for recorder to startup ok ) # *** Not originally in v28, only added for recorder to startup ok

View File

@ -26,12 +26,12 @@ async def test_passing_mutually_exclusive_options_to_get_many(
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
instance.statistics_meta_manager.get_many( instance.statistics_meta_manager.get_many(
session, statistic_ids=["light.kitchen"], statistic_source="sensor" session, statistic_ids={"light.kitchen"}, statistic_source="sensor"
) )
assert ( assert (
instance.statistics_meta_manager.get_many( instance.statistics_meta_manager.get_many(
session, session,
statistic_ids=["light.kitchen"], statistic_ids={"light.kitchen"},
) )
== {} == {}
) )

View File

@ -84,7 +84,7 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant])
# Should not fail if there is nothing there yet # Should not fail if there is nothing there yet
stats = get_latest_short_term_statistics( stats = get_latest_short_term_statistics(
hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"} hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"}
) )
assert stats == {} assert stats == {}
@ -169,15 +169,15 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant])
assert stats == {"sensor.test1": [expected_2]} assert stats == {"sensor.test1": [expected_2]}
stats = get_latest_short_term_statistics( stats = get_latest_short_term_statistics(
hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"} hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"}
) )
assert stats == {"sensor.test1": [expected_2]} assert stats == {"sensor.test1": [expected_2]}
metadata = get_metadata(hass, statistic_ids=['sensor.test1"']) metadata = get_metadata(hass, statistic_ids={"sensor.test1"})
stats = get_latest_short_term_statistics( stats = get_latest_short_term_statistics(
hass, hass,
["sensor.test1"], {"sensor.test1"},
{"last_reset", "max", "mean", "min", "state", "sum"}, {"last_reset", "max", "mean", "min", "state", "sum"},
metadata=metadata, metadata=metadata,
) )
@ -213,7 +213,7 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant])
instance.get_session().query(StatisticsShortTerm).delete() instance.get_session().query(StatisticsShortTerm).delete()
# Should not fail there is nothing in the table # Should not fail there is nothing in the table
stats = get_latest_short_term_statistics( stats = get_latest_short_term_statistics(
hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"} hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"}
) )
assert stats == {} assert stats == {}
@ -243,7 +243,7 @@ def mock_sensor_statistics():
sensor_stats("sensor.test3", start), sensor_stats("sensor.test3", start),
], ],
get_metadata( get_metadata(
_hass, statistic_ids=["sensor.test1", "sensor.test2", "sensor.test3"] _hass, statistic_ids={"sensor.test1", "sensor.test2", "sensor.test3"}
), ),
) )
@ -385,6 +385,27 @@ def test_rename_entity(hass_recorder: Callable[..., HomeAssistant]) -> None:
assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2} assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2}
def test_statistics_during_period_set_back_compat(
hass_recorder: Callable[..., HomeAssistant]
) -> None:
"""Test statistics_during_period can handle a list instead of a set."""
hass = hass_recorder()
setup_component(hass, "sensor", {})
# This should not throw an exception when passed a list instead of a set
assert (
statistics.statistics_during_period(
hass,
dt_util.utcnow(),
None,
statistic_ids=["sensor.test1"],
period="5minute",
units=None,
types=set(),
)
== {}
)
def test_rename_entity_collision( def test_rename_entity_collision(
hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
@ -595,7 +616,7 @@ async def test_import_statistics(
"unit_class": "energy", "unit_class": "energy",
} }
] ]
metadata = get_metadata(hass, statistic_ids=(statistic_id,)) metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == { assert metadata == {
statistic_id: ( statistic_id: (
1, 1,
@ -692,7 +713,7 @@ async def test_import_statistics(
"unit_class": "energy", "unit_class": "energy",
} }
] ]
metadata = get_metadata(hass, statistic_ids=(statistic_id,)) metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == { assert metadata == {
statistic_id: ( statistic_id: (
1, 1,
@ -814,7 +835,7 @@ def test_external_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics for the wrong domain # Attempt to insert statistics for the wrong domain
external_metadata = {**_external_metadata, "source": "other"} external_metadata = {**_external_metadata, "source": "other"}
@ -824,7 +845,7 @@ def test_external_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics for a naive starting time # Attempt to insert statistics for a naive starting time
external_metadata = {**_external_metadata} external_metadata = {**_external_metadata}
@ -837,7 +858,7 @@ def test_external_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics for an invalid starting time # Attempt to insert statistics for an invalid starting time
external_metadata = {**_external_metadata} external_metadata = {**_external_metadata}
@ -847,7 +868,7 @@ def test_external_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics with a naive last_reset # Attempt to insert statistics with a naive last_reset
external_metadata = {**_external_metadata} external_metadata = {**_external_metadata}
@ -860,7 +881,7 @@ def test_external_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
def test_import_statistics_errors( def test_import_statistics_errors(
@ -903,7 +924,7 @@ def test_import_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics for the wrong domain # Attempt to insert statistics for the wrong domain
external_metadata = {**_external_metadata, "source": "sensor"} external_metadata = {**_external_metadata, "source": "sensor"}
@ -913,7 +934,7 @@ def test_import_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics for a naive starting time # Attempt to insert statistics for a naive starting time
external_metadata = {**_external_metadata} external_metadata = {**_external_metadata}
@ -926,7 +947,7 @@ def test_import_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics for an invalid starting time # Attempt to insert statistics for an invalid starting time
external_metadata = {**_external_metadata} external_metadata = {**_external_metadata}
@ -936,7 +957,7 @@ def test_import_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics with a naive last_reset # Attempt to insert statistics with a naive last_reset
external_metadata = {**_external_metadata} external_metadata = {**_external_metadata}
@ -949,7 +970,7 @@ def test_import_statistics_errors(
wait_recording_done(hass) wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {} assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == [] assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
@pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"]) @pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"])

View File

@ -2588,7 +2588,7 @@ async def test_import_statistics(
"unit_class": "energy", "unit_class": "energy",
} }
] ]
metadata = get_metadata(hass, statistic_ids=(statistic_id,)) metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == { assert metadata == {
statistic_id: ( statistic_id: (
1, 1,
@ -2820,7 +2820,7 @@ async def test_adjust_sum_statistics_energy(
"unit_class": "energy", "unit_class": "energy",
} }
] ]
metadata = get_metadata(hass, statistic_ids=(statistic_id,)) metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == { assert metadata == {
statistic_id: ( statistic_id: (
1, 1,
@ -3016,7 +3016,7 @@ async def test_adjust_sum_statistics_gas(
"unit_class": "volume", "unit_class": "volume",
} }
] ]
metadata = get_metadata(hass, statistic_ids=(statistic_id,)) metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == { assert metadata == {
statistic_id: ( statistic_id: (
1, 1,
@ -3230,7 +3230,7 @@ async def test_adjust_sum_statistics_errors(
"unit_class": unit_class, "unit_class": unit_class,
} }
] ]
metadata = get_metadata(hass, statistic_ids=(statistic_id,)) metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == { assert metadata == {
statistic_id: ( statistic_id: (
1, 1,

View File

@ -3067,7 +3067,7 @@ def test_compile_hourly_statistics_changing_state_class(
"unit_class": unit_class, "unit_class": unit_class,
}, },
] ]
metadata = get_metadata(hass, statistic_ids=("sensor.test1",)) metadata = get_metadata(hass, statistic_ids={"sensor.test1"})
assert metadata == { assert metadata == {
"sensor.test1": ( "sensor.test1": (
1, 1,
@ -3103,7 +3103,7 @@ def test_compile_hourly_statistics_changing_state_class(
"unit_class": unit_class, "unit_class": unit_class,
}, },
] ]
metadata = get_metadata(hass, statistic_ids=("sensor.test1",)) metadata = get_metadata(hass, statistic_ids={"sensor.test1"})
assert metadata == { assert metadata == {
"sensor.test1": ( "sensor.test1": (
1, 1,

View File

@ -35,7 +35,7 @@ async def test_async_setup_entry(recorder_mock: Recorder, hass: HomeAssistant) -
hass, hass,
dt_util.parse_datetime(data[0]["from"]), dt_util.parse_datetime(data[0]["from"]),
None, None,
[statistic_id], {statistic_id},
"hour", "hour",
None, None,
{"start", "state", "mean", "min", "max", "last_reset", "sum"}, {"start", "state", "mean", "min", "max", "last_reset", "sum"},