Improve typing and code quality in beyesian (#79603)

* strict typing

* Detail implication

* adds newline

* don't change indenting

* really dont change indenting

* Update homeassistant/components/bayesian/binary_sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* typing in async_setup_platform() + remove arg

* less ambiguity

* mypy thinks Literal[False] otherwise

* clearer log

* don't use `and` assignments

* observations not values

* clarify can be None

* observation can't be none

* assert we have at least one

* make it clearer where we're using UUIDs

* remove unnecessary bool

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Unnecessary None handling

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Better type setting

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Reccomended changes.

* remove if statement not needed

* Not strict until _TrackTemplateResultInfo fixed

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
HarvsG 2022-10-07 21:23:25 +01:00 committed by GitHub
parent a18a0b39dd
commit 9d351a3c10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 67 deletions

View File

@ -2,12 +2,18 @@
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable
import logging import logging
from typing import Any from typing import Any
from uuid import UUID
import voluptuous as vol import voluptuous as vol
from homeassistant.components.binary_sensor import PLATFORM_SCHEMA, BinarySensorEntity from homeassistant.components.binary_sensor import (
PLATFORM_SCHEMA,
BinarySensorDeviceClass,
BinarySensorEntity,
)
from homeassistant.const import ( from homeassistant.const import (
CONF_ABOVE, CONF_ABOVE,
CONF_BELOW, CONF_BELOW,
@ -20,18 +26,19 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import ConditionError, TemplateError from homeassistant.exceptions import ConditionError, TemplateError
from homeassistant.helpers import condition from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
TrackTemplate, TrackTemplate,
TrackTemplateResult,
async_track_state_change_event, async_track_state_change_event,
async_track_template_result, async_track_template_result,
) )
from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.template import result_as_boolean from homeassistant.helpers.template import Template, result_as_boolean
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DOMAIN, PLATFORMS from . import DOMAIN, PLATFORMS
@ -107,7 +114,9 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
) )
def update_probability(prior, prob_given_true, prob_given_false): def update_probability(
prior: float, prob_given_true: float, prob_given_false: float
) -> float:
"""Update probability using Bayes' rule.""" """Update probability using Bayes' rule."""
numerator = prob_given_true * prior numerator = prob_given_true * prior
denominator = numerator + prob_given_false * (1 - prior) denominator = numerator + prob_given_false * (1 - prior)
@ -123,18 +132,18 @@ async def async_setup_platform(
"""Set up the Bayesian Binary sensor.""" """Set up the Bayesian Binary sensor."""
await async_setup_reload_service(hass, DOMAIN, PLATFORMS) await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
name = config[CONF_NAME] name: str = config[CONF_NAME]
observations = config[CONF_OBSERVATIONS] observations: list[ConfigType] = config[CONF_OBSERVATIONS]
prior = config[CONF_PRIOR] prior: float = config[CONF_PRIOR]
probability_threshold = config[CONF_PROBABILITY_THRESHOLD] probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD]
device_class = config.get(CONF_DEVICE_CLASS) device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS)
# Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas. # Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas.
broken_observations: list[dict[str, Any]] = [] broken_observations: list[dict[str, Any]] = []
for observation in observations: for observation in observations:
if CONF_P_GIVEN_F not in observation: if CONF_P_GIVEN_F not in observation:
text: str = f"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}" text: str = f"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}"
raise_no_prob_given_false(hass, observation, text) raise_no_prob_given_false(hass, text)
_LOGGER.error("Missing prob_given_false YAML entry for %s", text) _LOGGER.error("Missing prob_given_false YAML entry for %s", text)
broken_observations.append(observation) broken_observations.append(observation)
observations = [x for x in observations if x not in broken_observations] observations = [x for x in observations if x not in broken_observations]
@ -153,7 +162,14 @@ class BayesianBinarySensor(BinarySensorEntity):
_attr_should_poll = False _attr_should_poll = False
def __init__(self, name, prior, observations, probability_threshold, device_class): def __init__(
self,
name: str,
prior: float,
observations: list[ConfigType],
probability_threshold: float,
device_class: BinarySensorDeviceClass | None,
) -> None:
"""Initialize the Bayesian sensor.""" """Initialize the Bayesian sensor."""
self._attr_name = name self._attr_name = name
self._observations = [ self._observations = [
@ -173,17 +189,17 @@ class BayesianBinarySensor(BinarySensorEntity):
self._probability_threshold = probability_threshold self._probability_threshold = probability_threshold
self._attr_device_class = device_class self._attr_device_class = device_class
self._attr_is_on = False self._attr_is_on = False
self._callbacks = [] self._callbacks: list = []
self.prior = prior self.prior = prior
self.probability = prior self.probability = prior
self.current_observations = OrderedDict({}) self.current_observations: OrderedDict[UUID, Observation] = OrderedDict({})
self.observations_by_entity = self._build_observations_by_entity() self.observations_by_entity = self._build_observations_by_entity()
self.observations_by_template = self._build_observations_by_template() self.observations_by_template = self._build_observations_by_template()
self.observation_handlers = { self.observation_handlers: dict[str, Callable[[Observation], bool | None]] = {
"numeric_state": self._process_numeric_state, "numeric_state": self._process_numeric_state,
"state": self._process_state, "state": self._process_state,
"multi_state": self._process_multi_state, "multi_state": self._process_multi_state,
@ -205,7 +221,7 @@ class BayesianBinarySensor(BinarySensorEntity):
""" """
@callback @callback
def async_threshold_sensor_state_listener(event): def async_threshold_sensor_state_listener(event: Event) -> None:
""" """
Handle sensor state changes. Handle sensor state changes.
@ -213,7 +229,7 @@ class BayesianBinarySensor(BinarySensorEntity):
then calculate the new probability. then calculate the new probability.
""" """
entity = event.data.get("entity_id") entity: str = event.data[CONF_ENTITY_ID]
self.current_observations.update(self._record_entity_observations(entity)) self.current_observations.update(self._record_entity_observations(entity))
self.async_set_context(event.context) self.async_set_context(event.context)
@ -228,11 +244,15 @@ class BayesianBinarySensor(BinarySensorEntity):
) )
@callback @callback
def _async_template_result_changed(event, updates): def _async_template_result_changed(
event: Event | None, updates: list[TrackTemplateResult]
) -> None:
track_template_result = updates.pop() track_template_result = updates.pop()
template = track_template_result.template template = track_template_result.template
result = track_template_result.result result = track_template_result.result
entity = event and event.data.get("entity_id") entity: str | None = (
None if event is None else event.data.get(CONF_ENTITY_ID)
)
if isinstance(result, TemplateError): if isinstance(result, TemplateError):
_LOGGER.error( _LOGGER.error(
"TemplateError('%s') " "TemplateError('%s') "
@ -252,7 +272,7 @@ class BayesianBinarySensor(BinarySensorEntity):
# in some cases a template may update because of the absence of an entity # in some cases a template may update because of the absence of an entity
if entity is not None: if entity is not None:
observation.entity_id = str(entity) observation.entity_id = entity
self.current_observations[observation.id] = observation self.current_observations[observation.id] = observation
@ -273,7 +293,7 @@ class BayesianBinarySensor(BinarySensorEntity):
self.current_observations.update(self._initialize_current_observations()) self.current_observations.update(self._initialize_current_observations())
self.probability = self._calculate_new_probability() self.probability = self._calculate_new_probability()
self._attr_is_on = bool(self.probability >= self._probability_threshold) self._attr_is_on = self.probability >= self._probability_threshold
# detect mirrored entries # detect mirrored entries
for entity, observations in self.observations_by_entity.items(): for entity, observations in self.observations_by_entity.items():
@ -281,9 +301,9 @@ class BayesianBinarySensor(BinarySensorEntity):
self.hass, observations, text=f"{self._attr_name}/{entity}" self.hass, observations, text=f"{self._attr_name}/{entity}"
) )
all_template_observations = [] all_template_observations: list[Observation] = []
for value in self.observations_by_template.values(): for observations in self.observations_by_template.values():
all_template_observations.append(value[0]) all_template_observations.append(observations[0])
if len(all_template_observations) == 2: if len(all_template_observations) == 2:
raise_mirrored_entries( raise_mirrored_entries(
self.hass, self.hass,
@ -292,62 +312,63 @@ class BayesianBinarySensor(BinarySensorEntity):
) )
@callback @callback
def _recalculate_and_write_state(self): def _recalculate_and_write_state(self) -> None:
self.probability = self._calculate_new_probability() self.probability = self._calculate_new_probability()
self._attr_is_on = bool(self.probability >= self._probability_threshold) self._attr_is_on = bool(self.probability >= self._probability_threshold)
self.async_write_ha_state() self.async_write_ha_state()
def _initialize_current_observations(self): def _initialize_current_observations(self) -> OrderedDict[UUID, Observation]:
local_observations = OrderedDict({}) local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
for entity in self.observations_by_entity: for entity in self.observations_by_entity:
local_observations.update(self._record_entity_observations(entity)) local_observations.update(self._record_entity_observations(entity))
return local_observations return local_observations
def _record_entity_observations(self, entity): def _record_entity_observations(
local_observations = OrderedDict({}) self, entity: str
) -> OrderedDict[UUID, Observation]:
local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
for observation in self.observations_by_entity[entity]: for observation in self.observations_by_entity[entity]:
platform = observation.platform platform = observation.platform
observed = self.observation_handlers[platform](observation) observation.observed = self.observation_handlers[platform](observation)
observation.observed = observed
local_observations[observation.id] = observation local_observations[observation.id] = observation
return local_observations return local_observations
def _calculate_new_probability(self): def _calculate_new_probability(self) -> float:
prior = self.prior prior = self.prior
for observation in self.current_observations.values(): for observation in self.current_observations.values():
if observation is not None:
if observation.observed is True: if observation.observed is True:
prior = update_probability( prior = update_probability(
prior, prior,
observation.prob_given_true, observation.prob_given_true,
observation.prob_given_false, observation.prob_given_false,
) )
elif observation.observed is False: continue
if observation.observed is False:
prior = update_probability( prior = update_probability(
prior, prior,
1 - observation.prob_given_true, 1 - observation.prob_given_true,
1 - observation.prob_given_false, 1 - observation.prob_given_false,
) )
elif observation.observed is None: continue
# observation.observed is None
if observation.entity_id is not None: if observation.entity_id is not None:
_LOGGER.debug( _LOGGER.debug(
"Observation for entity '%s' returned None, it will not be used for Bayesian updating", "Observation for entity '%s' returned None, it will not be used for Bayesian updating",
observation.entity_id, observation.entity_id,
) )
else: continue
_LOGGER.debug( _LOGGER.debug(
"Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating", "Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating",
) )
# the prior has been updated and is now the posterior
return prior return prior
def _build_observations_by_entity(self): def _build_observations_by_entity(self) -> dict[str, list[Observation]]:
""" """
Build and return data structure of the form below. Build and return data structure of the form below.
@ -378,7 +399,7 @@ class BayesianBinarySensor(BinarySensorEntity):
return observations_by_entity return observations_by_entity
def _build_observations_by_template(self): def _build_observations_by_template(self) -> dict[Template, list[Observation]]:
""" """
Build and return data structure of the form below. Build and return data structure of the form below.
@ -392,7 +413,7 @@ class BayesianBinarySensor(BinarySensorEntity):
for all relevant observations to be looked up via their `template`. for all relevant observations to be looked up via their `template`.
""" """
observations_by_template = {} observations_by_template: dict[Template, list[Observation]] = {}
for observation in self._observations: for observation in self._observations:
if observation.value_template is None: if observation.value_template is None:
continue continue
@ -402,7 +423,7 @@ class BayesianBinarySensor(BinarySensorEntity):
return observations_by_template return observations_by_template
def _process_numeric_state(self, entity_observation): def _process_numeric_state(self, entity_observation: Observation) -> bool | None:
"""Return True if numeric condition is met, return False if not, return None otherwise.""" """Return True if numeric condition is met, return False if not, return None otherwise."""
entity = entity_observation.entity_id entity = entity_observation.entity_id
@ -420,7 +441,7 @@ class BayesianBinarySensor(BinarySensorEntity):
except ConditionError: except ConditionError:
return None return None
def _process_state(self, entity_observation): def _process_state(self, entity_observation: Observation) -> bool | None:
"""Return True if state conditions are met, return False if they are not. """Return True if state conditions are met, return False if they are not.
Returns None if the state is unavailable. Returns None if the state is unavailable.
@ -436,7 +457,7 @@ class BayesianBinarySensor(BinarySensorEntity):
except ConditionError: except ConditionError:
return None return None
def _process_multi_state(self, entity_observation): def _process_multi_state(self, entity_observation: Observation) -> bool | None:
"""Return True if state conditions are met, otherwise return None. """Return True if state conditions are met, otherwise return None.
Never return False as all other states should have their own probabilities configured. Never return False as all other states should have their own probabilities configured.
@ -452,7 +473,7 @@ class BayesianBinarySensor(BinarySensorEntity):
return None return None
@property @property
def extra_state_attributes(self): def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes of the sensor.""" """Return the state attributes of the sensor."""
return { return {

View File

@ -18,7 +18,10 @@ from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE
@dataclass @dataclass
class Observation: class Observation:
"""Representation of a sensor or template observation.""" """Representation of a sensor or template observation.
Either entity_id or value_template should be non-None.
"""
entity_id: str | None entity_id: str | None
platform: str platform: str
@ -29,7 +32,7 @@ class Observation:
below: float | None below: float | None
value_template: Template | None value_template: Template | None
observed: bool | None = None observed: bool | None = None
id: str = field(default_factory=lambda: str(uuid.uuid4())) id: uuid.UUID = field(default_factory=uuid.uuid4)
def to_dict(self) -> dict[str, str | float | bool | None]: def to_dict(self) -> dict[str, str | float | bool | None]:
"""Represent Class as a Dict for easier serialization.""" """Represent Class as a Dict for easier serialization."""

View File

@ -5,9 +5,12 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import issue_registry from homeassistant.helpers import issue_registry
from . import DOMAIN from . import DOMAIN
from .helpers import Observation
def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") -> None: def raise_mirrored_entries(
hass: HomeAssistant, observations: list[Observation], text: str = ""
) -> None:
"""If there are mirrored entries, the user is probably using a workaround for a patched bug.""" """If there are mirrored entries, the user is probably using a workaround for a patched bug."""
if len(observations) != 2: if len(observations) != 2:
return return
@ -26,7 +29,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") ->
# Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas. # Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas.
def raise_no_prob_given_false(hass: HomeAssistant, observation, text: str) -> None: def raise_no_prob_given_false(hass: HomeAssistant, text: str) -> None:
"""In previous 2022.9 and earlier, prob_given_false was optional and had a default version.""" """In previous 2022.9 and earlier, prob_given_false was optional and had a default version."""
issue_registry.async_create_issue( issue_registry.async_create_issue(
hass, hass,