mirror of
https://github.com/home-assistant/core.git
synced 2025-04-27 02:37:50 +00:00
Refactor bayesian observations using dataclass (#79590)
* refactor * remove some changes * remove typehint * improve codestyle * move docstring to comment * < 88 chars * avoid short var names * more readable * fix rename * Update homeassistant/components/bayesian/helpers.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Update homeassistant/components/bayesian/binary_sensor.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Update homeassistant/components/bayesian/binary_sensor.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * no intermediate * comment why set before list Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
parent
56dd0a6867
commit
dd1463da28
@ -35,24 +35,24 @@ from homeassistant.helpers.template import result_as_boolean
|
|||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from . import DOMAIN, PLATFORMS
|
from . import DOMAIN, PLATFORMS
|
||||||
|
from .const import (
|
||||||
|
ATTR_OBSERVATIONS,
|
||||||
|
ATTR_OCCURRED_OBSERVATION_ENTITIES,
|
||||||
|
ATTR_PROBABILITY,
|
||||||
|
ATTR_PROBABILITY_THRESHOLD,
|
||||||
|
CONF_OBSERVATIONS,
|
||||||
|
CONF_P_GIVEN_F,
|
||||||
|
CONF_P_GIVEN_T,
|
||||||
|
CONF_PRIOR,
|
||||||
|
CONF_PROBABILITY_THRESHOLD,
|
||||||
|
CONF_TEMPLATE,
|
||||||
|
CONF_TO_STATE,
|
||||||
|
DEFAULT_NAME,
|
||||||
|
DEFAULT_PROBABILITY_THRESHOLD,
|
||||||
|
)
|
||||||
|
from .helpers import Observation
|
||||||
from .repairs import raise_mirrored_entries, raise_no_prob_given_false
|
from .repairs import raise_mirrored_entries, raise_no_prob_given_false
|
||||||
|
|
||||||
ATTR_OBSERVATIONS = "observations"
|
|
||||||
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
|
|
||||||
ATTR_PROBABILITY = "probability"
|
|
||||||
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
|
|
||||||
|
|
||||||
CONF_OBSERVATIONS = "observations"
|
|
||||||
CONF_PRIOR = "prior"
|
|
||||||
CONF_TEMPLATE = "template"
|
|
||||||
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
|
|
||||||
CONF_P_GIVEN_F = "prob_given_false"
|
|
||||||
CONF_P_GIVEN_T = "prob_given_true"
|
|
||||||
CONF_TO_STATE = "to_state"
|
|
||||||
|
|
||||||
DEFAULT_NAME = "Bayesian Binary Sensor"
|
|
||||||
DEFAULT_PROBABILITY_THRESHOLD = 0.5
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -156,7 +156,20 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
def __init__(self, name, prior, observations, probability_threshold, device_class):
|
def __init__(self, name, prior, observations, probability_threshold, device_class):
|
||||||
"""Initialize the Bayesian sensor."""
|
"""Initialize the Bayesian sensor."""
|
||||||
self._attr_name = name
|
self._attr_name = name
|
||||||
self._observations = observations
|
self._observations = [
|
||||||
|
Observation(
|
||||||
|
entity_id=observation.get(CONF_ENTITY_ID),
|
||||||
|
platform=observation[CONF_PLATFORM],
|
||||||
|
prob_given_false=observation[CONF_P_GIVEN_F],
|
||||||
|
prob_given_true=observation[CONF_P_GIVEN_T],
|
||||||
|
observed=None,
|
||||||
|
to_state=observation.get(CONF_TO_STATE),
|
||||||
|
above=observation.get(CONF_ABOVE),
|
||||||
|
below=observation.get(CONF_BELOW),
|
||||||
|
value_template=observation.get(CONF_VALUE_TEMPLATE),
|
||||||
|
)
|
||||||
|
for observation in observations
|
||||||
|
]
|
||||||
self._probability_threshold = probability_threshold
|
self._probability_threshold = probability_threshold
|
||||||
self._attr_device_class = device_class
|
self._attr_device_class = device_class
|
||||||
self._attr_is_on = False
|
self._attr_is_on = False
|
||||||
@ -230,13 +243,18 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
self.entity_id,
|
self.entity_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
observation = None
|
observed = None
|
||||||
else:
|
else:
|
||||||
observation = result_as_boolean(result)
|
observed = result_as_boolean(result)
|
||||||
|
|
||||||
for obs in self.observations_by_template[template]:
|
for observation in self.observations_by_template[template]:
|
||||||
obs_entry = {"entity_id": entity, "observation": observation, **obs}
|
observation.observed = observed
|
||||||
self.current_observations[obs["id"]] = obs_entry
|
|
||||||
|
# in some cases a template may update because of the absence of an entity
|
||||||
|
if entity is not None:
|
||||||
|
observation.entity_id = str(entity)
|
||||||
|
|
||||||
|
self.current_observations[observation.id] = observation
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
self.async_set_context(event.context)
|
self.async_set_context(event.context)
|
||||||
@ -270,7 +288,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
raise_mirrored_entries(
|
raise_mirrored_entries(
|
||||||
self.hass,
|
self.hass,
|
||||||
all_template_observations,
|
all_template_observations,
|
||||||
text=f"{self._attr_name}/{all_template_observations[0]['value_template']}",
|
text=f"{self._attr_name}/{all_template_observations[0].value_template}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -289,42 +307,38 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
def _record_entity_observations(self, entity):
|
def _record_entity_observations(self, entity):
|
||||||
local_observations = OrderedDict({})
|
local_observations = OrderedDict({})
|
||||||
|
|
||||||
for entity_obs in self.observations_by_entity[entity]:
|
for observation in self.observations_by_entity[entity]:
|
||||||
platform = entity_obs["platform"]
|
platform = observation.platform
|
||||||
|
|
||||||
observation = self.observation_handlers[platform](entity_obs)
|
observed = self.observation_handlers[platform](observation)
|
||||||
|
observation.observed = observed
|
||||||
|
|
||||||
obs_entry = {
|
local_observations[observation.id] = observation
|
||||||
"entity_id": entity,
|
|
||||||
"observation": observation,
|
|
||||||
**entity_obs,
|
|
||||||
}
|
|
||||||
local_observations[entity_obs["id"]] = obs_entry
|
|
||||||
|
|
||||||
return local_observations
|
return local_observations
|
||||||
|
|
||||||
def _calculate_new_probability(self):
|
def _calculate_new_probability(self):
|
||||||
prior = self.prior
|
prior = self.prior
|
||||||
|
|
||||||
for obs in self.current_observations.values():
|
for observation in self.current_observations.values():
|
||||||
if obs is not None:
|
if observation is not None:
|
||||||
if obs["observation"] is True:
|
if observation.observed is True:
|
||||||
prior = update_probability(
|
prior = update_probability(
|
||||||
prior,
|
prior,
|
||||||
obs["prob_given_true"],
|
observation.prob_given_true,
|
||||||
obs["prob_given_false"],
|
observation.prob_given_false,
|
||||||
)
|
)
|
||||||
elif obs["observation"] is False:
|
elif observation.observed is False:
|
||||||
prior = update_probability(
|
prior = update_probability(
|
||||||
prior,
|
prior,
|
||||||
1 - obs["prob_given_true"],
|
1 - observation.prob_given_true,
|
||||||
1 - obs["prob_given_false"],
|
1 - observation.prob_given_false,
|
||||||
)
|
)
|
||||||
elif obs["observation"] is None:
|
elif observation.observed is None:
|
||||||
if obs["entity_id"] is not None:
|
if observation.entity_id is not None:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
|
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
|
||||||
obs["entity_id"],
|
observation.entity_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
@ -338,8 +352,8 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
Build and return data structure of the form below.
|
Build and return data structure of the form below.
|
||||||
|
|
||||||
{
|
{
|
||||||
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
|
"sensor.sensor1": [Observation, Observation],
|
||||||
"sensor.sensor2": [{"id": 2, ...}],
|
"sensor.sensor2": [Observation],
|
||||||
...
|
...
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,21 +361,20 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
for all relevant observations to be looked up via their `entity_id`.
|
for all relevant observations to be looked up via their `entity_id`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
observations_by_entity: dict[str, list[OrderedDict]] = {}
|
observations_by_entity: dict[str, list[Observation]] = {}
|
||||||
for i, obs in enumerate(self._observations):
|
for observation in self._observations:
|
||||||
obs["id"] = i
|
|
||||||
|
|
||||||
if "entity_id" not in obs:
|
if (key := observation.entity_id) is None:
|
||||||
continue
|
continue
|
||||||
observations_by_entity.setdefault(obs["entity_id"], []).append(obs)
|
observations_by_entity.setdefault(key, []).append(observation)
|
||||||
|
|
||||||
for li_of_dicts in observations_by_entity.values():
|
for entity_observations in observations_by_entity.values():
|
||||||
if len(li_of_dicts) == 1:
|
if len(entity_observations) == 1:
|
||||||
continue
|
continue
|
||||||
for ord_dict in li_of_dicts:
|
for observation in entity_observations:
|
||||||
if ord_dict["platform"] != "state":
|
if observation.platform != "state":
|
||||||
continue
|
continue
|
||||||
ord_dict["platform"] = "multi_state"
|
observation.platform = "multi_state"
|
||||||
|
|
||||||
return observations_by_entity
|
return observations_by_entity
|
||||||
|
|
||||||
@ -370,8 +383,8 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
Build and return data structure of the form below.
|
Build and return data structure of the form below.
|
||||||
|
|
||||||
{
|
{
|
||||||
"template": [{"id": 0, ...}, {"id": 1, ...}],
|
"template": [Observation, Observation],
|
||||||
"template2": [{"id": 2, ...}],
|
"template2": [Observation],
|
||||||
...
|
...
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -380,20 +393,18 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
observations_by_template = {}
|
observations_by_template = {}
|
||||||
for ind, obs in enumerate(self._observations):
|
for observation in self._observations:
|
||||||
obs["id"] = ind
|
if observation.value_template is None:
|
||||||
|
|
||||||
if "value_template" not in obs:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
template = obs.get(CONF_VALUE_TEMPLATE)
|
template = observation.value_template
|
||||||
observations_by_template.setdefault(template, []).append(obs)
|
observations_by_template.setdefault(template, []).append(observation)
|
||||||
|
|
||||||
return observations_by_template
|
return observations_by_template
|
||||||
|
|
||||||
def _process_numeric_state(self, entity_observation):
|
def _process_numeric_state(self, entity_observation):
|
||||||
"""Return True if numeric condition is met, return False if not, return None otherwise."""
|
"""Return True if numeric condition is met, return False if not, return None otherwise."""
|
||||||
entity = entity_observation["entity_id"]
|
entity = entity_observation.entity_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
||||||
@ -401,61 +412,67 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||||||
return condition.async_numeric_state(
|
return condition.async_numeric_state(
|
||||||
self.hass,
|
self.hass,
|
||||||
entity,
|
entity,
|
||||||
entity_observation.get("below"),
|
entity_observation.below,
|
||||||
entity_observation.get("above"),
|
entity_observation.above,
|
||||||
None,
|
None,
|
||||||
entity_observation,
|
entity_observation.to_dict(),
|
||||||
)
|
)
|
||||||
except ConditionError:
|
except ConditionError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_state(self, entity_observation):
|
def _process_state(self, entity_observation):
|
||||||
"""Return True if state conditions are met."""
|
"""Return True if state conditions are met, return False if they are not.
|
||||||
entity = entity_observation["entity_id"]
|
|
||||||
|
Returns None if the state is unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity = entity_observation.entity_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return condition.state(
|
return condition.state(self.hass, entity, entity_observation.to_state)
|
||||||
self.hass, entity, entity_observation.get("to_state")
|
|
||||||
)
|
|
||||||
except ConditionError:
|
except ConditionError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_multi_state(self, entity_observation):
|
def _process_multi_state(self, entity_observation):
|
||||||
"""Return True if state conditions are met."""
|
"""Return True if state conditions are met, otherwise return None.
|
||||||
entity = entity_observation["entity_id"]
|
|
||||||
|
Never return False as all other states should have their own probabilities configured.
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity = entity_observation.entity_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if condition.state(self.hass, entity, entity_observation.get("to_state")):
|
if condition.state(self.hass, entity, entity_observation.to_state):
|
||||||
return True
|
return True
|
||||||
except ConditionError:
|
except ConditionError:
|
||||||
return None
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_state_attributes(self):
|
def extra_state_attributes(self):
|
||||||
"""Return the state attributes of the sensor."""
|
"""Return the state attributes of the sensor."""
|
||||||
attr_observations_list = [
|
|
||||||
obs.copy() for obs in self.current_observations.values() if obs is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
for item in attr_observations_list:
|
|
||||||
item.pop("value_template", None)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
ATTR_OBSERVATIONS: attr_observations_list,
|
|
||||||
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
|
|
||||||
{
|
|
||||||
obs.get("entity_id")
|
|
||||||
for obs in self.current_observations.values()
|
|
||||||
if obs is not None
|
|
||||||
and obs.get("entity_id") is not None
|
|
||||||
and obs.get("observation") is not None
|
|
||||||
}
|
|
||||||
),
|
|
||||||
ATTR_PROBABILITY: round(self.probability, 2),
|
ATTR_PROBABILITY: round(self.probability, 2),
|
||||||
ATTR_PROBABILITY_THRESHOLD: self._probability_threshold,
|
ATTR_PROBABILITY_THRESHOLD: self._probability_threshold,
|
||||||
|
# An entity can be in more than one observation so set then list to deduplicate
|
||||||
|
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
|
||||||
|
{
|
||||||
|
observation.entity_id
|
||||||
|
for observation in self.current_observations.values()
|
||||||
|
if observation is not None
|
||||||
|
and observation.entity_id is not None
|
||||||
|
and observation.observed is not None
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ATTR_OBSERVATIONS: [
|
||||||
|
observation.to_dict()
|
||||||
|
for observation in self.current_observations.values()
|
||||||
|
if observation is not None
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def async_update(self) -> None:
|
async def async_update(self) -> None:
|
||||||
|
17
homeassistant/components/bayesian/const.py
Normal file
17
homeassistant/components/bayesian/const.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Consts for using in modules."""
|
||||||
|
|
||||||
|
ATTR_OBSERVATIONS = "observations"
|
||||||
|
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
|
||||||
|
ATTR_PROBABILITY = "probability"
|
||||||
|
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
|
||||||
|
|
||||||
|
CONF_OBSERVATIONS = "observations"
|
||||||
|
CONF_PRIOR = "prior"
|
||||||
|
CONF_TEMPLATE = "template"
|
||||||
|
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
|
||||||
|
CONF_P_GIVEN_F = "prob_given_false"
|
||||||
|
CONF_P_GIVEN_T = "prob_given_true"
|
||||||
|
CONF_TO_STATE = "to_state"
|
||||||
|
|
||||||
|
DEFAULT_NAME = "Bayesian Binary Sensor"
|
||||||
|
DEFAULT_PROBABILITY_THRESHOLD = 0.5
|
69
homeassistant/components/bayesian/helpers.py
Normal file
69
homeassistant/components/bayesian/helpers.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
"""Helpers to deal with bayesian observations."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from homeassistant.const import (
|
||||||
|
CONF_ABOVE,
|
||||||
|
CONF_BELOW,
|
||||||
|
CONF_ENTITY_ID,
|
||||||
|
CONF_PLATFORM,
|
||||||
|
CONF_VALUE_TEMPLATE,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers.template import Template
|
||||||
|
|
||||||
|
from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Observation:
|
||||||
|
"""Representation of a sensor or template observation."""
|
||||||
|
|
||||||
|
entity_id: str | None
|
||||||
|
platform: str
|
||||||
|
prob_given_true: float
|
||||||
|
prob_given_false: float
|
||||||
|
to_state: str | None
|
||||||
|
above: float | None
|
||||||
|
below: float | None
|
||||||
|
value_template: Template | None
|
||||||
|
observed: bool | None = None
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, str | float | bool | None]:
|
||||||
|
"""Represent Class as a Dict for easier serialization."""
|
||||||
|
|
||||||
|
# Needed because dataclasses asdict() can't serialize Templates and ignores Properties.
|
||||||
|
dic = {
|
||||||
|
CONF_PLATFORM: self.platform,
|
||||||
|
CONF_ENTITY_ID: self.entity_id,
|
||||||
|
CONF_VALUE_TEMPLATE: self.template,
|
||||||
|
CONF_TO_STATE: self.to_state,
|
||||||
|
CONF_ABOVE: self.above,
|
||||||
|
CONF_BELOW: self.below,
|
||||||
|
CONF_P_GIVEN_T: self.prob_given_true,
|
||||||
|
CONF_P_GIVEN_F: self.prob_given_false,
|
||||||
|
"observed": self.observed,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value in dic.copy().items():
|
||||||
|
if value is None:
|
||||||
|
del dic[key]
|
||||||
|
|
||||||
|
return dic
|
||||||
|
|
||||||
|
def is_mirror(self, other: Observation) -> bool:
|
||||||
|
"""Dectects whether given observation is a mirror of this one."""
|
||||||
|
return (
|
||||||
|
self.platform == other.platform
|
||||||
|
and round(self.prob_given_true + other.prob_given_true, 1) == 1
|
||||||
|
and round(self.prob_given_false + other.prob_given_false, 1) == 1
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def template(self) -> str | None:
|
||||||
|
"""Not all observations have templates and we want to get template strings."""
|
||||||
|
if self.value_template is not None:
|
||||||
|
return self.value_template.template
|
||||||
|
return None
|
@ -11,20 +11,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") ->
|
|||||||
"""If there are mirrored entries, the user is probably using a workaround for a patched bug."""
|
"""If there are mirrored entries, the user is probably using a workaround for a patched bug."""
|
||||||
if len(observations) != 2:
|
if len(observations) != 2:
|
||||||
return
|
return
|
||||||
true_sums_1: bool = (
|
if observations[0].is_mirror(observations[1]):
|
||||||
round(
|
|
||||||
observations[0]["prob_given_true"] + observations[1]["prob_given_true"], 1
|
|
||||||
)
|
|
||||||
== 1.0
|
|
||||||
)
|
|
||||||
false_sums_1: bool = (
|
|
||||||
round(
|
|
||||||
observations[0]["prob_given_false"] + observations[1]["prob_given_false"], 1
|
|
||||||
)
|
|
||||||
== 1.0
|
|
||||||
)
|
|
||||||
same_states: bool = observations[0]["platform"] == observations[1]["platform"]
|
|
||||||
if true_sums_1 & false_sums_1 & same_states:
|
|
||||||
issue_registry.async_create_issue(
|
issue_registry.async_create_issue(
|
||||||
hass,
|
hass,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user