From 5ffb2330043ecfe36a4c8272d239b0883a5d6339 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 19 Mar 2023 16:01:16 -1000 Subject: [PATCH] Avoid database executor job to fetch statistic metadata on cache hit (#89960) * Avoid database executor job to fetch statistic metadata on cache hit Since we will almost always have a cache hit fetching statistic meta data we can avoid an executor job * Avoid database executor job to fetch statistic metadata on cache hit Since we will almost always have a cache hit fetching statistic meta data we can avoid an executor job * Avoid database executor job to fetch statistic metadata on cache hit Since we will almost always have a cache hit fetching statistic meta data we can avoid an executor job * remove exception catch since the threading.excepthook will actually catch this in production * fix a few missed ones * threadsafe * Update homeassistant/components/recorder/table_managers/statistics_meta.py * coverage and optimistic caching --- homeassistant/components/energy/validate.py | 2 +- .../components/energy/websocket_api.py | 4 +- homeassistant/components/recorder/core.py | 1 + .../components/recorder/statistics.py | 167 ++++++++++++------ .../table_managers/statistics_meta.py | 64 ++++--- .../components/recorder/websocket_api.py | 16 +- homeassistant/components/sensor/recorder.py | 6 +- homeassistant/components/tibber/sensor.py | 2 +- tests/components/recorder/common.py | 4 +- tests/components/recorder/db_schema_28.py | 9 + .../table_managers/test_statistics_meta.py | 4 +- tests/components/recorder/test_statistics.py | 57 ++++-- .../components/recorder/test_websocket_api.py | 8 +- tests/components/sensor/test_recorder.py | 4 +- tests/components/tibber/test_statistics.py | 2 +- 15 files changed, 232 insertions(+), 118 deletions(-) diff --git a/homeassistant/components/energy/validate.py b/homeassistant/components/energy/validate.py index a2c3ad094da..0a89c3d9270 100644 --- a/homeassistant/components/energy/validate.py +++ b/homeassistant/components/energy/validate.py @@ -603,7 +603,7 @@ async def async_validate(hass: HomeAssistant) -> EnergyPreferencesValidation: functools.partial( recorder.statistics.get_metadata, hass, - statistic_ids=list(wanted_statistics_metadata), + statistic_ids=set(wanted_statistics_metadata), ) ) ) diff --git a/homeassistant/components/energy/websocket_api.py b/homeassistant/components/energy/websocket_api.py index 15ffc6a2804..7830d3649f2 100644 --- a/homeassistant/components/energy/websocket_api.py +++ b/homeassistant/components/energy/websocket_api.py @@ -262,8 +262,8 @@ async def ws_get_fossil_energy_consumption( connection.send_error(msg["id"], "invalid_end_time", "Invalid end_time") return - statistic_ids = list(msg["energy_statistic_ids"]) - statistic_ids.append(msg["co2_statistic_id"]) + statistic_ids = set(msg["energy_statistic_ids"]) + statistic_ids.add(msg["co2_statistic_id"]) # Fetch energy + CO2 statistics statistics = await recorder.get_instance(hass).async_add_executor_job( diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 538d07eb4d7..30dd311c0e6 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -501,6 +501,7 @@ class Recorder(threading.Thread): new_size = self.hass.states.async_entity_ids_count() * 2 self.state_attributes_manager.adjust_lru_size(new_size) self.states_meta_manager.adjust_lru_size(new_size) + self.statistics_meta_manager.adjust_lru_size(new_size) @callback def async_periodic_statistics(self) -> None: diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index fcd934270d1..2f2deeeaeee 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -713,10 +713,10 @@ def compile_missing_statistics(instance: Recorder) -> bool: periods_without_commit += 1 end = start + timedelta(minutes=period_size) _LOGGER.debug("Compiling missing statistics for %s-%s", start, end) - metadata_modified = _compile_statistics( + modified_statistic_ids = _compile_statistics( instance, session, start, end >= last_period ) - if periods_without_commit == commit_interval or metadata_modified: + if periods_without_commit == commit_interval or modified_statistic_ids: session.commit() session.expunge_all() periods_without_commit = 0 @@ -736,29 +736,40 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) - session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: - _compile_statistics(instance, session, start, fire_events) + modified_statistic_ids = _compile_statistics( + instance, session, start, fire_events + ) + + if modified_statistic_ids: + # In the rare case that we have modified statistic_ids, we reload the modified + # statistics meta data into the cache in a fresh session to ensure that the + # cache is up to date and future calls to get statistics meta data will + # not have to hit the database again. + with session_scope(session=instance.get_session(), read_only=True) as session: + instance.statistics_meta_manager.get_many(session, modified_statistic_ids) + return True def _compile_statistics( instance: Recorder, session: Session, start: datetime, fire_events: bool -) -> bool: +) -> set[str]: """Compile 5-minute statistics for all integrations with a recorder platform. This is a helper function for compile_statistics and compile_missing_statistics that does not retry on database errors since both callers already retry. - returns True if metadata was modified, False otherwise + returns a set of modified statistic_ids if any were modified. """ assert start.tzinfo == dt_util.UTC, "start must be in UTC" end = start + timedelta(minutes=5) statistics_meta_manager = instance.statistics_meta_manager - metadata_modified = False + modified_statistic_ids: set[str] = set() # Return if we already have 5-minute statistics for the requested period if session.query(StatisticsRuns).filter_by(start=start).first(): _LOGGER.debug("Statistics already compiled for %s-%s", start, end) - return metadata_modified + return modified_statistic_ids _LOGGER.debug("Compiling statistics for %s-%s", start, end) platform_stats: list[StatisticResult] = [] @@ -782,10 +793,11 @@ def _compile_statistics( # Insert collected statistics in the database for stats in platform_stats: - updated, metadata_id = statistics_meta_manager.update_or_add( + modified_statistic_id, metadata_id = statistics_meta_manager.update_or_add( session, stats["meta"], current_metadata ) - metadata_modified |= updated + if modified_statistic_id is not None: + modified_statistic_ids.add(modified_statistic_id) _insert_statistics( session, StatisticsShortTerm, @@ -804,7 +816,7 @@ def _compile_statistics( if start.minute == 55: instance.hass.bus.fire(EVENT_RECORDER_HOURLY_STATISTICS_GENERATED) - return metadata_modified + return modified_statistic_ids def _adjust_sum_statistics( @@ -882,7 +894,7 @@ def get_metadata_with_session( instance: Recorder, session: Session, *, - statistic_ids: list[str] | None = None, + statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: @@ -903,7 +915,7 @@ def get_metadata_with_session( def get_metadata( hass: HomeAssistant, *, - statistic_ids: list[str] | None = None, + statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: @@ -947,9 +959,79 @@ def update_statistics_metadata( ) +async def async_list_statistic_ids( + hass: HomeAssistant, + statistic_ids: set[str] | None = None, + statistic_type: Literal["mean"] | Literal["sum"] | None = None, +) -> list[dict]: + """Return all statistic_ids (or filtered one) and unit of measurement. + + Queries the database for existing statistic_ids, as well as integrations with + a recorder platform for statistic_ids which will be added in the next statistics + period. + """ + instance = get_instance(hass) + + if statistic_ids is not None: + # Try to get the results from the cache since there is nearly + # always a cache hit. + statistics_meta_manager = instance.statistics_meta_manager + metadata = statistics_meta_manager.get_from_cache_threadsafe(statistic_ids) + if not statistic_ids.difference(metadata): + result = _statistic_by_id_from_metadata(hass, metadata) + return _flatten_list_statistic_ids_metadata_result(result) + + return await instance.async_add_executor_job( + list_statistic_ids, + hass, + statistic_ids, + statistic_type, + ) + + +def _statistic_by_id_from_metadata( + hass: HomeAssistant, + metadata: dict[str, tuple[int, StatisticMetaData]], +) -> dict[str, dict[str, Any]]: + """Return a list of results for a given metadata dict.""" + return { + meta["statistic_id"]: { + "display_unit_of_measurement": get_display_unit( + hass, meta["statistic_id"], meta["unit_of_measurement"] + ), + "has_mean": meta["has_mean"], + "has_sum": meta["has_sum"], + "name": meta["name"], + "source": meta["source"], + "unit_class": _get_unit_class(meta["unit_of_measurement"]), + "unit_of_measurement": meta["unit_of_measurement"], + } + for _, meta in metadata.values() + } + + +def _flatten_list_statistic_ids_metadata_result( + result: dict[str, dict[str, Any]] +) -> list[dict]: + """Return a flat dict of metadata.""" + return [ + { + "statistic_id": _id, + "display_unit_of_measurement": info["display_unit_of_measurement"], + "has_mean": info["has_mean"], + "has_sum": info["has_sum"], + "name": info.get("name"), + "source": info["source"], + "statistics_unit_of_measurement": info["unit_of_measurement"], + "unit_class": info["unit_class"], + } + for _id, info in result.items() + ] + + def list_statistic_ids( hass: HomeAssistant, - statistic_ids: list[str] | None = None, + statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, ) -> list[dict]: """Return all statistic_ids (or filtered one) and unit of measurement. @@ -959,30 +1041,17 @@ def list_statistic_ids( period. """ result = {} - statistic_ids_set = set(statistic_ids) if statistic_ids else None + instance = get_instance(hass) + statistics_meta_manager = instance.statistics_meta_manager # Query the database with session_scope(hass=hass, read_only=True) as session: - metadata = get_instance(hass).statistics_meta_manager.get_many( + metadata = statistics_meta_manager.get_many( session, statistic_type=statistic_type, statistic_ids=statistic_ids ) + result = _statistic_by_id_from_metadata(hass, metadata) - result = { - meta["statistic_id"]: { - "display_unit_of_measurement": get_display_unit( - hass, meta["statistic_id"], meta["unit_of_measurement"] - ), - "has_mean": meta["has_mean"], - "has_sum": meta["has_sum"], - "name": meta["name"], - "source": meta["source"], - "unit_class": _get_unit_class(meta["unit_of_measurement"]), - "unit_of_measurement": meta["unit_of_measurement"], - } - for _, meta in metadata.values() - } - - if not statistic_ids_set or statistic_ids_set.difference(result): + if not statistic_ids or statistic_ids.difference(result): # If we want all statistic_ids, or some are missing, we need to query # the integrations for the missing ones. # @@ -1009,19 +1078,7 @@ def list_statistic_ids( } # Return a list of statistic_id + metadata - return [ - { - "statistic_id": _id, - "display_unit_of_measurement": info["display_unit_of_measurement"], - "has_mean": info["has_mean"], - "has_sum": info["has_sum"], - "name": info.get("name"), - "source": info["source"], - "statistics_unit_of_measurement": info["unit_of_measurement"], - "unit_class": info["unit_class"], - } - for _id, info in result.items() - ] + return _flatten_list_statistic_ids_metadata_result(result) def _reduce_statistics( @@ -1698,7 +1755,7 @@ def _statistics_during_period_with_session( session: Session, start_time: datetime, end_time: datetime | None, - statistic_ids: list[str] | None, + statistic_ids: set[str] | None, period: Literal["5minute", "day", "hour", "week", "month"], units: dict[str, str] | None, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], @@ -1708,6 +1765,10 @@ def _statistics_during_period_with_session( If end_time is omitted, returns statistics newer than or equal to start_time. If statistic_ids is omitted, returns statistics for all statistics ids. """ + if statistic_ids is not None and not isinstance(statistic_ids, set): + # This is for backwards compatibility to avoid a breaking change + # for custom integrations that call this method. + statistic_ids = set(statistic_ids) # type: ignore[unreachable] metadata = None # Fetch metadata for the given (or all) statistic_ids metadata = get_instance(hass).statistics_meta_manager.get_many( @@ -1784,7 +1845,7 @@ def statistics_during_period( hass: HomeAssistant, start_time: datetime, end_time: datetime | None, - statistic_ids: list[str] | None, + statistic_ids: set[str] | None, period: Literal["5minute", "day", "hour", "week", "month"], units: dict[str, str] | None, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], @@ -1845,7 +1906,7 @@ def _get_last_statistics( types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], ) -> dict[str, list[StatisticsRow]]: """Return the last number_of_stats statistics for a given statistic_id.""" - statistic_ids = [statistic_id] + statistic_ids = {statistic_id} with session_scope(hass=hass, read_only=True) as session: # Fetch metadata for the given statistic_id metadata = get_instance(hass).statistics_meta_manager.get_many( @@ -1930,7 +1991,7 @@ def _latest_short_term_statistics_stmt( def get_latest_short_term_statistics( hass: HomeAssistant, - statistic_ids: list[str], + statistic_ids: set[str], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], metadata: dict[str, tuple[int, StatisticMetaData]] | None = None, ) -> dict[str, list[StatisticsRow]]: @@ -2031,7 +2092,7 @@ def _sorted_statistics_to_dict( hass: HomeAssistant, session: Session, stats: Sequence[Row[Any]], - statistic_ids: list[str] | None, + statistic_ids: set[str] | None, _metadata: dict[str, tuple[int, StatisticMetaData]], convert_units: bool, table: type[StatisticsBase], @@ -2294,7 +2355,7 @@ def _import_statistics_with_session( """Import statistics to the database.""" statistics_meta_manager = instance.statistics_meta_manager old_metadata_dict = statistics_meta_manager.get_many( - session, statistic_ids=[metadata["statistic_id"]] + session, statistic_ids={metadata["statistic_id"]} ) _, metadata_id = statistics_meta_manager.update_or_add( session, metadata, old_metadata_dict @@ -2338,7 +2399,7 @@ def adjust_statistics( with session_scope(session=instance.get_session()) as session: metadata = instance.statistics_meta_manager.get_many( - session, statistic_ids=[statistic_id] + session, statistic_ids={statistic_id} ) if statistic_id not in metadata: return True @@ -2476,7 +2537,7 @@ def _validate_db_schema_utf8( try: with session_scope(session=session_maker()) as session: old_metadata_dict = statistics_meta_manager.get_many( - session, statistic_ids=[statistic_id] + session, statistic_ids={statistic_id} ) try: statistics_meta_manager.update_or_add( @@ -2573,7 +2634,7 @@ def _validate_db_schema( session, start_time, None, - [statistic_id], + {statistic_id}, "hour" if table == Statistics else "5minute", None, {"last_reset", "max", "mean", "min", "state", "sum"}, diff --git a/homeassistant/components/recorder/table_managers/statistics_meta.py b/homeassistant/components/recorder/table_managers/statistics_meta.py index 93417b43253..ba47b3600d6 100644 --- a/homeassistant/components/recorder/table_managers/statistics_meta.py +++ b/homeassistant/components/recorder/table_managers/statistics_meta.py @@ -34,7 +34,7 @@ QUERY_STATISTIC_META = ( def _generate_get_metadata_stmt( - statistic_ids: list[str] | None = None, + statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> StatementLambdaElement: @@ -89,7 +89,7 @@ class StatisticsMetaManager: def _get_from_database( self, session: Session, - statistic_ids: list[str] | None = None, + statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: @@ -112,6 +112,7 @@ class StatisticsMetaManager: ): statistics_meta = cast(StatisticsMeta, row) id_meta = _statistics_meta_to_id_statistics_metadata(statistics_meta) + statistic_id = cast(str, statistics_meta.statistic_id) results[statistic_id] = id_meta if update_cache: @@ -149,7 +150,7 @@ class StatisticsMetaManager: statistic_id: str, new_metadata: StatisticMetaData, old_metadata_dict: dict[str, tuple[int, StatisticMetaData]], - ) -> tuple[bool, int]: + ) -> tuple[str | None, int]: """Update metadata in the database. This call is not thread-safe and must be called from the @@ -163,7 +164,7 @@ class StatisticsMetaManager: or old_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"] ): - return False, metadata_id + return None, metadata_id self._assert_in_recorder_thread() session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update( @@ -182,7 +183,7 @@ class StatisticsMetaManager: old_metadata, new_metadata, ) - return True, metadata_id + return statistic_id, metadata_id def load(self, session: Session) -> None: """Load the statistic_id to metadata_id mapping into memory. @@ -196,12 +197,12 @@ class StatisticsMetaManager: self, session: Session, statistic_id: str ) -> tuple[int, StatisticMetaData] | None: """Resolve statistic_id to the metadata_id.""" - return self.get_many(session, [statistic_id]).get(statistic_id) + return self.get_many(session, {statistic_id}).get(statistic_id) def get_many( self, session: Session, - statistic_ids: list[str] | None = None, + statistic_ids: set[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: @@ -228,16 +229,8 @@ class StatisticsMetaManager: "Providing statistic_type and statistic_source is mutually exclusive of statistic_ids" ) - results: dict[str, tuple[int, StatisticMetaData]] = {} - missing_statistic_id: list[str] = [] - - for statistic_id in statistic_ids: - if id_meta := self._stat_id_to_id_meta.get(statistic_id): - results[statistic_id] = id_meta - else: - missing_statistic_id.append(statistic_id) - - if not missing_statistic_id: + results = self.get_from_cache_threadsafe(statistic_ids) + if not (missing_statistic_id := statistic_ids.difference(results)): return results # Fetch metadata from the database @@ -245,12 +238,29 @@ class StatisticsMetaManager: session, statistic_ids=missing_statistic_id ) + def get_from_cache_threadsafe( + self, statistic_ids: set[str] + ) -> dict[str, tuple[int, StatisticMetaData]]: + """Get metadata from cache. + + This call is thread safe and can be run in the event loop, + the database executor, or the recorder thread. + """ + return { + statistic_id: id_meta + for statistic_id in statistic_ids + # We must use a get call here and never iterate over the dict + # because the dict can be modified by the recorder thread + # while we are iterating over it. + if (id_meta := self._stat_id_to_id_meta.get(statistic_id)) + } + def update_or_add( self, session: Session, new_metadata: StatisticMetaData, old_metadata_dict: dict[str, tuple[int, StatisticMetaData]], - ) -> tuple[bool, int]: + ) -> tuple[str | None, int]: """Get metadata_id for a statistic_id. If the statistic_id is previously unknown, add it. If it's already known, update @@ -258,16 +268,16 @@ class StatisticsMetaManager: Updating metadata source is not possible. - Returns a tuple of (updated, metadata_id). + Returns a tuple of (statistic_id | None, metadata_id). - updated is True if the metadata was updated, False if it was not updated. + statistic_id is None if the metadata was not updated This call is not thread-safe and must be called from the recorder thread. """ statistic_id = new_metadata["statistic_id"] if statistic_id not in old_metadata_dict: - return True, self._add_metadata(session, statistic_id, new_metadata) + return statistic_id, self._add_metadata(session, statistic_id, new_metadata) return self._update_metadata( session, statistic_id, new_metadata, old_metadata_dict ) @@ -319,4 +329,14 @@ class StatisticsMetaManager: def reset(self) -> None: """Reset the cache.""" - self._stat_id_to_id_meta = {} + self._stat_id_to_id_meta.clear() + + def adjust_lru_size(self, new_size: int) -> None: + """Adjust the LRU cache size. + + This call is not thread-safe and must be called from the + recorder thread. + """ + lru: LRU = self._stat_id_to_id_meta + if new_size > lru.get_size(): + lru.set_size(new_size) diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index 29c0808e6ad..df42c519fe2 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -37,6 +37,7 @@ from .statistics import ( async_add_external_statistics, async_change_statistics_unit, async_import_statistics, + async_list_statistic_ids, list_statistic_ids, statistic_during_period, statistics_during_period, @@ -151,7 +152,7 @@ def _ws_get_statistics_during_period( msg_id: int, start_time: dt, end_time: dt | None, - statistic_ids: list[str] | None, + statistic_ids: set[str] | None, period: Literal["5minute", "day", "hour", "week", "month"], units: dict[str, str], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], @@ -208,7 +209,7 @@ async def ws_handle_get_statistics_during_period( msg["id"], start_time, end_time, - msg["statistic_ids"], + set(msg["statistic_ids"]), msg.get("period"), msg.get("units"), types, @@ -329,11 +330,10 @@ async def ws_get_statistics_metadata( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: """Get metadata for a list of statistic_ids.""" - instance = get_instance(hass) - statistic_ids = await instance.async_add_executor_job( - list_statistic_ids, hass, msg.get("statistic_ids") - ) - connection.send_result(msg["id"], statistic_ids) + statistic_ids = msg.get("statistic_ids") + statistic_ids_set_or_none = set(statistic_ids) if statistic_ids else None + metadata = await async_list_statistic_ids(hass, statistic_ids_set_or_none) + connection.send_result(msg["id"], metadata) @websocket_api.require_admin @@ -413,7 +413,7 @@ async def ws_adjust_sum_statistics( instance = get_instance(hass) metadatas = await instance.async_add_executor_job( - list_statistic_ids, hass, (msg["statistic_id"],) + list_statistic_ids, hass, {msg["statistic_id"]} ) if not metadatas: connection.send_error(msg["id"], "unknown_statistic_id", "Unknown statistic ID") diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index 8d5af155fd7..c0df642ed36 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -453,10 +453,10 @@ def _compile_statistics( # noqa: C901 # that are not in the metadata table and we are not working # with them anyway. old_metadatas = statistics.get_metadata_with_session( - get_instance(hass), session, statistic_ids=list(entities_with_float_states) + get_instance(hass), session, statistic_ids=set(entities_with_float_states) ) to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = [] - to_query: list[str] = [] + to_query: set[str] = set() for _state in sensor_states: entity_id = _state.entity_id if not (maybe_float_states := entities_with_float_states.get(entity_id)): @@ -472,7 +472,7 @@ def _compile_statistics( # noqa: C901 state_class: str = _state.attributes[ATTR_STATE_CLASS] to_process.append((entity_id, statistics_unit, state_class, valid_float_states)) if "sum" in wanted_statistics[entity_id]: - to_query.append(entity_id) + to_query.add(entity_id) last_stats = statistics.get_latest_short_term_statistics( hass, to_query, {"last_reset", "state", "sum"}, metadata=old_metadatas diff --git a/homeassistant/components/tibber/sensor.py b/homeassistant/components/tibber/sensor.py index 4d847c19205..874ec5be673 100644 --- a/homeassistant/components/tibber/sensor.py +++ b/homeassistant/components/tibber/sensor.py @@ -636,7 +636,7 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]): self.hass, start, None, - [statistic_id], + {statistic_id}, "hour", None, {"sum"}, diff --git a/tests/components/recorder/common.py b/tests/components/recorder/common.py index aec5bf81349..17e8c47f6b4 100644 --- a/tests/components/recorder/common.py +++ b/tests/components/recorder/common.py @@ -144,13 +144,15 @@ def statistics_during_period( hass: HomeAssistant, start_time: datetime, end_time: datetime | None = None, - statistic_ids: list[str] | None = None, + statistic_ids: set[str] | None = None, period: Literal["5minute", "day", "hour", "week", "month"] = "hour", units: dict[str, str] | None = None, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]] | None = None, ) -> dict[str, list[dict[str, Any]]]: """Call statistics_during_period with defaults for simpler tests.""" + if statistic_ids is not None and not isinstance(statistic_ids, set): + statistic_ids = set(statistic_ids) if types is None: types = {"last_reset", "max", "mean", "min", "state", "sum"} return statistics.statistics_during_period( diff --git a/tests/components/recorder/db_schema_28.py b/tests/components/recorder/db_schema_28.py index d7a9ec0af4e..8127cb3f26f 100644 --- a/tests/components/recorder/db_schema_28.py +++ b/tests/components/recorder/db_schema_28.py @@ -292,6 +292,15 @@ class States(Base): # type: ignore[misc,valid-type] context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) origin_idx = Column(SmallInteger) # 0 is local, 1 is remote + context_id_bin = Column( + LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH) + ) # *** Not originally in v28, only added for recorder to startup ok + context_user_id_bin = Column( + LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH) + ) # *** Not originally in v28, only added for recorder to startup ok + context_parent_id_bin = Column( + LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH) + ) # *** Not originally in v28, only added for recorder to startup ok metadata_id = Column( Integer, ForeignKey("states_meta.metadata_id"), index=True ) # *** Not originally in v28, only added for recorder to startup ok diff --git a/tests/components/recorder/table_managers/test_statistics_meta.py b/tests/components/recorder/table_managers/test_statistics_meta.py index 8ec3f9367d6..ab6615c6dd0 100644 --- a/tests/components/recorder/table_managers/test_statistics_meta.py +++ b/tests/components/recorder/table_managers/test_statistics_meta.py @@ -26,12 +26,12 @@ async def test_passing_mutually_exclusive_options_to_get_many( ) with pytest.raises(ValueError): instance.statistics_meta_manager.get_many( - session, statistic_ids=["light.kitchen"], statistic_source="sensor" + session, statistic_ids={"light.kitchen"}, statistic_source="sensor" ) assert ( instance.statistics_meta_manager.get_many( session, - statistic_ids=["light.kitchen"], + statistic_ids={"light.kitchen"}, ) == {} ) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 522a3eff2e6..de75052d344 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -84,7 +84,7 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant]) # Should not fail if there is nothing there yet stats = get_latest_short_term_statistics( - hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"} + hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"} ) assert stats == {} @@ -169,15 +169,15 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant]) assert stats == {"sensor.test1": [expected_2]} stats = get_latest_short_term_statistics( - hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"} + hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"} ) assert stats == {"sensor.test1": [expected_2]} - metadata = get_metadata(hass, statistic_ids=['sensor.test1"']) + metadata = get_metadata(hass, statistic_ids={"sensor.test1"}) stats = get_latest_short_term_statistics( hass, - ["sensor.test1"], + {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"}, metadata=metadata, ) @@ -213,7 +213,7 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant]) instance.get_session().query(StatisticsShortTerm).delete() # Should not fail there is nothing in the table stats = get_latest_short_term_statistics( - hass, ["sensor.test1"], {"last_reset", "max", "mean", "min", "state", "sum"} + hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"} ) assert stats == {} @@ -243,7 +243,7 @@ def mock_sensor_statistics(): sensor_stats("sensor.test3", start), ], get_metadata( - _hass, statistic_ids=["sensor.test1", "sensor.test2", "sensor.test3"] + _hass, statistic_ids={"sensor.test1", "sensor.test2", "sensor.test3"} ), ) @@ -385,6 +385,27 @@ def test_rename_entity(hass_recorder: Callable[..., HomeAssistant]) -> None: assert stats == {"sensor.test99": expected_stats99, "sensor.test2": expected_stats2} +def test_statistics_during_period_set_back_compat( + hass_recorder: Callable[..., HomeAssistant] +) -> None: + """Test statistics_during_period can handle a list instead of a set.""" + hass = hass_recorder() + setup_component(hass, "sensor", {}) + # This should not throw an exception when passed a list instead of a set + assert ( + statistics.statistics_during_period( + hass, + dt_util.utcnow(), + None, + statistic_ids=["sensor.test1"], + period="5minute", + units=None, + types=set(), + ) + == {} + ) + + def test_rename_entity_collision( hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture ) -> None: @@ -595,7 +616,7 @@ async def test_import_statistics( "unit_class": "energy", } ] - metadata = get_metadata(hass, statistic_ids=(statistic_id,)) + metadata = get_metadata(hass, statistic_ids={statistic_id}) assert metadata == { statistic_id: ( 1, @@ -692,7 +713,7 @@ async def test_import_statistics( "unit_class": "energy", } ] - metadata = get_metadata(hass, statistic_ids=(statistic_id,)) + metadata = get_metadata(hass, statistic_ids={statistic_id}) assert metadata == { statistic_id: ( 1, @@ -814,7 +835,7 @@ def test_external_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {} # Attempt to insert statistics for the wrong domain external_metadata = {**_external_metadata, "source": "other"} @@ -824,7 +845,7 @@ def test_external_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {} # Attempt to insert statistics for a naive starting time external_metadata = {**_external_metadata} @@ -837,7 +858,7 @@ def test_external_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {} # Attempt to insert statistics for an invalid starting time external_metadata = {**_external_metadata} @@ -847,7 +868,7 @@ def test_external_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {} # Attempt to insert statistics with a naive last_reset external_metadata = {**_external_metadata} @@ -860,7 +881,7 @@ def test_external_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {} def test_import_statistics_errors( @@ -903,7 +924,7 @@ def test_import_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"test:total_energy_import"}) == {} # Attempt to insert statistics for the wrong domain external_metadata = {**_external_metadata, "source": "sensor"} @@ -913,7 +934,7 @@ def test_import_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {} # Attempt to insert statistics for a naive starting time external_metadata = {**_external_metadata} @@ -926,7 +947,7 @@ def test_import_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {} # Attempt to insert statistics for an invalid starting time external_metadata = {**_external_metadata} @@ -936,7 +957,7 @@ def test_import_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {} # Attempt to insert statistics with a naive last_reset external_metadata = {**_external_metadata} @@ -949,7 +970,7 @@ def test_import_statistics_errors( wait_recording_done(hass) assert statistics_during_period(hass, zero, period="hour") == {} assert list_statistic_ids(hass) == [] - assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} + assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {} @pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"]) diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index 4ed6747ac0b..5244a33f0bc 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -2588,7 +2588,7 @@ async def test_import_statistics( "unit_class": "energy", } ] - metadata = get_metadata(hass, statistic_ids=(statistic_id,)) + metadata = get_metadata(hass, statistic_ids={statistic_id}) assert metadata == { statistic_id: ( 1, @@ -2820,7 +2820,7 @@ async def test_adjust_sum_statistics_energy( "unit_class": "energy", } ] - metadata = get_metadata(hass, statistic_ids=(statistic_id,)) + metadata = get_metadata(hass, statistic_ids={statistic_id}) assert metadata == { statistic_id: ( 1, @@ -3016,7 +3016,7 @@ async def test_adjust_sum_statistics_gas( "unit_class": "volume", } ] - metadata = get_metadata(hass, statistic_ids=(statistic_id,)) + metadata = get_metadata(hass, statistic_ids={statistic_id}) assert metadata == { statistic_id: ( 1, @@ -3230,7 +3230,7 @@ async def test_adjust_sum_statistics_errors( "unit_class": unit_class, } ] - metadata = get_metadata(hass, statistic_ids=(statistic_id,)) + metadata = get_metadata(hass, statistic_ids={statistic_id}) assert metadata == { statistic_id: ( 1, diff --git a/tests/components/sensor/test_recorder.py b/tests/components/sensor/test_recorder.py index ae044c535b5..8881bef8edc 100644 --- a/tests/components/sensor/test_recorder.py +++ b/tests/components/sensor/test_recorder.py @@ -3067,7 +3067,7 @@ def test_compile_hourly_statistics_changing_state_class( "unit_class": unit_class, }, ] - metadata = get_metadata(hass, statistic_ids=("sensor.test1",)) + metadata = get_metadata(hass, statistic_ids={"sensor.test1"}) assert metadata == { "sensor.test1": ( 1, @@ -3103,7 +3103,7 @@ def test_compile_hourly_statistics_changing_state_class( "unit_class": unit_class, }, ] - metadata = get_metadata(hass, statistic_ids=("sensor.test1",)) + metadata = get_metadata(hass, statistic_ids={"sensor.test1"}) assert metadata == { "sensor.test1": ( 1, diff --git a/tests/components/tibber/test_statistics.py b/tests/components/tibber/test_statistics.py index ca6500e6327..6de7549c285 100644 --- a/tests/components/tibber/test_statistics.py +++ b/tests/components/tibber/test_statistics.py @@ -35,7 +35,7 @@ async def test_async_setup_entry(recorder_mock: Recorder, hass: HomeAssistant) - hass, dt_util.parse_datetime(data[0]["from"]), None, - [statistic_id], + {statistic_id}, "hour", None, {"start", "state", "mean", "min", "max", "last_reset", "sum"},