Avoid database executor job to fetch statistic metadata on cache hit (#89960)

* Avoid database executor job to fetch statistic metadata on cache hit

Since we will almost always have a cache hit fetching
statistic meta data we can avoid an executor job

* Avoid database executor job to fetch statistic metadata on cache hit

Since we will almost always have a cache hit fetching
statistic meta data we can avoid an executor job

* Avoid database executor job to fetch statistic metadata on cache hit

Since we will almost always have a cache hit fetching
statistic meta data we can avoid an executor job

* remove exception catch since the threading.excepthook will actually catch this in production

* fix a few missed ones

* threadsafe

* Update homeassistant/components/recorder/table_managers/statistics_meta.py

* coverage and optimistic caching
This commit is contained in:
J. Nick Koston 2023-03-19 16:01:16 -10:00 committed by GitHub
parent d7de23fa65
commit 5ffb233004
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 232 additions and 118 deletions

View File

@ -603,7 +603,7 @@ async def async_validate(hass: HomeAssistant) -> EnergyPreferencesValidation:
functools.partial(
recorder.statistics.get_metadata,
hass,
statistic_ids=list(wanted_statistics_metadata),
statistic_ids=set(wanted_statistics_metadata),
)
)
)

View File

@ -262,8 +262,8 @@ async def ws_get_fossil_energy_consumption(
connection.send_error(msg["id"], "invalid_end_time", "Invalid end_time")
return
statistic_ids = list(msg["energy_statistic_ids"])
statistic_ids.append(msg["co2_statistic_id"])
statistic_ids = set(msg["energy_statistic_ids"])
statistic_ids.add(msg["co2_statistic_id"])
# Fetch energy + CO2 statistics
statistics = await recorder.get_instance(hass).async_add_executor_job(

View File

@ -501,6 +501,7 @@ class Recorder(threading.Thread):
new_size = self.hass.states.async_entity_ids_count() * 2
self.state_attributes_manager.adjust_lru_size(new_size)
self.states_meta_manager.adjust_lru_size(new_size)
self.statistics_meta_manager.adjust_lru_size(new_size)
@callback
def async_periodic_statistics(self) -> None:

View File

@ -713,10 +713,10 @@ def compile_missing_statistics(instance: Recorder) -> bool:
periods_without_commit += 1
end = start + timedelta(minutes=period_size)
_LOGGER.debug("Compiling missing statistics for %s-%s", start, end)
metadata_modified = _compile_statistics(
modified_statistic_ids = _compile_statistics(
instance, session, start, end >= last_period
)
if periods_without_commit == commit_interval or metadata_modified:
if periods_without_commit == commit_interval or modified_statistic_ids:
session.commit()
session.expunge_all()
periods_without_commit = 0
@ -736,29 +736,40 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) -
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
_compile_statistics(instance, session, start, fire_events)
modified_statistic_ids = _compile_statistics(
instance, session, start, fire_events
)
if modified_statistic_ids:
# In the rare case that we have modified statistic_ids, we reload the modified
# statistics meta data into the cache in a fresh session to ensure that the
# cache is up to date and future calls to get statistics meta data will
# not have to hit the database again.
with session_scope(session=instance.get_session(), read_only=True) as session:
instance.statistics_meta_manager.get_many(session, modified_statistic_ids)
return True
def _compile_statistics(
instance: Recorder, session: Session, start: datetime, fire_events: bool
) -> bool:
) -> set[str]:
"""Compile 5-minute statistics for all integrations with a recorder platform.
This is a helper function for compile_statistics and compile_missing_statistics
that does not retry on database errors since both callers already retry.
returns True if metadata was modified, False otherwise
returns a set of modified statistic_ids if any were modified.
"""
assert start.tzinfo == dt_util.UTC, "start must be in UTC"
end = start + timedelta(minutes=5)
statistics_meta_manager = instance.statistics_meta_manager
metadata_modified = False
modified_statistic_ids: set[str] = set()
# Return if we already have 5-minute statistics for the requested period
if session.query(StatisticsRuns).filter_by(start=start).first():
_LOGGER.debug("Statistics already compiled for %s-%s", start, end)
return metadata_modified
return modified_statistic_ids
_LOGGER.debug("Compiling statistics for %s-%s", start, end)
platform_stats: list[StatisticResult] = []
@ -782,10 +793,11 @@ def _compile_statistics(
# Insert collected statistics in the database
for stats in platform_stats:
updated, metadata_id = statistics_meta_manager.update_or_add(
modified_statistic_id, metadata_id = statistics_meta_manager.update_or_add(
session, stats["meta"], current_metadata
)
metadata_modified |= updated
if modified_statistic_id is not None:
modified_statistic_ids.add(modified_statistic_id)
_insert_statistics(
session,
StatisticsShortTerm,
@ -804,7 +816,7 @@ def _compile_statistics(
if start.minute == 55:
instance.hass.bus.fire(EVENT_RECORDER_HOURLY_STATISTICS_GENERATED)
return metadata_modified
return modified_statistic_ids
def _adjust_sum_statistics(
@ -882,7 +894,7 @@ def get_metadata_with_session(
instance: Recorder,
session: Session,
*,
statistic_ids: list[str] | None = None,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
@ -903,7 +915,7 @@ def get_metadata_with_session(
def get_metadata(
hass: HomeAssistant,
*,
statistic_ids: list[str] | None = None,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
@ -947,9 +959,79 @@ def update_statistics_metadata(
)
async def async_list_statistic_ids(
hass: HomeAssistant,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
) -> list[dict]:
"""Return all statistic_ids (or filtered one) and unit of measurement.
Queries the database for existing statistic_ids, as well as integrations with
a recorder platform for statistic_ids which will be added in the next statistics
period.
"""
instance = get_instance(hass)
if statistic_ids is not None:
# Try to get the results from the cache since there is nearly
# always a cache hit.
statistics_meta_manager = instance.statistics_meta_manager
metadata = statistics_meta_manager.get_from_cache_threadsafe(statistic_ids)
if not statistic_ids.difference(metadata):
result = _statistic_by_id_from_metadata(hass, metadata)
return _flatten_list_statistic_ids_metadata_result(result)
return await instance.async_add_executor_job(
list_statistic_ids,
hass,
statistic_ids,
statistic_type,
)
def _statistic_by_id_from_metadata(
hass: HomeAssistant,
metadata: dict[str, tuple[int, StatisticMetaData]],
) -> dict[str, dict[str, Any]]:
"""Return a list of results for a given metadata dict."""
return {
meta["statistic_id"]: {
"display_unit_of_measurement": get_display_unit(
hass, meta["statistic_id"], meta["unit_of_measurement"]
),
"has_mean": meta["has_mean"],
"has_sum": meta["has_sum"],
"name": meta["name"],
"source": meta["source"],
"unit_class": _get_unit_class(meta["unit_of_measurement"]),
"unit_of_measurement": meta["unit_of_measurement"],
}
for _, meta in metadata.values()
}
def _flatten_list_statistic_ids_metadata_result(
result: dict[str, dict[str, Any]]
) -> list[dict]:
"""Return a flat dict of metadata."""
return [
{
"statistic_id": _id,
"display_unit_of_measurement": info["display_unit_of_measurement"],
"has_mean": info["has_mean"],
"has_sum": info["has_sum"],
"name": info.get("name"),
"source": info["source"],
"statistics_unit_of_measurement": info["unit_of_measurement"],
"unit_class": info["unit_class"],
}
for _id, info in result.items()
]
def list_statistic_ids(
hass: HomeAssistant,
statistic_ids: list[str] | None = None,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
) -> list[dict]:
"""Return all statistic_ids (or filtered one) and unit of measurement.
@ -959,30 +1041,17 @@ def list_statistic_ids(
period.
"""
result = {}
statistic_ids_set = set(statistic_ids) if statistic_ids else None
instance = get_instance(hass)
statistics_meta_manager = instance.statistics_meta_manager
# Query the database
with session_scope(hass=hass, read_only=True) as session:
metadata = get_instance(hass).statistics_meta_manager.get_many(
metadata = statistics_meta_manager.get_many(
session, statistic_type=statistic_type, statistic_ids=statistic_ids
)
result = _statistic_by_id_from_metadata(hass, metadata)
result = {
meta["statistic_id"]: {
"display_unit_of_measurement": get_display_unit(
hass, meta["statistic_id"], meta["unit_of_measurement"]
),
"has_mean": meta["has_mean"],
"has_sum": meta["has_sum"],
"name": meta["name"],
"source": meta["source"],
"unit_class": _get_unit_class(meta["unit_of_measurement"]),
"unit_of_measurement": meta["unit_of_measurement"],
}
for _, meta in metadata.values()
}
if not statistic_ids_set or statistic_ids_set.difference(result):
if not statistic_ids or statistic_ids.difference(result):
# If we want all statistic_ids, or some are missing, we need to query
# the integrations for the missing ones.
#
@ -1009,19 +1078,7 @@ def list_statistic_ids(
}
# Return a list of statistic_id + metadata
return [
{
"statistic_id": _id,
"display_unit_of_measurement": info["display_unit_of_measurement"],
"has_mean": info["has_mean"],
"has_sum": info["has_sum"],
"name": info.get("name"),
"source": info["source"],
"statistics_unit_of_measurement": info["unit_of_measurement"],
"unit_class": info["unit_class"],
}
for _id, info in result.items()
]
return _flatten_list_statistic_ids_metadata_result(result)
def _reduce_statistics(
@ -1698,7 +1755,7 @@ def _statistics_during_period_with_session(
session: Session,
start_time: datetime,
end_time: datetime | None,
statistic_ids: list[str] | None,
statistic_ids: set[str] | None,
period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
@ -1708,6 +1765,10 @@ def _statistics_during_period_with_session(
If end_time is omitted, returns statistics newer than or equal to start_time.
If statistic_ids is omitted, returns statistics for all statistics ids.
"""
if statistic_ids is not None and not isinstance(statistic_ids, set):
# This is for backwards compatibility to avoid a breaking change
# for custom integrations that call this method.
statistic_ids = set(statistic_ids) # type: ignore[unreachable]
metadata = None
# Fetch metadata for the given (or all) statistic_ids
metadata = get_instance(hass).statistics_meta_manager.get_many(
@ -1784,7 +1845,7 @@ def statistics_during_period(
hass: HomeAssistant,
start_time: datetime,
end_time: datetime | None,
statistic_ids: list[str] | None,
statistic_ids: set[str] | None,
period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
@ -1845,7 +1906,7 @@ def _get_last_statistics(
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats statistics for a given statistic_id."""
statistic_ids = [statistic_id]
statistic_ids = {statistic_id}
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_id
metadata = get_instance(hass).statistics_meta_manager.get_many(
@ -1930,7 +1991,7 @@ def _latest_short_term_statistics_stmt(
def get_latest_short_term_statistics(
hass: HomeAssistant,
statistic_ids: list[str],
statistic_ids: set[str],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[StatisticsRow]]:
@ -2031,7 +2092,7 @@ def _sorted_statistics_to_dict(
hass: HomeAssistant,
session: Session,
stats: Sequence[Row[Any]],
statistic_ids: list[str] | None,
statistic_ids: set[str] | None,
_metadata: dict[str, tuple[int, StatisticMetaData]],
convert_units: bool,
table: type[StatisticsBase],
@ -2294,7 +2355,7 @@ def _import_statistics_with_session(
"""Import statistics to the database."""
statistics_meta_manager = instance.statistics_meta_manager
old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[metadata["statistic_id"]]
session, statistic_ids={metadata["statistic_id"]}
)
_, metadata_id = statistics_meta_manager.update_or_add(
session, metadata, old_metadata_dict
@ -2338,7 +2399,7 @@ def adjust_statistics(
with session_scope(session=instance.get_session()) as session:
metadata = instance.statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id]
session, statistic_ids={statistic_id}
)
if statistic_id not in metadata:
return True
@ -2476,7 +2537,7 @@ def _validate_db_schema_utf8(
try:
with session_scope(session=session_maker()) as session:
old_metadata_dict = statistics_meta_manager.get_many(
session, statistic_ids=[statistic_id]
session, statistic_ids={statistic_id}
)
try:
statistics_meta_manager.update_or_add(
@ -2573,7 +2634,7 @@ def _validate_db_schema(
session,
start_time,
None,
[statistic_id],
{statistic_id},
"hour" if table == Statistics else "5minute",
None,
{"last_reset", "max", "mean", "min", "state", "sum"},

View File

@ -34,7 +34,7 @@ QUERY_STATISTIC_META = (
def _generate_get_metadata_stmt(
statistic_ids: list[str] | None = None,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> StatementLambdaElement:
@ -89,7 +89,7 @@ class StatisticsMetaManager:
def _get_from_database(
self,
session: Session,
statistic_ids: list[str] | None = None,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
@ -112,6 +112,7 @@ class StatisticsMetaManager:
):
statistics_meta = cast(StatisticsMeta, row)
id_meta = _statistics_meta_to_id_statistics_metadata(statistics_meta)
statistic_id = cast(str, statistics_meta.statistic_id)
results[statistic_id] = id_meta
if update_cache:
@ -149,7 +150,7 @@ class StatisticsMetaManager:
statistic_id: str,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]:
) -> tuple[str | None, int]:
"""Update metadata in the database.
This call is not thread-safe and must be called from the
@ -163,7 +164,7 @@ class StatisticsMetaManager:
or old_metadata["unit_of_measurement"]
!= new_metadata["unit_of_measurement"]
):
return False, metadata_id
return None, metadata_id
self._assert_in_recorder_thread()
session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update(
@ -182,7 +183,7 @@ class StatisticsMetaManager:
old_metadata,
new_metadata,
)
return True, metadata_id
return statistic_id, metadata_id
def load(self, session: Session) -> None:
"""Load the statistic_id to metadata_id mapping into memory.
@ -196,12 +197,12 @@ class StatisticsMetaManager:
self, session: Session, statistic_id: str
) -> tuple[int, StatisticMetaData] | None:
"""Resolve statistic_id to the metadata_id."""
return self.get_many(session, [statistic_id]).get(statistic_id)
return self.get_many(session, {statistic_id}).get(statistic_id)
def get_many(
self,
session: Session,
statistic_ids: list[str] | None = None,
statistic_ids: set[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
@ -228,16 +229,8 @@ class StatisticsMetaManager:
"Providing statistic_type and statistic_source is mutually exclusive of statistic_ids"
)
results: dict[str, tuple[int, StatisticMetaData]] = {}
missing_statistic_id: list[str] = []
for statistic_id in statistic_ids:
if id_meta := self._stat_id_to_id_meta.get(statistic_id):
results[statistic_id] = id_meta
else:
missing_statistic_id.append(statistic_id)
if not missing_statistic_id:
results = self.get_from_cache_threadsafe(statistic_ids)
if not (missing_statistic_id := statistic_ids.difference(results)):
return results
# Fetch metadata from the database
@ -245,12 +238,29 @@ class StatisticsMetaManager:
session, statistic_ids=missing_statistic_id
)
def get_from_cache_threadsafe(
self, statistic_ids: set[str]
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Get metadata from cache.
This call is thread safe and can be run in the event loop,
the database executor, or the recorder thread.
"""
return {
statistic_id: id_meta
for statistic_id in statistic_ids
# We must use a get call here and never iterate over the dict
# because the dict can be modified by the recorder thread
# while we are iterating over it.
if (id_meta := self._stat_id_to_id_meta.get(statistic_id))
}
def update_or_add(
self,
session: Session,
new_metadata: StatisticMetaData,
old_metadata_dict: dict[str, tuple[int, StatisticMetaData]],
) -> tuple[bool, int]:
) -> tuple[str | None, int]:
"""Get metadata_id for a statistic_id.
If the statistic_id is previously unknown, add it. If it's already known, update
@ -258,16 +268,16 @@ class StatisticsMetaManager:
Updating metadata source is not possible.
Returns a tuple of (updated, metadata_id).
Returns a tuple of (statistic_id | None, metadata_id).
updated is True if the metadata was updated, False if it was not updated.
statistic_id is None if the metadata was not updated
This call is not thread-safe and must be called from the
recorder thread.
"""
statistic_id = new_metadata["statistic_id"]
if statistic_id not in old_metadata_dict:
return True, self._add_metadata(session, statistic_id, new_metadata)
return statistic_id, self._add_metadata(session, statistic_id, new_metadata)
return self._update_metadata(
session, statistic_id, new_metadata, old_metadata_dict
)
@ -319,4 +329,14 @@ class StatisticsMetaManager:
def reset(self) -> None:
"""Reset the cache."""
self._stat_id_to_id_meta = {}
self._stat_id_to_id_meta.clear()
def adjust_lru_size(self, new_size: int) -> None:
"""Adjust the LRU cache size.
This call is not thread-safe and must be called from the
recorder thread.
"""
lru: LRU = self._stat_id_to_id_meta
if new_size > lru.get_size():
lru.set_size(new_size)

View File

@ -37,6 +37,7 @@ from .statistics import (
async_add_external_statistics,
async_change_statistics_unit,
async_import_statistics,
async_list_statistic_ids,
list_statistic_ids,
statistic_during_period,
statistics_during_period,
@ -151,7 +152,7 @@ def _ws_get_statistics_during_period(
msg_id: int,
start_time: dt,
end_time: dt | None,
statistic_ids: list[str] | None,
statistic_ids: set[str] | None,
period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
@ -208,7 +209,7 @@ async def ws_handle_get_statistics_during_period(
msg["id"],
start_time,
end_time,
msg["statistic_ids"],
set(msg["statistic_ids"]),
msg.get("period"),
msg.get("units"),
types,
@ -329,11 +330,10 @@ async def ws_get_statistics_metadata(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Get metadata for a list of statistic_ids."""
instance = get_instance(hass)
statistic_ids = await instance.async_add_executor_job(
list_statistic_ids, hass, msg.get("statistic_ids")
)
connection.send_result(msg["id"], statistic_ids)
statistic_ids = msg.get("statistic_ids")
statistic_ids_set_or_none = set(statistic_ids) if statistic_ids else None
metadata = await async_list_statistic_ids(hass, statistic_ids_set_or_none)
connection.send_result(msg["id"], metadata)
@websocket_api.require_admin
@ -413,7 +413,7 @@ async def ws_adjust_sum_statistics(
instance = get_instance(hass)
metadatas = await instance.async_add_executor_job(
list_statistic_ids, hass, (msg["statistic_id"],)
list_statistic_ids, hass, {msg["statistic_id"]}
)
if not metadatas:
connection.send_error(msg["id"], "unknown_statistic_id", "Unknown statistic ID")

View File

@ -453,10 +453,10 @@ def _compile_statistics( # noqa: C901
# that are not in the metadata table and we are not working
# with them anyway.
old_metadatas = statistics.get_metadata_with_session(
get_instance(hass), session, statistic_ids=list(entities_with_float_states)
get_instance(hass), session, statistic_ids=set(entities_with_float_states)
)
to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = []
to_query: list[str] = []
to_query: set[str] = set()
for _state in sensor_states:
entity_id = _state.entity_id
if not (maybe_float_states := entities_with_float_states.get(entity_id)):
@ -472,7 +472,7 @@ def _compile_statistics( # noqa: C901
state_class: str = _state.attributes[ATTR_STATE_CLASS]
to_process.append((entity_id, statistics_unit, state_class, valid_float_states))
if "sum" in wanted_statistics[entity_id]:
to_query.append(entity_id)
to_query.add(entity_id)
last_stats = statistics.get_latest_short_term_statistics(
hass, to_query, {"last_reset", "state", "sum"}, metadata=old_metadatas

View File

@ -636,7 +636,7 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]):
self.hass,
start,
None,
[statistic_id],
{statistic_id},
"hour",
None,
{"sum"},

View File

@ -144,13 +144,15 @@ def statistics_during_period(
hass: HomeAssistant,
start_time: datetime,
end_time: datetime | None = None,
statistic_ids: list[str] | None = None,
statistic_ids: set[str] | None = None,
period: Literal["5minute", "day", "hour", "week", "month"] = "hour",
units: dict[str, str] | None = None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]]
| None = None,
) -> dict[str, list[dict[str, Any]]]:
"""Call statistics_during_period with defaults for simpler tests."""
if statistic_ids is not None and not isinstance(statistic_ids, set):
statistic_ids = set(statistic_ids)
if types is None:
types = {"last_reset", "max", "mean", "min", "state", "sum"}
return statistics.statistics_during_period(

View File

@ -292,6 +292,15 @@ class States(Base): # type: ignore[misc,valid-type]
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
origin_idx = Column(SmallInteger) # 0 is local, 1 is remote
context_id_bin = Column(
LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH)
) # *** Not originally in v28, only added for recorder to startup ok
context_user_id_bin = Column(
LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH)
) # *** Not originally in v28, only added for recorder to startup ok
context_parent_id_bin = Column(
LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH)
) # *** Not originally in v28, only added for recorder to startup ok
metadata_id = Column(
Integer, ForeignKey("states_meta.metadata_id"), index=True
) # *** Not originally in v28, only added for recorder to startup ok

View File

@ -26,12 +26,12 @@ async def test_passing_mutually_exclusive_options_to_get_many(
)
with pytest.raises(ValueError):
instance.statistics_meta_manager.get_many(
session, statistic_ids=["light.kitchen"], statistic_source="sensor"
session, statistic_ids={"light.kitchen"}, statistic_source="sensor"
)
assert (
instance.statistics_meta_manager.get_many(
session,
statistic_ids=["light.kitchen"],
statistic_ids={"light.kitchen"},
)
== {}
)

View File

@ -84,7 +84,7 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant])
# Should not fail if there is nothing there yet
stats = get_latest_short_term_statistics(
hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"}
hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"}
)
assert stats == {}
@ -169,15 +169,15 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant])
assert stats == {"sensor.test1": [expected_2]}
stats = get_latest_short_term_statistics(
hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"}
hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"}
)
assert stats == {"sensor.test1": [expected_2]}
metadata = get_metadata(hass, statistic_ids=['sensor.test1"'])
metadata = get_metadata(hass, statistic_ids={"sensor.test1"})
stats = get_latest_short_term_statistics(
hass,
["sensor.test1"],
{"sensor.test1"},
{"last_reset", "max", "mean", "min", "state", "sum"},
metadata=metadata,
)
@ -213,7 +213,7 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant])
instance.get_session().query(StatisticsShortTerm).delete()
# Should not fail there is nothing in the table
stats = get_latest_short_term_statistics(
hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"}
hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"}
)
assert stats == {}
@ -243,7 +243,7 @@ def mock_sensor_statistics():
sensor_stats("sensor.test3", start),
],
get_metadata(
_hass, statistic_ids=["sensor.test1", "sensor.test2", "sensor.test3"]
_hass, statistic_ids={"sensor.test1", "sensor.test2", "sensor.test3"}
),
)
@ -385,6 +385,27 @@ def test_rename_entity(hass_recorder: Callable[..., HomeAssistant]) -> None:
assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2}
def test_statistics_during_period_set_back_compat(
hass_recorder: Callable[..., HomeAssistant]
) -> None:
"""Test statistics_during_period can handle a list instead of a set."""
hass = hass_recorder()
setup_component(hass, "sensor", {})
# This should not throw an exception when passed a list instead of a set
assert (
statistics.statistics_during_period(
hass,
dt_util.utcnow(),
None,
statistic_ids=["sensor.test1"],
period="5minute",
units=None,
types=set(),
)
== {}
)
def test_rename_entity_collision(
hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture
) -> None:
@ -595,7 +616,7 @@ async def test_import_statistics(
"unit_class": "energy",
}
]
metadata = get_metadata(hass, statistic_ids=(statistic_id,))
metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == {
statistic_id: (
1,
@ -692,7 +713,7 @@ async def test_import_statistics(
"unit_class": "energy",
}
]
metadata = get_metadata(hass, statistic_ids=(statistic_id,))
metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == {
statistic_id: (
1,
@ -814,7 +835,7 @@ def test_external_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics for the wrong domain
external_metadata = {**_external_metadata, "source": "other"}
@ -824,7 +845,7 @@ def test_external_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics for a naive starting time
external_metadata = {**_external_metadata}
@ -837,7 +858,7 @@ def test_external_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics for an invalid starting time
external_metadata = {**_external_metadata}
@ -847,7 +868,7 @@ def test_external_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics with a naive last_reset
external_metadata = {**_external_metadata}
@ -860,7 +881,7 @@ def test_external_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
def test_import_statistics_errors(
@ -903,7 +924,7 @@ def test_import_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {}
# Attempt to insert statistics for the wrong domain
external_metadata = {**_external_metadata, "source": "sensor"}
@ -913,7 +934,7 @@ def test_import_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics for a naive starting time
external_metadata = {**_external_metadata}
@ -926,7 +947,7 @@ def test_import_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics for an invalid starting time
external_metadata = {**_external_metadata}
@ -936,7 +957,7 @@ def test_import_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
# Attempt to insert statistics with a naive last_reset
external_metadata = {**_external_metadata}
@ -949,7 +970,7 @@ def test_import_statistics_errors(
wait_recording_done(hass)
assert statistics_during_period(hass, zero, period="hour") == {}
assert list_statistic_ids(hass) == []
assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {}
assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {}
@pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"])

View File

@ -2588,7 +2588,7 @@ async def test_import_statistics(
"unit_class": "energy",
}
]
metadata = get_metadata(hass, statistic_ids=(statistic_id,))
metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == {
statistic_id: (
1,
@ -2820,7 +2820,7 @@ async def test_adjust_sum_statistics_energy(
"unit_class": "energy",
}
]
metadata = get_metadata(hass, statistic_ids=(statistic_id,))
metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == {
statistic_id: (
1,
@ -3016,7 +3016,7 @@ async def test_adjust_sum_statistics_gas(
"unit_class": "volume",
}
]
metadata = get_metadata(hass, statistic_ids=(statistic_id,))
metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == {
statistic_id: (
1,
@ -3230,7 +3230,7 @@ async def test_adjust_sum_statistics_errors(
"unit_class": unit_class,
}
]
metadata = get_metadata(hass, statistic_ids=(statistic_id,))
metadata = get_metadata(hass, statistic_ids={statistic_id})
assert metadata == {
statistic_id: (
1,

View File

@ -3067,7 +3067,7 @@ def test_compile_hourly_statistics_changing_state_class(
"unit_class": unit_class,
},
]
metadata = get_metadata(hass, statistic_ids=("sensor.test1",))
metadata = get_metadata(hass, statistic_ids={"sensor.test1"})
assert metadata == {
"sensor.test1": (
1,
@ -3103,7 +3103,7 @@ def test_compile_hourly_statistics_changing_state_class(
"unit_class": unit_class,
},
]
metadata = get_metadata(hass, statistic_ids=("sensor.test1",))
metadata = get_metadata(hass, statistic_ids={"sensor.test1"})
assert metadata == {
"sensor.test1": (
1,

View File

@ -35,7 +35,7 @@ async def test_async_setup_entry(recorder_mock: Recorder, hass: HomeAssistant) -
hass,
dt_util.parse_datetime(data[0]["from"]),
None,
[statistic_id],
{statistic_id},
"hour",
None,
{"start", "state", "mean", "min", "max", "last_reset", "sum"},