From c9a55c7f84da8b8e35a57722c910ac4a33fd3f58 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 26 Sep 2023 09:57:59 -0500 Subject: [PATCH] Cache the latest short term stat id for each metadata_id on each run (#100535) --- .../components/recorder/statistics.py | 199 +++++++++++++++--- tests/components/recorder/test_statistics.py | 21 ++ .../components/recorder/test_websocket_api.py | 8 + 3 files changed, 204 insertions(+), 24 deletions(-) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 005859b865b..24fb209ae07 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -24,6 +24,7 @@ import voluptuous as vol from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT from homeassistant.core import HomeAssistant, callback, valid_entity_id from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.singleton import singleton from homeassistant.helpers.typing import UNDEFINED, UndefinedType from homeassistant.util import dt as dt_util from homeassistant.util.unit_conversion import ( @@ -141,10 +142,39 @@ STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { **{unit: VolumeConverter for unit in VolumeConverter.VALID_UNITS}, } +DATA_SHORT_TERM_STATISTICS_RUN_CACHE = "recorder_short_term_statistics_run_cache" + _LOGGER = logging.getLogger(__name__) +@dataclasses.dataclass(slots=True) +class ShortTermStatisticsRunCache: + """Cache for short term statistics runs.""" + + # This is a mapping of metadata_id:id of the last short term + # statistics run for each metadata_id + _latest_id_by_metadata_id: dict[int, int] = dataclasses.field(default_factory=dict) + + def get_latest_ids(self, metadata_ids: set[int]) -> dict[int, int]: + """Return the latest short term statistics ids for the metadata_ids.""" + return { + metadata_id: id_ + for metadata_id, id_ in self._latest_id_by_metadata_id.items() + if metadata_id in metadata_ids + } + + def set_latest_id_for_metadata_id(self, metadata_id: int, id_: int) -> None: + """Cache the latest id for the metadata_id.""" + self._latest_id_by_metadata_id[metadata_id] = id_ + + def set_latest_ids_for_metadata_ids( + self, metadata_id_to_id: dict[int, int] + ) -> None: + """Cache the latest id for the each metadata_id.""" + self._latest_id_by_metadata_id.update(metadata_id_to_id) + + class BaseStatisticsRow(TypedDict, total=False): """A processed row of statistic data.""" @@ -508,6 +538,8 @@ def _compile_statistics( platform_stats.extend(compiled.platform_stats) current_metadata.update(compiled.current_metadata) + new_short_term_stats: list[StatisticsBase] = [] + updated_metadata_ids: set[int] = set() # Insert collected statistics in the database for stats in platform_stats: modified_statistic_id, metadata_id = statistics_meta_manager.update_or_add( @@ -515,12 +547,14 @@ def _compile_statistics( ) if modified_statistic_id is not None: modified_statistic_ids.add(modified_statistic_id) - _insert_statistics( + updated_metadata_ids.add(metadata_id) + if new_stat := _insert_statistics( session, StatisticsShortTerm, metadata_id, stats["stat"], - ) + ): + new_short_term_stats.append(new_stat) if start.minute == 55: # A full hour is ready, summarize it @@ -533,6 +567,23 @@ def _compile_statistics( if start.minute == 55: instance.hass.bus.fire(EVENT_RECORDER_HOURLY_STATISTICS_GENERATED) + if updated_metadata_ids: + # These are always the newest statistics, so we can update + # the run cache without having to check the start_ts. + session.flush() # populate the ids of the new StatisticsShortTerm rows + run_cache = get_short_term_statistics_run_cache(instance.hass) + # metadata_id is typed to allow None, but we know it's not None here + # so we can safely cast it to int. + run_cache.set_latest_ids_for_metadata_ids( + cast( + dict[int, int], + { + new_stat.metadata_id: new_stat.id + for new_stat in new_short_term_stats + }, + ) + ) + return modified_statistic_ids @@ -566,16 +617,19 @@ def _insert_statistics( table: type[StatisticsBase], metadata_id: int, statistic: StatisticData, -) -> None: +) -> StatisticsBase | None: """Insert statistics in the database.""" try: - session.add(table.from_stats(metadata_id, statistic)) + stat = table.from_stats(metadata_id, statistic) + session.add(stat) + return stat except SQLAlchemyError: _LOGGER.exception( "Unexpected exception when inserting statistics %s:%s ", metadata_id, statistic, ) + return None def _update_statistics( @@ -1809,24 +1863,26 @@ def get_last_short_term_statistics( ) -def _latest_short_term_statistics_stmt( - metadata_ids: list[int], +def get_latest_short_term_statistics_by_ids( + session: Session, ids: Iterable[int] +) -> list[Row]: + """Return the latest short term statistics for a list of ids.""" + stmt = _latest_short_term_statistics_by_ids_stmt(ids) + return list( + cast( + Sequence[Row], + execute_stmt_lambda_element(session, stmt, orm_rows=False), + ) + ) + + +def _latest_short_term_statistics_by_ids_stmt( + ids: Iterable[int], ) -> StatementLambdaElement: - """Create the statement for finding the latest short term stat rows.""" + """Create the statement for finding the latest short term stat rows by id.""" return lambda_stmt( - lambda: select(*QUERY_STATISTICS_SHORT_TERM).join( - ( - most_recent_statistic_row := ( - select( - StatisticsShortTerm.metadata_id, - func.max(StatisticsShortTerm.start_ts).label("start_max"), - ) - .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) - .group_by(StatisticsShortTerm.metadata_id) - ).subquery() - ), - (StatisticsShortTerm.metadata_id == most_recent_statistic_row.c.metadata_id) - & (StatisticsShortTerm.start_ts == most_recent_statistic_row.c.start_max), + lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter( + StatisticsShortTerm.id.in_(ids) ) ) @@ -1846,11 +1902,38 @@ def get_latest_short_term_statistics( ) if not metadata: return {} - metadata_ids = _extract_metadata_and_discard_impossible_columns(metadata, types) - stmt = _latest_short_term_statistics_stmt(metadata_ids) - stats = cast( - Sequence[Row], execute_stmt_lambda_element(session, stmt, orm_rows=False) + metadata_ids = set( + _extract_metadata_and_discard_impossible_columns(metadata, types) ) + run_cache = get_short_term_statistics_run_cache(hass) + # Try to find the latest short term statistics ids for the metadata_ids + # from the run cache first if we have it. If the run cache references + # a non-existent id because of a purge, we will detect it missing in the + # next step and run a query to re-populate the cache. + stats: list[Row] = [] + if metadata_id_to_id := run_cache.get_latest_ids(metadata_ids): + stats = get_latest_short_term_statistics_by_ids( + session, metadata_id_to_id.values() + ) + # If we are missing some metadata_ids in the run cache, we need run a query + # to populate the cache for each metadata_id, and then run another query + # to get the latest short term statistics for the missing metadata_ids. + if (missing_metadata_ids := metadata_ids - set(metadata_id_to_id)) and ( + found_latest_ids := { + latest_id + for metadata_id in missing_metadata_ids + if ( + latest_id := cache_latest_short_term_statistic_id_for_metadata_id( + run_cache, session, metadata_id + ) + ) + is not None + } + ): + stats.extend( + get_latest_short_term_statistics_by_ids(session, found_latest_ids) + ) + if not stats: return {} @@ -2221,9 +2304,77 @@ def _import_statistics_with_session( else: _insert_statistics(session, table, metadata_id, stat) + if table != StatisticsShortTerm: + return True + + # We just inserted new short term statistics, so we need to update the + # ShortTermStatisticsRunCache with the latest id for the metadata_id + run_cache = get_short_term_statistics_run_cache(instance.hass) + cache_latest_short_term_statistic_id_for_metadata_id( + run_cache, session, metadata_id + ) + return True +@singleton(DATA_SHORT_TERM_STATISTICS_RUN_CACHE) +def get_short_term_statistics_run_cache( + hass: HomeAssistant, +) -> ShortTermStatisticsRunCache: + """Get the short term statistics run cache.""" + return ShortTermStatisticsRunCache() + + +def cache_latest_short_term_statistic_id_for_metadata_id( + run_cache: ShortTermStatisticsRunCache, session: Session, metadata_id: int +) -> int | None: + """Cache the latest short term statistic for a given metadata_id. + + Returns the id of the latest short term statistic for the metadata_id + that was added to the cache, or None if no latest short term statistic + was found for the metadata_id. + """ + if latest := cast( + Sequence[Row], + execute_stmt_lambda_element( + session, + _find_latest_short_term_statistic_for_metadata_id_stmt(metadata_id), + orm_rows=False, + ), + ): + id_: int = latest[0].id + run_cache.set_latest_id_for_metadata_id(metadata_id, id_) + return id_ + return None + + +def _find_latest_short_term_statistic_for_metadata_id_stmt( + metadata_id: int, +) -> StatementLambdaElement: + """Create a statement to find the latest short term statistics for a metadata_id.""" + # + # This code only looks up one row, and should not be refactored to + # lookup multiple using func.max + # or similar, as that will cause the query to be significantly slower + # for DBMs such as PostgreSQL that will have to do a full scan + # + # For PostgreSQL a combined query plan looks like: + # (actual time=2.218..893.909 rows=170531 loops=1) + # + # For PostgreSQL a separate query plan looks like: + # (actual time=0.301..0.301 rows=1 loops=1) + # + # + return lambda_stmt( + lambda: select( + StatisticsShortTerm.id, + ) + .where(StatisticsShortTerm.metadata_id == metadata_id) + .order_by(StatisticsShortTerm.start_ts.desc()) + .limit(1) + ) + + @retryable_database_job("statistics") def import_statistics( instance: Recorder, diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index ab89b82d713..e56b2b83274 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -24,6 +24,7 @@ from homeassistant.components.recorder.statistics import ( get_last_statistics, get_latest_short_term_statistics, get_metadata, + get_short_term_statistics_run_cache, list_statistic_ids, ) from homeassistant.components.recorder.table_managers.statistics_meta import ( @@ -176,6 +177,15 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant]) ) assert stats == {"sensor.test1": [expected_2]} + # Now wipe the latest_short_term_statistics_ids table and test again + # to make sure we can rebuild the missing data + run_cache = get_short_term_statistics_run_cache(instance.hass) + run_cache._latest_id_by_metadata_id = {} + stats = get_latest_short_term_statistics( + hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"} + ) + assert stats == {"sensor.test1": [expected_2]} + metadata = get_metadata(hass, statistic_ids={"sensor.test1"}) stats = get_latest_short_term_statistics( @@ -220,6 +230,17 @@ def test_compile_hourly_statistics(hass_recorder: Callable[..., HomeAssistant]) ) assert stats == {} + # Delete again, and manually wipe the cache since we deleted all the data + instance.get_session().query(StatisticsShortTerm).delete() + run_cache = get_short_term_statistics_run_cache(instance.hass) + run_cache._latest_id_by_metadata_id = {} + + # And test again to make sure there is no data + stats = get_latest_short_term_statistics( + hass, {"sensor.test1"}, {"last_reset", "max", "mean", "min", "state", "sum"} + ) + assert stats == {} + @pytest.fixture def mock_sensor_statistics(): diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index a9dc23ef5b3..38b657945f7 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -15,6 +15,7 @@ from homeassistant.components.recorder.statistics import ( async_add_external_statistics, get_last_statistics, get_metadata, + get_short_term_statistics_run_cache, list_statistic_ids, ) from homeassistant.components.recorder.websocket_api import UNIT_SCHEMA @@ -302,6 +303,13 @@ async def test_statistic_during_period( ) await async_wait_recording_done(hass) + metadata = get_metadata(hass, statistic_ids={"sensor.test"}) + metadata_id = metadata["sensor.test"][0] + run_cache = get_short_term_statistics_run_cache(hass) + # Verify the import of the short term statistics + # also updates the run cache + assert run_cache.get_latest_ids({metadata_id}) is not None + # No data for this period yet await client.send_json( {