Refactor bayesian observations using dataclass (#79590)

* refactor

* remove some changes

* remove typehint

* improve codestyle

* move docstring to comment

* < 88 chars

* avoid short var names

* more readable

* fix rename

* Update homeassistant/components/bayesian/helpers.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update homeassistant/components/bayesian/binary_sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update homeassistant/components/bayesian/binary_sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* no intermediate

* comment why set before list

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
HarvsG 2022-10-04 16:16:39 +01:00 committed by GitHub
parent 56dd0a6867
commit dd1463da28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 105 deletions

View File

@ -35,24 +35,24 @@ from homeassistant.helpers.template import result_as_boolean
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DOMAIN, PLATFORMS from . import DOMAIN, PLATFORMS
from .const import (
ATTR_OBSERVATIONS,
ATTR_OCCURRED_OBSERVATION_ENTITIES,
ATTR_PROBABILITY,
ATTR_PROBABILITY_THRESHOLD,
CONF_OBSERVATIONS,
CONF_P_GIVEN_F,
CONF_P_GIVEN_T,
CONF_PRIOR,
CONF_PROBABILITY_THRESHOLD,
CONF_TEMPLATE,
CONF_TO_STATE,
DEFAULT_NAME,
DEFAULT_PROBABILITY_THRESHOLD,
)
from .helpers import Observation
from .repairs import raise_mirrored_entries, raise_no_prob_given_false from .repairs import raise_mirrored_entries, raise_no_prob_given_false
ATTR_OBSERVATIONS = "observations"
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
ATTR_PROBABILITY = "probability"
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_OBSERVATIONS = "observations"
CONF_PRIOR = "prior"
CONF_TEMPLATE = "template"
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_P_GIVEN_F = "prob_given_false"
CONF_P_GIVEN_T = "prob_given_true"
CONF_TO_STATE = "to_state"
DEFAULT_NAME = "Bayesian Binary Sensor"
DEFAULT_PROBABILITY_THRESHOLD = 0.5
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -156,7 +156,20 @@ class BayesianBinarySensor(BinarySensorEntity):
def __init__(self, name, prior, observations, probability_threshold, device_class): def __init__(self, name, prior, observations, probability_threshold, device_class):
"""Initialize the Bayesian sensor.""" """Initialize the Bayesian sensor."""
self._attr_name = name self._attr_name = name
self._observations = observations self._observations = [
Observation(
entity_id=observation.get(CONF_ENTITY_ID),
platform=observation[CONF_PLATFORM],
prob_given_false=observation[CONF_P_GIVEN_F],
prob_given_true=observation[CONF_P_GIVEN_T],
observed=None,
to_state=observation.get(CONF_TO_STATE),
above=observation.get(CONF_ABOVE),
below=observation.get(CONF_BELOW),
value_template=observation.get(CONF_VALUE_TEMPLATE),
)
for observation in observations
]
self._probability_threshold = probability_threshold self._probability_threshold = probability_threshold
self._attr_device_class = device_class self._attr_device_class = device_class
self._attr_is_on = False self._attr_is_on = False
@ -230,13 +243,18 @@ class BayesianBinarySensor(BinarySensorEntity):
self.entity_id, self.entity_id,
) )
observation = None observed = None
else: else:
observation = result_as_boolean(result) observed = result_as_boolean(result)
for obs in self.observations_by_template[template]: for observation in self.observations_by_template[template]:
obs_entry = {"entity_id": entity, "observation": observation, **obs} observation.observed = observed
self.current_observations[obs["id"]] = obs_entry
# in some cases a template may update because of the absence of an entity
if entity is not None:
observation.entity_id = str(entity)
self.current_observations[observation.id] = observation
if event: if event:
self.async_set_context(event.context) self.async_set_context(event.context)
@ -270,7 +288,7 @@ class BayesianBinarySensor(BinarySensorEntity):
raise_mirrored_entries( raise_mirrored_entries(
self.hass, self.hass,
all_template_observations, all_template_observations,
text=f"{self._attr_name}/{all_template_observations[0]['value_template']}", text=f"{self._attr_name}/{all_template_observations[0].value_template}",
) )
@callback @callback
@ -289,42 +307,38 @@ class BayesianBinarySensor(BinarySensorEntity):
def _record_entity_observations(self, entity): def _record_entity_observations(self, entity):
local_observations = OrderedDict({}) local_observations = OrderedDict({})
for entity_obs in self.observations_by_entity[entity]: for observation in self.observations_by_entity[entity]:
platform = entity_obs["platform"] platform = observation.platform
observation = self.observation_handlers[platform](entity_obs) observed = self.observation_handlers[platform](observation)
observation.observed = observed
obs_entry = { local_observations[observation.id] = observation
"entity_id": entity,
"observation": observation,
**entity_obs,
}
local_observations[entity_obs["id"]] = obs_entry
return local_observations return local_observations
def _calculate_new_probability(self): def _calculate_new_probability(self):
prior = self.prior prior = self.prior
for obs in self.current_observations.values(): for observation in self.current_observations.values():
if obs is not None: if observation is not None:
if obs["observation"] is True: if observation.observed is True:
prior = update_probability( prior = update_probability(
prior, prior,
obs["prob_given_true"], observation.prob_given_true,
obs["prob_given_false"], observation.prob_given_false,
) )
elif obs["observation"] is False: elif observation.observed is False:
prior = update_probability( prior = update_probability(
prior, prior,
1 - obs["prob_given_true"], 1 - observation.prob_given_true,
1 - obs["prob_given_false"], 1 - observation.prob_given_false,
) )
elif obs["observation"] is None: elif observation.observed is None:
if obs["entity_id"] is not None: if observation.entity_id is not None:
_LOGGER.debug( _LOGGER.debug(
"Observation for entity '%s' returned None, it will not be used for Bayesian updating", "Observation for entity '%s' returned None, it will not be used for Bayesian updating",
obs["entity_id"], observation.entity_id,
) )
else: else:
_LOGGER.debug( _LOGGER.debug(
@ -338,8 +352,8 @@ class BayesianBinarySensor(BinarySensorEntity):
Build and return data structure of the form below. Build and return data structure of the form below.
{ {
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}], "sensor.sensor1": [Observation, Observation],
"sensor.sensor2": [{"id": 2, ...}], "sensor.sensor2": [Observation],
... ...
} }
@ -347,21 +361,20 @@ class BayesianBinarySensor(BinarySensorEntity):
for all relevant observations to be looked up via their `entity_id`. for all relevant observations to be looked up via their `entity_id`.
""" """
observations_by_entity: dict[str, list[OrderedDict]] = {} observations_by_entity: dict[str, list[Observation]] = {}
for i, obs in enumerate(self._observations): for observation in self._observations:
obs["id"] = i
if "entity_id" not in obs: if (key := observation.entity_id) is None:
continue continue
observations_by_entity.setdefault(obs["entity_id"], []).append(obs) observations_by_entity.setdefault(key, []).append(observation)
for li_of_dicts in observations_by_entity.values(): for entity_observations in observations_by_entity.values():
if len(li_of_dicts) == 1: if len(entity_observations) == 1:
continue continue
for ord_dict in li_of_dicts: for observation in entity_observations:
if ord_dict["platform"] != "state": if observation.platform != "state":
continue continue
ord_dict["platform"] = "multi_state" observation.platform = "multi_state"
return observations_by_entity return observations_by_entity
@ -370,8 +383,8 @@ class BayesianBinarySensor(BinarySensorEntity):
Build and return data structure of the form below. Build and return data structure of the form below.
{ {
"template": [{"id": 0, ...}, {"id": 1, ...}], "template": [Observation, Observation],
"template2": [{"id": 2, ...}], "template2": [Observation],
... ...
} }
@ -380,20 +393,18 @@ class BayesianBinarySensor(BinarySensorEntity):
""" """
observations_by_template = {} observations_by_template = {}
for ind, obs in enumerate(self._observations): for observation in self._observations:
obs["id"] = ind if observation.value_template is None:
if "value_template" not in obs:
continue continue
template = obs.get(CONF_VALUE_TEMPLATE) template = observation.value_template
observations_by_template.setdefault(template, []).append(obs) observations_by_template.setdefault(template, []).append(observation)
return observations_by_template return observations_by_template
def _process_numeric_state(self, entity_observation): def _process_numeric_state(self, entity_observation):
"""Return True if numeric condition is met, return False if not, return None otherwise.""" """Return True if numeric condition is met, return False if not, return None otherwise."""
entity = entity_observation["entity_id"] entity = entity_observation.entity_id
try: try:
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]): if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
@ -401,61 +412,67 @@ class BayesianBinarySensor(BinarySensorEntity):
return condition.async_numeric_state( return condition.async_numeric_state(
self.hass, self.hass,
entity, entity,
entity_observation.get("below"), entity_observation.below,
entity_observation.get("above"), entity_observation.above,
None, None,
entity_observation, entity_observation.to_dict(),
) )
except ConditionError: except ConditionError:
return None return None
def _process_state(self, entity_observation): def _process_state(self, entity_observation):
"""Return True if state conditions are met.""" """Return True if state conditions are met, return False if they are not.
entity = entity_observation["entity_id"]
Returns None if the state is unavailable.
"""
entity = entity_observation.entity_id
try: try:
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]): if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
return None return None
return condition.state( return condition.state(self.hass, entity, entity_observation.to_state)
self.hass, entity, entity_observation.get("to_state")
)
except ConditionError: except ConditionError:
return None return None
def _process_multi_state(self, entity_observation): def _process_multi_state(self, entity_observation):
"""Return True if state conditions are met.""" """Return True if state conditions are met, otherwise return None.
entity = entity_observation["entity_id"]
Never return False as all other states should have their own probabilities configured.
"""
entity = entity_observation.entity_id
try: try:
if condition.state(self.hass, entity, entity_observation.get("to_state")): if condition.state(self.hass, entity, entity_observation.to_state):
return True return True
except ConditionError: except ConditionError:
return None return None
return None
@property @property
def extra_state_attributes(self): def extra_state_attributes(self):
"""Return the state attributes of the sensor.""" """Return the state attributes of the sensor."""
attr_observations_list = [
obs.copy() for obs in self.current_observations.values() if obs is not None
]
for item in attr_observations_list:
item.pop("value_template", None)
return { return {
ATTR_OBSERVATIONS: attr_observations_list,
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
{
obs.get("entity_id")
for obs in self.current_observations.values()
if obs is not None
and obs.get("entity_id") is not None
and obs.get("observation") is not None
}
),
ATTR_PROBABILITY: round(self.probability, 2), ATTR_PROBABILITY: round(self.probability, 2),
ATTR_PROBABILITY_THRESHOLD: self._probability_threshold, ATTR_PROBABILITY_THRESHOLD: self._probability_threshold,
# An entity can be in more than one observation so set then list to deduplicate
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
{
observation.entity_id
for observation in self.current_observations.values()
if observation is not None
and observation.entity_id is not None
and observation.observed is not None
}
),
ATTR_OBSERVATIONS: [
observation.to_dict()
for observation in self.current_observations.values()
if observation is not None
],
} }
async def async_update(self) -> None: async def async_update(self) -> None:

