Fix circular mean by always storing and using the weighted one (#142208)

* Fix circular mean by always storing and using the weighted one

* fix

* Fix test
This commit is contained in:
Robert Resch 2025-04-04 21:19:15 +02:00 committed by Franck Nijhof
parent 86eee4f041
commit e9abdab1f5
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
3 changed files with 76 additions and 60 deletions

View File

@ -139,14 +139,13 @@ def query_circular_mean(table: type[StatisticsBase]) -> tuple[Label, Label]:
# in Python. # in Python.
# https://en.wikipedia.org/wiki/Circular_mean # https://en.wikipedia.org/wiki/Circular_mean
radians = func.radians(table.mean) radians = func.radians(table.mean)
weighted_sum_sin = func.sum(func.sin(radians) * table.mean_weight)
weighted_sum_cos = func.sum(func.cos(radians) * table.mean_weight)
weight = func.sqrt( weight = func.sqrt(
func.power(func.sum(func.sin(radians) * table.mean_weight), 2) func.power(weighted_sum_sin, 2) + func.power(weighted_sum_cos, 2)
+ func.power(func.sum(func.cos(radians) * table.mean_weight), 2)
) )
return ( return (
func.degrees( func.degrees(func.atan2(weighted_sum_sin, weighted_sum_cos)).label("mean"),
func.atan2(func.sum(func.sin(radians)), func.sum(func.cos(radians)))
).label("mean"),
weight.label("mean_weight"), weight.label("mean_weight"),
) )
@ -240,18 +239,20 @@ DEG_TO_RAD = math.pi / 180
RAD_TO_DEG = 180 / math.pi RAD_TO_DEG = 180 / math.pi
def weighted_circular_mean(values: Iterable[tuple[float, float]]) -> float: def weighted_circular_mean(
"""Return the weighted circular mean of the values.""" values: Iterable[tuple[float, float]],
sin_sum = sum(math.sin(x * DEG_TO_RAD) * weight for x, weight in values) ) -> tuple[float, float]:
cos_sum = sum(math.cos(x * DEG_TO_RAD) * weight for x, weight in values) """Return the weighted circular mean and the weight of the values."""
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360 weighted_sin_sum, weighted_cos_sum = 0.0, 0.0
for x, weight in values:
rad_x = x * DEG_TO_RAD
weighted_sin_sum += math.sin(rad_x) * weight
weighted_cos_sum += math.cos(rad_x) * weight
return (
def circular_mean(values: list[float]) -> float: (RAD_TO_DEG * math.atan2(weighted_sin_sum, weighted_cos_sum)) % 360,
"""Return the circular mean of the values.""" math.sqrt(weighted_sin_sum**2 + weighted_cos_sum**2),
sin_sum = sum(math.sin(x * DEG_TO_RAD) for x in values) )
cos_sum = sum(math.cos(x * DEG_TO_RAD) for x in values)
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -300,6 +301,7 @@ class StatisticsRow(BaseStatisticsRow, total=False):
min: float | None min: float | None
max: float | None max: float | None
mean: float | None mean: float | None
mean_weight: float | None
change: float | None change: float | None
@ -1023,7 +1025,7 @@ def _reduce_statistics(
_want_sum = "sum" in types _want_sum = "sum" in types
for statistic_id, stat_list in stats.items(): for statistic_id, stat_list in stats.items():
max_values: list[float] = [] max_values: list[float] = []
mean_values: list[float] = [] mean_values: list[tuple[float, float]] = []
min_values: list[float] = [] min_values: list[float] = []
prev_stat: StatisticsRow = stat_list[0] prev_stat: StatisticsRow = stat_list[0]
fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds} fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds}
@ -1039,12 +1041,15 @@ def _reduce_statistics(
} }
if _want_mean: if _want_mean:
row["mean"] = None row["mean"] = None
row["mean_weight"] = None
if mean_values: if mean_values:
match metadata[statistic_id][1]["mean_type"]: match metadata[statistic_id][1]["mean_type"]:
case StatisticMeanType.ARITHMETIC: case StatisticMeanType.ARITHMETIC:
row["mean"] = mean(mean_values) row["mean"] = mean([x[0] for x in mean_values])
case StatisticMeanType.CIRCULAR: case StatisticMeanType.CIRCULAR:
row["mean"] = circular_mean(mean_values) row["mean"], row["mean_weight"] = (
weighted_circular_mean(mean_values)
)
mean_values.clear() mean_values.clear()
if _want_min: if _want_min:
row["min"] = min(min_values) if min_values else None row["min"] = min(min_values) if min_values else None
@ -1063,7 +1068,8 @@ def _reduce_statistics(
max_values.append(_max) max_values.append(_max)
if _want_mean: if _want_mean:
if (_mean := statistic.get("mean")) is not None: if (_mean := statistic.get("mean")) is not None:
mean_values.append(_mean) _mean_weight = statistic.get("mean_weight") or 0.0
mean_values.append((_mean, _mean_weight))
if _want_min and (_min := statistic.get("min")) is not None: if _want_min and (_min := statistic.get("min")) is not None:
min_values.append(_min) min_values.append(_min)
prev_stat = statistic prev_stat = statistic
@ -1385,7 +1391,7 @@ def _get_max_mean_min_statistic(
match metadata[1]["mean_type"]: match metadata[1]["mean_type"]:
case StatisticMeanType.CIRCULAR: case StatisticMeanType.CIRCULAR:
if circular_means := max_mean_min["circular_means"]: if circular_means := max_mean_min["circular_means"]:
mean_value = weighted_circular_mean(circular_means) mean_value = weighted_circular_mean(circular_means)[0]
case StatisticMeanType.ARITHMETIC: case StatisticMeanType.ARITHMETIC:
if (mean_value := max_mean_min.get("mean_acc")) is not None and ( if (mean_value := max_mean_min.get("mean_acc")) is not None and (
duration := max_mean_min.get("duration") duration := max_mean_min.get("duration")
@ -1739,12 +1745,12 @@ def statistic_during_period(
_type_column_mapping = { _type_column_mapping = {
"last_reset": "last_reset_ts", "last_reset": ("last_reset_ts",),
"max": "max", "max": ("max",),
"mean": "mean", "mean": ("mean", "mean_weight"),
"min": "min", "min": ("min",),
"state": "state", "state": ("state",),
"sum": "sum", "sum": ("sum",),
} }
@ -1756,7 +1762,8 @@ def _generate_select_columns_for_types_stmt(
track_on: list[str | None] = [ track_on: list[str | None] = [
table.__tablename__, # type: ignore[attr-defined] table.__tablename__, # type: ignore[attr-defined]
] ]
for key, column in _type_column_mapping.items(): for key, type_columns in _type_column_mapping.items():
for column in type_columns:
if key in types: if key in types:
columns = columns.add_columns(getattr(table, column)) columns = columns.add_columns(getattr(table, column))
track_on.append(column) track_on.append(column)
@ -1944,6 +1951,12 @@ def _statistics_during_period_with_session(
hass, session, start_time, units, _types, table, metadata, result hass, session, start_time, units, _types, table, metadata, result
) )
# filter out mean_weight as it is only needed to reduce statistics
# and not needed in the result
for stats_rows in result.values():
for row in stats_rows:
row.pop("mean_weight", None)
# Return statistics combined with metadata # Return statistics combined with metadata
return result return result
@ -2391,7 +2404,12 @@ def _sorted_statistics_to_dict(
field_map["last_reset"] = field_map.pop("last_reset_ts") field_map["last_reset"] = field_map.pop("last_reset_ts")
sum_idx = field_map["sum"] if "sum" in types else None sum_idx = field_map["sum"] if "sum" in types else None
sum_only = len(types) == 1 and sum_idx is not None sum_only = len(types) == 1 and sum_idx is not None
row_mapping = tuple((key, field_map[key]) for key in types if key in field_map) row_mapping = tuple(
(column, field_map[column])
for key in types
for column in ({key, *_type_column_mapping.get(key, ())})
if column in field_map
)
# Append all statistic entries, and optionally do unit conversion # Append all statistic entries, and optionally do unit conversion
table_duration_seconds = table.duration.total_seconds() table_duration_seconds = table.duration.total_seconds()
for meta_id, db_rows in stats_by_meta_id.items(): for meta_id, db_rows in stats_by_meta_id.items():

View File

@ -160,7 +160,7 @@ def _time_weighted_arithmetic_mean(
def _time_weighted_circular_mean( def _time_weighted_circular_mean(
fstates: list[tuple[float, State]], start: datetime.datetime, end: datetime.datetime fstates: list[tuple[float, State]], start: datetime.datetime, end: datetime.datetime
) -> float: ) -> tuple[float, float]:
"""Calculate a time weighted circular mean. """Calculate a time weighted circular mean.
The circular mean is calculated by weighting the states by duration in seconds between The circular mean is calculated by weighting the states by duration in seconds between
@ -623,7 +623,7 @@ def compile_statistics( # noqa: C901
valid_float_states, start, end valid_float_states, start, end
) )
case StatisticMeanType.CIRCULAR: case StatisticMeanType.CIRCULAR:
stat["mean"] = _time_weighted_circular_mean( stat["mean"], stat["mean_weight"] = _time_weighted_circular_mean(
valid_float_states, start, end valid_float_states, start, end
) )

View File

@ -4508,23 +4508,19 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
duration += dur duration += dur
return total / duration return total / duration
def _time_weighted_circular_mean(values: list[tuple[float, int]]): def _weighted_circular_mean(
values: Iterable[tuple[float, float]],
) -> tuple[float, float]:
sin_sum = 0 sin_sum = 0
cos_sum = 0 cos_sum = 0
for x, dur in values: for x, weight in values:
sin_sum += math.sin(x * DEG_TO_RAD) * dur sin_sum += math.sin(x * DEG_TO_RAD) * weight
cos_sum += math.cos(x * DEG_TO_RAD) * dur cos_sum += math.cos(x * DEG_TO_RAD) * weight
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360 return (
(RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360,
def _circular_mean(values: list[float]) -> float: math.sqrt(sin_sum**2 + cos_sum**2),
sin_sum = 0 )
cos_sum = 0
for x in values:
sin_sum += math.sin(x * DEG_TO_RAD)
cos_sum += math.cos(x * DEG_TO_RAD)
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
def _min(seq, last_state): def _min(seq, last_state):
if last_state is None: if last_state is None:
@ -4631,7 +4627,7 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
values = [(seq, durations[j]) for j, seq in enumerate(seq)] values = [(seq, durations[j]) for j, seq in enumerate(seq)]
if (state := last_states["sensor.test5"]) is not None: if (state := last_states["sensor.test5"]) is not None:
values.append((state, 5)) values.append((state, 5))
expected_means["sensor.test5"].append(_time_weighted_circular_mean(values)) expected_means["sensor.test5"].append(_weighted_circular_mean(values))
last_states["sensor.test5"] = seq[-1] last_states["sensor.test5"] = seq[-1]
start += timedelta(minutes=5) start += timedelta(minutes=5)
@ -4733,15 +4729,17 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
start = zero start = zero
end = zero + timedelta(minutes=5) end = zero + timedelta(minutes=5)
for i in range(24): for i in range(24):
for entity_id in ( for entity_id, mean_extractor in (
"sensor.test1", ("sensor.test1", lambda x: x),
"sensor.test2", ("sensor.test2", lambda x: x),
"sensor.test3", ("sensor.test3", lambda x: x),
"sensor.test4", ("sensor.test4", lambda x: x),
"sensor.test5", ("sensor.test5", lambda x: x[0]),
): ):
expected_average = ( expected_average = (
expected_means[entity_id][i] if entity_id in expected_means else None mean_extractor(expected_means[entity_id][i])
if entity_id in expected_means
else None
) )
expected_minimum = ( expected_minimum = (
expected_minima[entity_id][i] if entity_id in expected_minima else None expected_minima[entity_id][i] if entity_id in expected_minima else None
@ -4772,7 +4770,7 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
assert stats == expected_stats assert stats == expected_stats
def verify_stats( def verify_stats(
period: Literal["5minute", "day", "hour", "week", "month"], period: Literal["hour", "day", "week", "month"],
start: datetime, start: datetime,
next_datetime: Callable[[datetime], datetime], next_datetime: Callable[[datetime], datetime],
) -> None: ) -> None:
@ -4791,7 +4789,7 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
("sensor.test2", mean), ("sensor.test2", mean),
("sensor.test3", mean), ("sensor.test3", mean),
("sensor.test4", mean), ("sensor.test4", mean),
("sensor.test5", _circular_mean), ("sensor.test5", lambda x: _weighted_circular_mean(x)[0]),
): ):
expected_average = ( expected_average = (
mean_fn(expected_means[entity_id][i * 12 : (i + 1) * 12]) mean_fn(expected_means[entity_id][i * 12 : (i + 1) * 12])