Fix circular mean by always storing and using the weighted one (#142208)

* Fix circular mean by always storing and using the weighted one

* fix

* Fix test
This commit is contained in:
Robert Resch 2025-04-04 21:19:15 +02:00 committed by Franck Nijhof
parent 86eee4f041
commit e9abdab1f5
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
3 changed files with 76 additions and 60 deletions

View File

@ -139,14 +139,13 @@ def query_circular_mean(table: type[StatisticsBase]) -> tuple[Label, Label]:
# in Python.
# https://en.wikipedia.org/wiki/Circular_mean
radians = func.radians(table.mean)
weighted_sum_sin = func.sum(func.sin(radians) * table.mean_weight)
weighted_sum_cos = func.sum(func.cos(radians) * table.mean_weight)
weight = func.sqrt(
func.power(func.sum(func.sin(radians) * table.mean_weight), 2)
+ func.power(func.sum(func.cos(radians) * table.mean_weight), 2)
func.power(weighted_sum_sin, 2) + func.power(weighted_sum_cos, 2)
)
return (
func.degrees(
func.atan2(func.sum(func.sin(radians)), func.sum(func.cos(radians)))
).label("mean"),
func.degrees(func.atan2(weighted_sum_sin, weighted_sum_cos)).label("mean"),
weight.label("mean_weight"),
)
@ -240,18 +239,20 @@ DEG_TO_RAD = math.pi / 180
RAD_TO_DEG = 180 / math.pi
def weighted_circular_mean(values: Iterable[tuple[float, float]]) -> float:
"""Return the weighted circular mean of the values."""
sin_sum = sum(math.sin(x * DEG_TO_RAD) * weight for x, weight in values)
cos_sum = sum(math.cos(x * DEG_TO_RAD) * weight for x, weight in values)
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
def weighted_circular_mean(
values: Iterable[tuple[float, float]],
) -> tuple[float, float]:
"""Return the weighted circular mean and the weight of the values."""
weighted_sin_sum, weighted_cos_sum = 0.0, 0.0
for x, weight in values:
rad_x = x * DEG_TO_RAD
weighted_sin_sum += math.sin(rad_x) * weight
weighted_cos_sum += math.cos(rad_x) * weight
def circular_mean(values: list[float]) -> float:
"""Return the circular mean of the values."""
sin_sum = sum(math.sin(x * DEG_TO_RAD) for x in values)
cos_sum = sum(math.cos(x * DEG_TO_RAD) for x in values)
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
return (
(RAD_TO_DEG * math.atan2(weighted_sin_sum, weighted_cos_sum)) % 360,
math.sqrt(weighted_sin_sum**2 + weighted_cos_sum**2),
)
_LOGGER = logging.getLogger(__name__)
@ -300,6 +301,7 @@ class StatisticsRow(BaseStatisticsRow, total=False):
min: float | None
max: float | None
mean: float | None
mean_weight: float | None
change: float | None
@ -1023,7 +1025,7 @@ def _reduce_statistics(
_want_sum = "sum" in types
for statistic_id, stat_list in stats.items():
max_values: list[float] = []
mean_values: list[float] = []
mean_values: list[tuple[float, float]] = []
min_values: list[float] = []
prev_stat: StatisticsRow = stat_list[0]
fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds}
@ -1039,12 +1041,15 @@ def _reduce_statistics(
}
if _want_mean:
row["mean"] = None
row["mean_weight"] = None
if mean_values:
match metadata[statistic_id][1]["mean_type"]:
case StatisticMeanType.ARITHMETIC:
row["mean"] = mean(mean_values)
row["mean"] = mean([x[0] for x in mean_values])
case StatisticMeanType.CIRCULAR:
row["mean"] = circular_mean(mean_values)
row["mean"], row["mean_weight"] = (
weighted_circular_mean(mean_values)
)
mean_values.clear()
if _want_min:
row["min"] = min(min_values) if min_values else None
@ -1063,7 +1068,8 @@ def _reduce_statistics(
max_values.append(_max)
if _want_mean:
if (_mean := statistic.get("mean")) is not None:
mean_values.append(_mean)
_mean_weight = statistic.get("mean_weight") or 0.0
mean_values.append((_mean, _mean_weight))
if _want_min and (_min := statistic.get("min")) is not None:
min_values.append(_min)
prev_stat = statistic
@ -1385,7 +1391,7 @@ def _get_max_mean_min_statistic(
match metadata[1]["mean_type"]:
case StatisticMeanType.CIRCULAR:
if circular_means := max_mean_min["circular_means"]:
mean_value = weighted_circular_mean(circular_means)
mean_value = weighted_circular_mean(circular_means)[0]
case StatisticMeanType.ARITHMETIC:
if (mean_value := max_mean_min.get("mean_acc")) is not None and (
duration := max_mean_min.get("duration")
@ -1739,12 +1745,12 @@ def statistic_during_period(
_type_column_mapping = {
"last_reset": "last_reset_ts",
"max": "max",
"mean": "mean",
"min": "min",
"state": "state",
"sum": "sum",
"last_reset": ("last_reset_ts",),
"max": ("max",),
"mean": ("mean", "mean_weight"),
"min": ("min",),
"state": ("state",),
"sum": ("sum",),
}
@ -1756,12 +1762,13 @@ def _generate_select_columns_for_types_stmt(
track_on: list[str | None] = [
table.__tablename__, # type: ignore[attr-defined]
]
for key, column in _type_column_mapping.items():
if key in types:
columns = columns.add_columns(getattr(table, column))
track_on.append(column)
else:
track_on.append(None)
for key, type_columns in _type_column_mapping.items():
for column in type_columns:
if key in types:
columns = columns.add_columns(getattr(table, column))
track_on.append(column)
else:
track_on.append(None)
return lambda_stmt(lambda: columns, track_on=track_on)
@ -1944,6 +1951,12 @@ def _statistics_during_period_with_session(
hass, session, start_time, units, _types, table, metadata, result
)
# filter out mean_weight as it is only needed to reduce statistics
# and not needed in the result
for stats_rows in result.values():
for row in stats_rows:
row.pop("mean_weight", None)
# Return statistics combined with metadata
return result
@ -2391,7 +2404,12 @@ def _sorted_statistics_to_dict(
field_map["last_reset"] = field_map.pop("last_reset_ts")
sum_idx = field_map["sum"] if "sum" in types else None
sum_only = len(types) == 1 and sum_idx is not None
row_mapping = tuple((key, field_map[key]) for key in types if key in field_map)
row_mapping = tuple(
(column, field_map[column])
for key in types
for column in ({key, *_type_column_mapping.get(key, ())})
if column in field_map
)
# Append all statistic entries, and optionally do unit conversion
table_duration_seconds = table.duration.total_seconds()
for meta_id, db_rows in stats_by_meta_id.items():

View File

@ -160,7 +160,7 @@ def _time_weighted_arithmetic_mean(
def _time_weighted_circular_mean(
fstates: list[tuple[float, State]], start: datetime.datetime, end: datetime.datetime
) -> float:
) -> tuple[float, float]:
"""Calculate a time weighted circular mean.
The circular mean is calculated by weighting the states by duration in seconds between
@ -623,7 +623,7 @@ def compile_statistics( # noqa: C901
valid_float_states, start, end
)
case StatisticMeanType.CIRCULAR:
stat["mean"] = _time_weighted_circular_mean(
stat["mean"], stat["mean_weight"] = _time_weighted_circular_mean(
valid_float_states, start, end
)

View File

@ -4508,23 +4508,19 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
duration += dur
return total / duration
def _time_weighted_circular_mean(values: list[tuple[float, int]]):
def _weighted_circular_mean(
values: Iterable[tuple[float, float]],
) -> tuple[float, float]:
sin_sum = 0
cos_sum = 0
for x, dur in values:
sin_sum += math.sin(x * DEG_TO_RAD) * dur
cos_sum += math.cos(x * DEG_TO_RAD) * dur
for x, weight in values:
sin_sum += math.sin(x * DEG_TO_RAD) * weight
cos_sum += math.cos(x * DEG_TO_RAD) * weight
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
def _circular_mean(values: list[float]) -> float:
sin_sum = 0
cos_sum = 0
for x in values:
sin_sum += math.sin(x * DEG_TO_RAD)
cos_sum += math.cos(x * DEG_TO_RAD)
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
return (
(RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360,
math.sqrt(sin_sum**2 + cos_sum**2),
)
def _min(seq, last_state):
if last_state is None:
@ -4631,7 +4627,7 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
values = [(seq, durations[j]) for j, seq in enumerate(seq)]
if (state := last_states["sensor.test5"]) is not None:
values.append((state, 5))
expected_means["sensor.test5"].append(_time_weighted_circular_mean(values))
expected_means["sensor.test5"].append(_weighted_circular_mean(values))
last_states["sensor.test5"] = seq[-1]
start += timedelta(minutes=5)
@ -4733,15 +4729,17 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
start = zero
end = zero + timedelta(minutes=5)
for i in range(24):
for entity_id in (
"sensor.test1",
"sensor.test2",
"sensor.test3",
"sensor.test4",
"sensor.test5",
for entity_id, mean_extractor in (
("sensor.test1", lambda x: x),
("sensor.test2", lambda x: x),
("sensor.test3", lambda x: x),
("sensor.test4", lambda x: x),
("sensor.test5", lambda x: x[0]),
):
expected_average = (
expected_means[entity_id][i] if entity_id in expected_means else None
mean_extractor(expected_means[entity_id][i])
if entity_id in expected_means
else None
)
expected_minimum = (
expected_minima[entity_id][i] if entity_id in expected_minima else None
@ -4772,7 +4770,7 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
assert stats == expected_stats
def verify_stats(
period: Literal["5minute", "day", "hour", "week", "month"],
period: Literal["hour", "day", "week", "month"],
start: datetime,
next_datetime: Callable[[datetime], datetime],
) -> None:
@ -4791,7 +4789,7 @@ async def test_compile_statistics_hourly_daily_monthly_summary(
("sensor.test2", mean),
("sensor.test3", mean),
("sensor.test4", mean),
("sensor.test5", _circular_mean),
("sensor.test5", lambda x: _weighted_circular_mean(x)[0]),
):
expected_average = (
mean_fn(expected_means[entity_id][i * 12 : (i + 1) * 12])