diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index f045af45a7a..0056a81fb60 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -157,6 +157,14 @@ DISPLAY_UNIT_TO_STATISTIC_UNIT_CONVERSIONS: dict[ _LOGGER = logging.getLogger(__name__) +@dataclasses.dataclass +class PlatformCompiledStatistics: + """Compiled Statistics from a platform.""" + + platform_stats: list[StatisticResult] + current_metadata: dict[str, tuple[int, StatisticMetaData]] + + def split_statistic_id(entity_id: str) -> list[str]: """Split a state entity ID into domain and object ID.""" return entity_id.split(":", 1) @@ -550,28 +558,32 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool: _LOGGER.debug("Compiling statistics for %s-%s", start, end) platform_stats: list[StatisticResult] = [] + current_metadata: dict[str, tuple[int, StatisticMetaData]] = {} # Collect statistics from all platforms implementing support for domain, platform in instance.hass.data[DOMAIN].items(): if not hasattr(platform, "compile_statistics"): continue - platform_stat = platform.compile_statistics(instance.hass, start, end) - _LOGGER.debug( - "Statistics for %s during %s-%s: %s", domain, start, end, platform_stat + compiled: PlatformCompiledStatistics = platform.compile_statistics( + instance.hass, start, end ) - platform_stats.extend(platform_stat) + _LOGGER.debug( + "Statistics for %s during %s-%s: %s", + domain, + start, + end, + compiled.platform_stats, + ) + platform_stats.extend(compiled.platform_stats) + current_metadata.update(compiled.current_metadata) # Insert collected statistics in the database with session_scope( session=instance.get_session(), # type: ignore[misc] exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: - statistic_ids = [stats["meta"]["statistic_id"] for stats in platform_stats] - old_metadata_dict = get_metadata_with_session( - instance.hass, session, statistic_ids=statistic_ids - ) for stats in platform_stats: metadata_id = _update_or_add_metadata( - session, stats["meta"], old_metadata_dict + session, stats["meta"], current_metadata ) _insert_statistics( session, @@ -1102,14 +1114,19 @@ def get_last_short_term_statistics( def get_latest_short_term_statistics( - hass: HomeAssistant, statistic_ids: list[str] + hass: HomeAssistant, + statistic_ids: list[str], + metadata: dict[str, tuple[int, StatisticMetaData]] | None = None, ) -> dict[str, list[dict]]: """Return the latest short term statistics for a list of statistic_ids.""" # This function doesn't use a baked query, we instead rely on the # "Transparent SQL Compilation Caching" feature introduced in SQLAlchemy 1.4 with session_scope(hass=hass) as session: # Fetch metadata for the given statistic_ids - metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids) + if not metadata: + metadata = get_metadata_with_session( + hass, session, statistic_ids=statistic_ids + ) if not metadata: return {} metadata_ids = [ diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index 0d691b74974..3fc5cbec7ee 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -387,14 +387,14 @@ def _last_reset_as_utc_isoformat(last_reset_s: Any, entity_id: str) -> str | Non def compile_statistics( hass: HomeAssistant, start: datetime.datetime, end: datetime.datetime -) -> list[StatisticResult]: +) -> statistics.PlatformCompiledStatistics: """Compile statistics for all entities during start-end. Note: This will query the database and must not be run in the event loop """ with recorder_util.session_scope(hass=hass) as session: - result = _compile_statistics(hass, session, start, end) - return result + compiled = _compile_statistics(hass, session, start, end) + return compiled def _compile_statistics( # noqa: C901 @@ -402,7 +402,7 @@ def _compile_statistics( # noqa: C901 session: Session, start: datetime.datetime, end: datetime.datetime, -) -> list[StatisticResult]: +) -> statistics.PlatformCompiledStatistics: """Compile statistics for all entities during start-end.""" result: list[StatisticResult] = [] @@ -473,7 +473,9 @@ def _compile_statistics( # noqa: C901 if "sum" in wanted_statistics[entity_id]: to_query.append(entity_id) - last_stats = statistics.get_latest_short_term_statistics(hass, to_query) + last_stats = statistics.get_latest_short_term_statistics( + hass, to_query, metadata=old_metadatas + ) for ( # pylint: disable=too-many-nested-blocks entity_id, unit, @@ -609,7 +611,7 @@ def _compile_statistics( # noqa: C901 result.append({"meta": meta, "stat": stat}) - return result + return statistics.PlatformCompiledStatistics(result, old_metadatas) def list_statistic_ids( diff --git a/tests/components/energy/test_sensor.py b/tests/components/energy/test_sensor.py index fa350329e97..57074c2bfde 100644 --- a/tests/components/energy/test_sensor.py +++ b/tests/components/energy/test_sensor.py @@ -106,7 +106,7 @@ async def test_cost_sensor_price_entity_total_increasing( """Test energy cost price from total_increasing type sensor entity.""" def _compile_statistics(_): - return compile_statistics(hass, now, now + timedelta(seconds=1)) + return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats energy_attributes = { ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR, @@ -311,7 +311,7 @@ async def test_cost_sensor_price_entity_total( """Test energy cost price from total type sensor entity.""" def _compile_statistics(_): - return compile_statistics(hass, now, now + timedelta(seconds=1)) + return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats energy_attributes = { ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR, @@ -518,7 +518,7 @@ async def test_cost_sensor_price_entity_total_no_reset( """Test energy cost price from total type sensor entity with no last_reset.""" def _compile_statistics(_): - return compile_statistics(hass, now, now + timedelta(seconds=1)) + return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats energy_attributes = { ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR, diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 51548e69ca5..ccdbcc4f74d 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -124,6 +124,11 @@ def test_compile_hourly_statistics(hass_recorder): stats = get_latest_short_term_statistics(hass, ["sensor.test1"]) assert stats == {"sensor.test1": [{**expected_2, "statistic_id": "sensor.test1"}]} + metadata = get_metadata(hass, statistic_ids=['sensor.test1"']) + + stats = get_latest_short_term_statistics(hass, ["sensor.test1"], metadata=metadata) + assert stats == {"sensor.test1": [{**expected_2, "statistic_id": "sensor.test1"}]} + stats = get_last_short_term_statistics(hass, 2, "sensor.test1", True) assert stats == {"sensor.test1": expected_stats1[::-1]} @@ -156,11 +161,16 @@ def mock_sensor_statistics(): } def get_fake_stats(_hass, start, _end): - return [ - sensor_stats("sensor.test1", start), - sensor_stats("sensor.test2", start), - sensor_stats("sensor.test3", start), - ] + return statistics.PlatformCompiledStatistics( + [ + sensor_stats("sensor.test1", start), + sensor_stats("sensor.test2", start), + sensor_stats("sensor.test3", start), + ], + get_metadata( + _hass, statistic_ids=["sensor.test1", "sensor.test2", "sensor.test3"] + ), + ) with patch( "homeassistant.components.sensor.recorder.compile_statistics", @@ -327,7 +337,8 @@ def test_statistics_duplicated(hass_recorder, caplog): assert "Statistics already compiled" not in caplog.text with patch( - "homeassistant.components.sensor.recorder.compile_statistics" + "homeassistant.components.sensor.recorder.compile_statistics", + return_value=statistics.PlatformCompiledStatistics([], {}), ) as compile_statistics: recorder.do_adhoc_statistics(start=zero) wait_recording_done(hass)