Avoid fetching metadata multiple times during stat compile (#70397)

This commit is contained in:
J. Nick Koston 2022-04-22 00:25:42 -10:00 committed by GitHub
parent be0fbba523
commit 3737b58e85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 26 deletions

View File

@ -157,6 +157,14 @@ DISPLAY_UNIT_TO_STATISTIC_UNIT_CONVERSIONS: dict[
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@dataclasses.dataclass
class PlatformCompiledStatistics:
"""Compiled Statistics from a platform."""
platform_stats: list[StatisticResult]
current_metadata: dict[str, tuple[int, StatisticMetaData]]
def split_statistic_id(entity_id: str) -> list[str]: def split_statistic_id(entity_id: str) -> list[str]:
"""Split a state entity ID into domain and object ID.""" """Split a state entity ID into domain and object ID."""
return entity_id.split(":", 1) return entity_id.split(":", 1)
@ -550,28 +558,32 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
_LOGGER.debug("Compiling statistics for %s-%s", start, end) _LOGGER.debug("Compiling statistics for %s-%s", start, end)
platform_stats: list[StatisticResult] = [] platform_stats: list[StatisticResult] = []
current_metadata: dict[str, tuple[int, StatisticMetaData]] = {}
# Collect statistics from all platforms implementing support # Collect statistics from all platforms implementing support
for domain, platform in instance.hass.data[DOMAIN].items(): for domain, platform in instance.hass.data[DOMAIN].items():
if not hasattr(platform, "compile_statistics"): if not hasattr(platform, "compile_statistics"):
continue continue
platform_stat = platform.compile_statistics(instance.hass, start, end) compiled: PlatformCompiledStatistics = platform.compile_statistics(
_LOGGER.debug( instance.hass, start, end
"Statistics for %s during %s-%s: %s", domain, start, end, platform_stat
) )
platform_stats.extend(platform_stat) _LOGGER.debug(
"Statistics for %s during %s-%s: %s",
domain,
start,
end,
compiled.platform_stats,
)
platform_stats.extend(compiled.platform_stats)
current_metadata.update(compiled.current_metadata)
# Insert collected statistics in the database # Insert collected statistics in the database
with session_scope( with session_scope(
session=instance.get_session(), # type: ignore[misc] session=instance.get_session(), # type: ignore[misc]
exception_filter=_filter_unique_constraint_integrity_error(instance), exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session: ) as session:
statistic_ids = [stats["meta"]["statistic_id"] for stats in platform_stats]
old_metadata_dict = get_metadata_with_session(
instance.hass, session, statistic_ids=statistic_ids
)
for stats in platform_stats: for stats in platform_stats:
metadata_id = _update_or_add_metadata( metadata_id = _update_or_add_metadata(
session, stats["meta"], old_metadata_dict session, stats["meta"], current_metadata
) )
_insert_statistics( _insert_statistics(
session, session,
@ -1102,14 +1114,19 @@ def get_last_short_term_statistics(
def get_latest_short_term_statistics( def get_latest_short_term_statistics(
hass: HomeAssistant, statistic_ids: list[str] hass: HomeAssistant,
statistic_ids: list[str],
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[dict]]: ) -> dict[str, list[dict]]:
"""Return the latest short term statistics for a list of statistic_ids.""" """Return the latest short term statistics for a list of statistic_ids."""
# This function doesn't use a baked query, we instead rely on the # This function doesn't use a baked query, we instead rely on the
# "Transparent SQL Compilation Caching" feature introduced in SQLAlchemy 1.4 # "Transparent SQL Compilation Caching" feature introduced in SQLAlchemy 1.4
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
# Fetch metadata for the given statistic_ids # Fetch metadata for the given statistic_ids
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids) if not metadata:
metadata = get_metadata_with_session(
hass, session, statistic_ids=statistic_ids
)
if not metadata: if not metadata:
return {} return {}
metadata_ids = [ metadata_ids = [

View File

@ -387,14 +387,14 @@ def _last_reset_as_utc_isoformat(last_reset_s: Any, entity_id: str) -> str | Non
def compile_statistics( def compile_statistics(
hass: HomeAssistant, start: datetime.datetime, end: datetime.datetime hass: HomeAssistant, start: datetime.datetime, end: datetime.datetime
) -> list[StatisticResult]: ) -> statistics.PlatformCompiledStatistics:
"""Compile statistics for all entities during start-end. """Compile statistics for all entities during start-end.
Note: This will query the database and must not be run in the event loop Note: This will query the database and must not be run in the event loop
""" """
with recorder_util.session_scope(hass=hass) as session: with recorder_util.session_scope(hass=hass) as session:
result = _compile_statistics(hass, session, start, end) compiled = _compile_statistics(hass, session, start, end)
return result return compiled
def _compile_statistics( # noqa: C901 def _compile_statistics( # noqa: C901
@ -402,7 +402,7 @@ def _compile_statistics( # noqa: C901
session: Session, session: Session,
start: datetime.datetime, start: datetime.datetime,
end: datetime.datetime, end: datetime.datetime,
) -> list[StatisticResult]: ) -> statistics.PlatformCompiledStatistics:
"""Compile statistics for all entities during start-end.""" """Compile statistics for all entities during start-end."""
result: list[StatisticResult] = [] result: list[StatisticResult] = []
@ -473,7 +473,9 @@ def _compile_statistics( # noqa: C901
if "sum" in wanted_statistics[entity_id]: if "sum" in wanted_statistics[entity_id]:
to_query.append(entity_id) to_query.append(entity_id)
last_stats = statistics.get_latest_short_term_statistics(hass, to_query) last_stats = statistics.get_latest_short_term_statistics(
hass, to_query, metadata=old_metadatas
)
for ( # pylint: disable=too-many-nested-blocks for ( # pylint: disable=too-many-nested-blocks
entity_id, entity_id,
unit, unit,
@ -609,7 +611,7 @@ def _compile_statistics( # noqa: C901
result.append({"meta": meta, "stat": stat}) result.append({"meta": meta, "stat": stat})
return result return statistics.PlatformCompiledStatistics(result, old_metadatas)
def list_statistic_ids( def list_statistic_ids(

View File

@ -106,7 +106,7 @@ async def test_cost_sensor_price_entity_total_increasing(
"""Test energy cost price from total_increasing type sensor entity.""" """Test energy cost price from total_increasing type sensor entity."""
def _compile_statistics(_): def _compile_statistics(_):
return compile_statistics(hass, now, now + timedelta(seconds=1)) return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats
energy_attributes = { energy_attributes = {
ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR, ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR,
@ -311,7 +311,7 @@ async def test_cost_sensor_price_entity_total(
"""Test energy cost price from total type sensor entity.""" """Test energy cost price from total type sensor entity."""
def _compile_statistics(_): def _compile_statistics(_):
return compile_statistics(hass, now, now + timedelta(seconds=1)) return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats
energy_attributes = { energy_attributes = {
ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR, ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR,
@ -518,7 +518,7 @@ async def test_cost_sensor_price_entity_total_no_reset(
"""Test energy cost price from total type sensor entity with no last_reset.""" """Test energy cost price from total type sensor entity with no last_reset."""
def _compile_statistics(_): def _compile_statistics(_):
return compile_statistics(hass, now, now + timedelta(seconds=1)) return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats
energy_attributes = { energy_attributes = {
ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR, ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR,

View File

@ -124,6 +124,11 @@ def test_compile_hourly_statistics(hass_recorder):
stats = get_latest_short_term_statistics(hass, ["sensor.test1"]) stats = get_latest_short_term_statistics(hass, ["sensor.test1"])
assert stats == {"sensor.test1": [{**expected_2, "statistic_id": "sensor.test1"}]} assert stats == {"sensor.test1": [{**expected_2, "statistic_id": "sensor.test1"}]}
metadata = get_metadata(hass, statistic_ids=['sensor.test1"'])
stats = get_latest_short_term_statistics(hass, ["sensor.test1"], metadata=metadata)
assert stats == {"sensor.test1": [{**expected_2, "statistic_id": "sensor.test1"}]}
stats = get_last_short_term_statistics(hass, 2, "sensor.test1", True) stats = get_last_short_term_statistics(hass, 2, "sensor.test1", True)
assert stats == {"sensor.test1": expected_stats1[::-1]} assert stats == {"sensor.test1": expected_stats1[::-1]}
@ -156,11 +161,16 @@ def mock_sensor_statistics():
} }
def get_fake_stats(_hass, start, _end): def get_fake_stats(_hass, start, _end):
return [ return statistics.PlatformCompiledStatistics(
[
sensor_stats("sensor.test1", start), sensor_stats("sensor.test1", start),
sensor_stats("sensor.test2", start), sensor_stats("sensor.test2", start),
sensor_stats("sensor.test3", start), sensor_stats("sensor.test3", start),
] ],
get_metadata(
_hass, statistic_ids=["sensor.test1", "sensor.test2", "sensor.test3"]
),
)
with patch( with patch(
"homeassistant.components.sensor.recorder.compile_statistics", "homeassistant.components.sensor.recorder.compile_statistics",
@ -327,7 +337,8 @@ def test_statistics_duplicated(hass_recorder, caplog):
assert "Statistics already compiled" not in caplog.text assert "Statistics already compiled" not in caplog.text
with patch( with patch(
"homeassistant.components.sensor.recorder.compile_statistics" "homeassistant.components.sensor.recorder.compile_statistics",
return_value=statistics.PlatformCompiledStatistics([], {}),
) as compile_statistics: ) as compile_statistics:
recorder.do_adhoc_statistics(start=zero) recorder.do_adhoc_statistics(start=zero)
wait_recording_done(hass) wait_recording_done(hass)