View File

@ -0,0 +1,17 @@
"""Consts for using in modules."""
ATTR_OBSERVATIONS = "observations"
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
ATTR_PROBABILITY = "probability"
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_OBSERVATIONS = "observations"
CONF_PRIOR = "prior"
CONF_TEMPLATE = "template"
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_P_GIVEN_F = "prob_given_false"
CONF_P_GIVEN_T = "prob_given_true"
CONF_TO_STATE = "to_state"
DEFAULT_NAME = "Bayesian Binary Sensor"
DEFAULT_PROBABILITY_THRESHOLD = 0.5

View File

@ -0,0 +1,69 @@
"""Helpers to deal with bayesian observations."""
from __future__ import annotations
from dataclasses import dataclass, field
import uuid
from homeassistant.const import (
CONF_ABOVE,
CONF_BELOW,
CONF_ENTITY_ID,
CONF_PLATFORM,
CONF_VALUE_TEMPLATE,
)
from homeassistant.helpers.template import Template
from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE
@dataclass
class Observation:
"""Representation of a sensor or template observation."""
entity_id: str | None
platform: str
prob_given_true: float
prob_given_false: float
to_state: str | None
above: float | None
below: float | None
value_template: Template | None
observed: bool | None = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
def to_dict(self) -> dict[str, str | float | bool | None]:
"""Represent Class as a Dict for easier serialization."""
# Needed because dataclasses asdict() can't serialize Templates and ignores Properties.
dic = {
CONF_PLATFORM: self.platform,
CONF_ENTITY_ID: self.entity_id,
CONF_VALUE_TEMPLATE: self.template,
CONF_TO_STATE: self.to_state,
CONF_ABOVE: self.above,
CONF_BELOW: self.below,
CONF_P_GIVEN_T: self.prob_given_true,
CONF_P_GIVEN_F: self.prob_given_false,
"observed": self.observed,
}
for key, value in dic.copy().items():
if value is None:
del dic[key]
return dic
def is_mirror(self, other: Observation) -> bool:
"""Dectects whether given observation is a mirror of this one."""
return (
self.platform == other.platform
and round(self.prob_given_true + other.prob_given_true, 1) == 1
and round(self.prob_given_false + other.prob_given_false, 1) == 1
)
@property
def template(self) -> str | None:
"""Not all observations have templates and we want to get template strings."""
if self.value_template is not None:
return self.value_template.template
return None

View File

@ -11,20 +11,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") ->
"""If there are mirrored entries, the user is probably using a workaround for a patched bug.""" """If there are mirrored entries, the user is probably using a workaround for a patched bug."""
if len(observations) != 2: if len(observations) != 2:
return return
true_sums_1: bool = ( if observations[0].is_mirror(observations[1]):
round(
observations[0]["prob_given_true"] + observations[1]["prob_given_true"], 1
)
== 1.0
)
false_sums_1: bool = (
round(
observations[0]["prob_given_false"] + observations[1]["prob_given_false"], 1
)
== 1.0
)
same_states: bool = observations[0]["platform"] == observations[1]["platform"]
if true_sums_1 & false_sums_1 & same_states:
issue_registry.async_create_issue( issue_registry.async_create_issue(
hass, hass,
DOMAIN, DOMAIN,