Add type hints to Filter (#86165)

This commit is contained in:
epenet 2023-01-19 08:09:18 +01:00 committed by GitHub
parent 4b6157cd9b
commit 74096b87eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,7 +8,7 @@ from functools import partial
import logging import logging
from numbers import Number from numbers import Number
import statistics import statistics
from typing import Any from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -302,14 +302,14 @@ class SensorFilter(SensorEntity):
for filt in self._filters: for filt in self._filters:
if ( if (
filt.window_unit == WINDOW_SIZE_UNIT_NUMBER_EVENTS filt.window_unit == WINDOW_SIZE_UNIT_NUMBER_EVENTS
and largest_window_items < filt.window_size and largest_window_items < (size := cast(int, filt.window_size))
): ):
largest_window_items = filt.window_size largest_window_items = size
elif ( elif (
filt.window_unit == WINDOW_SIZE_UNIT_TIME filt.window_unit == WINDOW_SIZE_UNIT_TIME
and largest_window_time < filt.window_size and largest_window_time < (val := cast(timedelta, filt.window_size))
): ):
largest_window_time = filt.window_size largest_window_time = val
# Retrieve the largest window_size of each type # Retrieve the largest window_size of each type
if largest_window_items > 0: if largest_window_items > 0:
@ -386,11 +386,11 @@ class FilterState:
value = round(float(self.state), precision) value = round(float(self.state), precision)
self.state = int(value) if precision == 0 else value self.state = int(value) if precision == 0 else value
def __str__(self): def __str__(self) -> str:
"""Return state as the string representation of FilterState.""" """Return state as the string representation of FilterState."""
return str(self.state) return str(self.state)
def __repr__(self): def __repr__(self) -> str:
"""Return timestamp and state as the representation of FilterState.""" """Return timestamp and state as the representation of FilterState."""
return f"{self.timestamp} : {self.state}" return f"{self.timestamp} : {self.state}"
@ -412,7 +412,7 @@ class Filter:
:param entity: used for debugging only :param entity: used for debugging only
""" """
if isinstance(window_size, int): if isinstance(window_size, int):
self.states: deque = deque(maxlen=window_size) self.states: deque[FilterState] = deque(maxlen=window_size)
self.window_unit = WINDOW_SIZE_UNIT_NUMBER_EVENTS self.window_unit = WINDOW_SIZE_UNIT_NUMBER_EVENTS
else: else:
self.states = deque(maxlen=0) self.states = deque(maxlen=0)
@ -426,25 +426,25 @@ class Filter:
self._only_numbers = True self._only_numbers = True
@property @property
def window_size(self): def window_size(self) -> int | timedelta:
"""Return window size.""" """Return window size."""
return self._window_size return self._window_size
@property @property
def name(self): def name(self) -> str:
"""Return filter name.""" """Return filter name."""
return self._name return self._name
@property @property
def skip_processing(self): def skip_processing(self) -> bool:
"""Return whether the current filter_state should be skipped.""" """Return whether the current filter_state should be skipped."""
return self._skip_processing return self._skip_processing
def _filter_state(self, new_state): def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement filter.""" """Implement filter."""
raise NotImplementedError() raise NotImplementedError()
def filter_state(self, new_state): def filter_state(self, new_state: State) -> State:
"""Implement a common interface for filters.""" """Implement a common interface for filters."""
fstate = FilterState(new_state) fstate = FilterState(new_state)
if self._only_numbers and not isinstance(fstate.state, Number): if self._only_numbers and not isinstance(fstate.state, Number):
@ -485,7 +485,7 @@ class RangeFilter(Filter, SensorEntity):
self._upper_bound = upper_bound self._upper_bound = upper_bound
self._stats_internal: Counter = Counter() self._stats_internal: Counter = Counter()
def _filter_state(self, new_state): def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the range filter.""" """Implement the range filter."""
if self._upper_bound is not None and new_state.state > self._upper_bound: if self._upper_bound is not None and new_state.state > self._upper_bound:
@ -534,7 +534,7 @@ class OutlierFilter(Filter, SensorEntity):
self._stats_internal: Counter = Counter() self._stats_internal: Counter = Counter()
self._store_raw = True self._store_raw = True
def _filter_state(self, new_state): def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the outlier filter.""" """Implement the outlier filter."""
median = statistics.median([s.state for s in self.states]) if self.states else 0 median = statistics.median([s.state for s in self.states]) if self.states else 0
@ -566,7 +566,7 @@ class LowPassFilter(Filter, SensorEntity):
super().__init__(FILTER_NAME_LOWPASS, window_size, precision, entity) super().__init__(FILTER_NAME_LOWPASS, window_size, precision, entity)
self._time_constant = time_constant self._time_constant = time_constant
def _filter_state(self, new_state): def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the low pass filter.""" """Implement the low pass filter."""
if not self.states: if not self.states:
@ -601,10 +601,10 @@ class TimeSMAFilter(Filter, SensorEntity):
""" """
super().__init__(FILTER_NAME_TIME_SMA, window_size, precision, entity) super().__init__(FILTER_NAME_TIME_SMA, window_size, precision, entity)
self._time_window = window_size self._time_window = window_size
self.last_leak = None self.last_leak: FilterState | None = None
self.queue = deque[FilterState]() self.queue = deque[FilterState]()
def _leak(self, left_boundary): def _leak(self, left_boundary: datetime) -> None:
"""Remove timeouted elements.""" """Remove timeouted elements."""
while self.queue: while self.queue:
if self.queue[0].timestamp + self._time_window <= left_boundary: if self.queue[0].timestamp + self._time_window <= left_boundary:
@ -612,13 +612,13 @@ class TimeSMAFilter(Filter, SensorEntity):
else: else:
return return
def _filter_state(self, new_state): def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the Simple Moving Average filter.""" """Implement the Simple Moving Average filter."""
self._leak(new_state.timestamp) self._leak(new_state.timestamp)
self.queue.append(copy(new_state)) self.queue.append(copy(new_state))
moving_sum = 0 moving_sum: float = 0
start = new_state.timestamp - self._time_window start = new_state.timestamp - self._time_window
prev_state = self.last_leak if self.last_leak is not None else self.queue[0] prev_state = self.last_leak if self.last_leak is not None else self.queue[0]
for state in self.queue: for state in self.queue:
@ -643,7 +643,7 @@ class ThrottleFilter(Filter, SensorEntity):
super().__init__(FILTER_NAME_THROTTLE, window_size, precision, entity) super().__init__(FILTER_NAME_THROTTLE, window_size, precision, entity)
self._only_numbers = False self._only_numbers = False
def _filter_state(self, new_state): def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the throttle filter.""" """Implement the throttle filter."""
if not self.states or len(self.states) == self.states.maxlen: if not self.states or len(self.states) == self.states.maxlen:
self.states.clear() self.states.clear()
@ -665,10 +665,10 @@ class TimeThrottleFilter(Filter, SensorEntity):
"""Initialize Filter.""" """Initialize Filter."""
super().__init__(FILTER_NAME_TIME_THROTTLE, window_size, precision, entity) super().__init__(FILTER_NAME_TIME_THROTTLE, window_size, precision, entity)
self._time_window = window_size self._time_window = window_size
self._last_emitted_at = None self._last_emitted_at: datetime | None = None
self._only_numbers = False self._only_numbers = False
def _filter_state(self, new_state): def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the filter.""" """Implement the filter."""
window_start = new_state.timestamp - self._time_window window_start = new_state.timestamp - self._time_window
if not self._last_emitted_at or self._last_emitted_at <= window_start: if not self._last_emitted_at or self._last_emitted_at <= window_start: