Add type hints to Filter (#86165)

This commit is contained in:
epenet 2023-01-19 08:09:18 +01:00 committed by GitHub
parent 4b6157cd9b
commit 74096b87eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,7 +8,7 @@ from functools import partial
import logging
from numbers import Number
import statistics
from typing import Any
from typing import Any, cast
import voluptuous as vol
@ -302,14 +302,14 @@ class SensorFilter(SensorEntity):
for filt in self._filters:
if (
filt.window_unit == WINDOW_SIZE_UNIT_NUMBER_EVENTS
and largest_window_items < filt.window_size
and largest_window_items < (size := cast(int, filt.window_size))
):
largest_window_items = filt.window_size
largest_window_items = size
elif (
filt.window_unit == WINDOW_SIZE_UNIT_TIME
and largest_window_time < filt.window_size
and largest_window_time < (val := cast(timedelta, filt.window_size))
):
largest_window_time = filt.window_size
largest_window_time = val
# Retrieve the largest window_size of each type
if largest_window_items > 0:
@ -386,11 +386,11 @@ class FilterState:
value = round(float(self.state), precision)
self.state = int(value) if precision == 0 else value
def __str__(self):
def __str__(self) -> str:
"""Return state as the string representation of FilterState."""
return str(self.state)
def __repr__(self):
def __repr__(self) -> str:
"""Return timestamp and state as the representation of FilterState."""
return f"{self.timestamp} : {self.state}"
@ -412,7 +412,7 @@ class Filter:
:param entity: used for debugging only
"""
if isinstance(window_size, int):
self.states: deque = deque(maxlen=window_size)
self.states: deque[FilterState] = deque(maxlen=window_size)
self.window_unit = WINDOW_SIZE_UNIT_NUMBER_EVENTS
else:
self.states = deque(maxlen=0)
@ -426,25 +426,25 @@ class Filter:
self._only_numbers = True
@property
def window_size(self):
def window_size(self) -> int | timedelta:
"""Return window size."""
return self._window_size
@property
def name(self):
def name(self) -> str:
"""Return filter name."""
return self._name
@property
def skip_processing(self):
def skip_processing(self) -> bool:
"""Return whether the current filter_state should be skipped."""
return self._skip_processing
def _filter_state(self, new_state):
def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement filter."""
raise NotImplementedError()
def filter_state(self, new_state):
def filter_state(self, new_state: State) -> State:
"""Implement a common interface for filters."""
fstate = FilterState(new_state)
if self._only_numbers and not isinstance(fstate.state, Number):
@ -485,7 +485,7 @@ class RangeFilter(Filter, SensorEntity):
self._upper_bound = upper_bound
self._stats_internal: Counter = Counter()
def _filter_state(self, new_state):
def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the range filter."""
if self._upper_bound is not None and new_state.state > self._upper_bound:
@ -534,7 +534,7 @@ class OutlierFilter(Filter, SensorEntity):
self._stats_internal: Counter = Counter()
self._store_raw = True
def _filter_state(self, new_state):
def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the outlier filter."""
median = statistics.median([s.state for s in self.states]) if self.states else 0
@ -566,7 +566,7 @@ class LowPassFilter(Filter, SensorEntity):
super().__init__(FILTER_NAME_LOWPASS, window_size, precision, entity)
self._time_constant = time_constant
def _filter_state(self, new_state):
def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the low pass filter."""
if not self.states:
@ -601,10 +601,10 @@ class TimeSMAFilter(Filter, SensorEntity):
"""
super().__init__(FILTER_NAME_TIME_SMA, window_size, precision, entity)
self._time_window = window_size
self.last_leak = None
self.last_leak: FilterState | None = None
self.queue = deque[FilterState]()
def _leak(self, left_boundary):
def _leak(self, left_boundary: datetime) -> None:
"""Remove timeouted elements."""
while self.queue:
if self.queue[0].timestamp + self._time_window <= left_boundary:
@ -612,13 +612,13 @@ class TimeSMAFilter(Filter, SensorEntity):
else:
return
def _filter_state(self, new_state):
def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the Simple Moving Average filter."""
self._leak(new_state.timestamp)
self.queue.append(copy(new_state))
moving_sum = 0
moving_sum: float = 0
start = new_state.timestamp - self._time_window
prev_state = self.last_leak if self.last_leak is not None else self.queue[0]
for state in self.queue:
@ -643,7 +643,7 @@ class ThrottleFilter(Filter, SensorEntity):
super().__init__(FILTER_NAME_THROTTLE, window_size, precision, entity)
self._only_numbers = False
def _filter_state(self, new_state):
def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the throttle filter."""
if not self.states or len(self.states) == self.states.maxlen:
self.states.clear()
@ -665,10 +665,10 @@ class TimeThrottleFilter(Filter, SensorEntity):
"""Initialize Filter."""
super().__init__(FILTER_NAME_TIME_THROTTLE, window_size, precision, entity)
self._time_window = window_size
self._last_emitted_at = None
self._last_emitted_at: datetime | None = None
self._only_numbers = False
def _filter_state(self, new_state):
def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the filter."""
window_start = new_state.timestamp - self._time_window
if not self._last_emitted_at or self._last_emitted_at <= window_start: