Add typing to statistics results (#89118)

This commit is contained in:
J. Nick Koston 2023-03-14 09:06:56 -10:00 committed by GitHub
parent 9d2c62095f
commit a6d6807dd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 42 deletions

View File

@ -13,6 +13,7 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.components import recorder, websocket_api
from homeassistant.components.recorder.statistics import StatisticsRow
from homeassistant.const import UnitOfEnergy
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.integration_platform import (
@ -277,7 +278,7 @@ async def ws_get_fossil_energy_consumption(
)
def _combine_sum_statistics(
stats: dict[str, list[dict[str, Any]]], statistic_ids: list[str]
stats: dict[str, list[StatisticsRow]], statistic_ids: list[str]
) -> dict[float, float]:
"""Combine multiple statistics, returns a dict indexed by start time."""
result: defaultdict[float, float] = defaultdict(float)
@ -313,11 +314,10 @@ async def ws_get_fossil_energy_consumption(
if not stat_list:
return result
prev_stat: dict[str, Any] = stat_list[0]
fake_stat = {"start": stat_list[-1]["start"] + period.total_seconds()}
# Loop over the hourly deltas + a fake entry to end the period
for statistic in chain(
stat_list, ({"start": stat_list[-1]["start"] + period.total_seconds()},)
):
for statistic in chain(stat_list, (fake_stat,)):
if not same_period(prev_stat["start"], statistic["start"]):
start, _ = period_start_end(prev_stat["start"])
# The previous statistic was the last entry of the period
@ -338,10 +338,13 @@ async def ws_get_fossil_energy_consumption(
statistics, msg["energy_statistic_ids"]
)
energy_deltas = _calculate_deltas(merged_energy_statistics)
indexed_co2_statistics = {
period["start"]: period["mean"]
for period in statistics.get(msg["co2_statistic_id"], {})
}
indexed_co2_statistics = cast(
dict[float, float],
{
period["start"]: period["mean"]
for period in statistics.get(msg["co2_statistic_id"], {})
},
)
# Calculate amount of fossil based energy, assume 100% fossil if missing
fossil_energy = [

View File

@ -14,7 +14,7 @@ from operator import itemgetter
import os
import re
from statistics import mean
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text
from sqlalchemy.engine import Engine
@ -166,6 +166,24 @@ STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = {
_LOGGER = logging.getLogger(__name__)
class BaseStatisticsRow(TypedDict, total=False):
"""A processed row of statistic data."""
start: float
class StatisticsRow(BaseStatisticsRow, total=False):
"""A processed row of statistic data."""
end: float
last_reset: float | None
state: float | None
sum: float | None
min: float | None
max: float | None
mean: float | None
def _get_unit_class(unit: str | None) -> str | None:
"""Get corresponding unit class from from the statistics unit."""
if converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(unit):
@ -1048,14 +1066,14 @@ def list_statistic_ids(
def _reduce_statistics(
stats: dict[str, list[dict[str, Any]]],
stats: dict[str, list[StatisticsRow]],
same_period: Callable[[float, float], bool],
period_start_end: Callable[[float], tuple[float, float]],
period: timedelta,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]:
) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to daily or monthly statistics."""
result: dict[str, list[dict[str, Any]]] = defaultdict(list)
result: dict[str, list[StatisticsRow]] = defaultdict(list)
period_seconds = period.total_seconds()
_want_mean = "mean" in types
_want_min = "min" in types
@ -1067,16 +1085,15 @@ def _reduce_statistics(
max_values: list[float] = []
mean_values: list[float] = []
min_values: list[float] = []
prev_stat: dict[str, Any] = stat_list[0]
prev_stat: StatisticsRow = stat_list[0]
fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds}
# Loop over the hourly statistics + a fake entry to end the period
for statistic in chain(
stat_list, ({"start": stat_list[-1]["start"] + period_seconds},)
):
for statistic in chain(stat_list, (fake_entry,)):
if not same_period(prev_stat["start"], statistic["start"]):
start, end = period_start_end(prev_stat["start"])
# The previous statistic was the last entry of the period
row: dict[str, Any] = {
row: StatisticsRow = {
"start": start,
"end": end,
}
@ -1146,9 +1163,9 @@ def reduce_day_ts_factory() -> (
def _reduce_statistics_per_day(
stats: dict[str, list[dict[str, Any]]],
stats: dict[str, list[StatisticsRow]],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]:
) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to daily statistics."""
_same_day_ts, _day_start_end_ts = reduce_day_ts_factory()
return _reduce_statistics(
@ -1196,9 +1213,9 @@ def reduce_week_ts_factory() -> (
def _reduce_statistics_per_week(
stats: dict[str, list[dict[str, Any]]],
stats: dict[str, list[StatisticsRow]],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]:
) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to weekly statistics."""
_same_week_ts, _week_start_end_ts = reduce_week_ts_factory()
return _reduce_statistics(
@ -1248,9 +1265,9 @@ def reduce_month_ts_factory() -> (
def _reduce_statistics_per_month(
stats: dict[str, list[dict[str, Any]]],
stats: dict[str, list[StatisticsRow]],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]:
) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to monthly statistics."""
_same_month_ts, _month_start_end_ts = reduce_month_ts_factory()
return _reduce_statistics(
@ -1724,7 +1741,7 @@ def _statistics_during_period_with_session(
period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]:
) -> dict[str, list[StatisticsRow]]:
"""Return statistic data points during UTC period start_time - end_time.
If end_time is omitted, returns statistics newer than or equal to start_time.
@ -1808,7 +1825,7 @@ def statistics_during_period(
period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]:
) -> dict[str, list[StatisticsRow]]:
"""Return statistic data points during UTC period start_time - end_time.
If end_time is omitted, returns statistics newer than or equal to start_time.
@ -1863,7 +1880,7 @@ def _get_last_statistics(
convert_units: bool,
table: type[StatisticsBase],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]:
) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats statistics for a given statistic_id."""
statistic_ids = [statistic_id]
with session_scope(hass=hass, read_only=True) as session:
@ -1902,7 +1919,7 @@ def get_last_statistics(
statistic_id: str,
convert_units: bool,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]:
) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats statistics for a statistic_id."""
return _get_last_statistics(
hass, number_of_stats, statistic_id, convert_units, Statistics, types
@ -1915,7 +1932,7 @@ def get_last_short_term_statistics(
statistic_id: str,
convert_units: bool,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]:
) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats short term statistics for a statistic_id."""
return _get_last_statistics(
hass, number_of_stats, statistic_id, convert_units, StatisticsShortTerm, types
@ -1951,7 +1968,7 @@ def get_latest_short_term_statistics(
statistic_ids: list[str],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[dict]]:
) -> dict[str, list[StatisticsRow]]:
"""Return the latest short term statistics for a list of statistic_ids."""
with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_ids
@ -2054,10 +2071,10 @@ def _sorted_statistics_to_dict(
start_time: datetime | None,
units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]:
) -> dict[str, list[StatisticsRow]]:
"""Convert SQL results into JSON friendly data structure."""
assert stats, "stats must not be empty" # Guard against implementation error
result: dict = defaultdict(list)
result: dict[str, list[StatisticsRow]] = defaultdict(list)
metadata = dict(_metadata.values())
need_stat_at_start_time: set[int] = set()
start_time_ts = start_time.timestamp() if start_time else None
@ -2123,7 +2140,7 @@ def _sorted_statistics_to_dict(
# attribute lookups, and dict lookups as much as possible.
#
for db_state in stats_list:
row: dict[str, Any] = {
row: StatisticsRow = {
"start": (start_ts := db_state[start_ts_idx]),
"end": start_ts + table_duration_seconds,
}

View File

@ -529,11 +529,11 @@ def _compile_statistics( # noqa: C901
if entity_id in last_stats:
# We have compiled history for this sensor before,
# use that as a starting point.
last_reset = old_last_reset = _timestamp_to_isoformat_or_none(
last_stats[entity_id][0]["last_reset"]
)
new_state = old_state = last_stats[entity_id][0]["state"]
_sum = last_stats[entity_id][0]["sum"] or 0.0
last_stat = last_stats[entity_id][0]
last_reset = _timestamp_to_isoformat_or_none(last_stat["last_reset"])
old_last_reset = last_reset
new_state = old_state = last_stat["state"]
_sum = last_stat["sum"] or 0.0
for fstate, state in fstates:
reset = False
@ -596,7 +596,7 @@ def _compile_statistics( # noqa: C901
if reset:
# The sensor has been reset, update the sum
if old_state is not None:
if old_state is not None and new_state is not None:
_sum += new_state - old_state
# ..and update the starting point
new_state = fstate

View File

@ -6,7 +6,7 @@ import datetime
from datetime import timedelta
import logging
from random import randrange
from typing import Any
from typing import Any, cast
import aiohttp
import tibber
@ -614,7 +614,7 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]):
5 * 365 * 24, production=is_production
)
_sum = 0
_sum = 0.0
last_stats_time = None
else:
# hourly_consumption/production_data contains the last 30 days
@ -641,8 +641,9 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]):
None,
{"sum"},
)
_sum = stat[statistic_id][0]["sum"]
last_stats_time = stat[statistic_id][0]["start"]
first_stat = stat[statistic_id][0]
_sum = cast(float, first_stat["sum"])
last_stats_time = first_stat["start"]
statistics = []