mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Best effort state initialization of bayesian binary sensor (#30962)
* Best effort state initialization of bayesian binary sensor. Why: * https://github.com/home-assistant/home-assistant/issues/30119 This change addresses the need by: * Running the main update logic eagerly for each entity being observed on `async_added_to_hass`. * Test of the new behavior. * Refactor in order to reduce number of methods with side effects that mutate instance attributes. * Improve test coverage Why: * Because for some reason my commits decreased test coverage. This change addresses the need by: * Adding coverage for the case where a device returns `STATE_UNKNOWN` * Adding coverage for configurations with templates * rebase and ensure upstream tests passed * Delete commented code from addressing merge conflict.
This commit is contained in:
parent
f2f03cf552
commit
dd1608db0d
@ -1,6 +1,5 @@
|
||||
"""Use Bayesian Inference to trigger a binary sensor."""
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -88,10 +87,10 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
)
|
||||
|
||||
|
||||
def update_probability(prior, prob_true, prob_false):
|
||||
def update_probability(prior, prob_given_true, prob_given_false):
|
||||
"""Update probability using Bayes' rule."""
|
||||
numerator = prob_true * prior
|
||||
denominator = numerator + prob_false * (1 - prior)
|
||||
numerator = prob_given_true * prior
|
||||
denominator = numerator + prob_given_false * (1 - prior)
|
||||
probability = numerator / denominator
|
||||
return probability
|
||||
|
||||
@ -127,84 +126,124 @@ class BayesianBinarySensor(BinarySensorDevice):
|
||||
self.prior = prior
|
||||
self.probability = prior
|
||||
|
||||
self.current_obs = OrderedDict({})
|
||||
self.entity_obs_dict = []
|
||||
self.current_observations = OrderedDict({})
|
||||
|
||||
for obs in self._observations:
|
||||
if "entity_id" in obs:
|
||||
self.entity_obs_dict.append([obs.get("entity_id")])
|
||||
if "value_template" in obs:
|
||||
self.entity_obs_dict.append(
|
||||
list(obs.get(CONF_VALUE_TEMPLATE).extract_entities())
|
||||
)
|
||||
self.observations_by_entity = self._build_observations_by_entity()
|
||||
|
||||
to_observe = set()
|
||||
for obs in self._observations:
|
||||
if "entity_id" in obs:
|
||||
to_observe.update(set([obs.get("entity_id")]))
|
||||
if "value_template" in obs:
|
||||
to_observe.update(set(obs.get(CONF_VALUE_TEMPLATE).extract_entities()))
|
||||
self.entity_obs = {key: [] for key in to_observe}
|
||||
|
||||
for ind, obs in enumerate(self._observations):
|
||||
obs["id"] = ind
|
||||
if "entity_id" in obs:
|
||||
self.entity_obs[obs["entity_id"]].append(obs)
|
||||
if "value_template" in obs:
|
||||
for ent in obs.get(CONF_VALUE_TEMPLATE).extract_entities():
|
||||
self.entity_obs[ent].append(obs)
|
||||
|
||||
self.watchers = {
|
||||
self.observation_handlers = {
|
||||
"numeric_state": self._process_numeric_state,
|
||||
"state": self._process_state,
|
||||
"template": self._process_template,
|
||||
}
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Call when entity about to be added."""
|
||||
"""
|
||||
Call when entity about to be added.
|
||||
|
||||
All relevant update logic for instance attributes occurs within this closure.
|
||||
Other methods in this class are designed to avoid directly modifying instance
|
||||
attributes, by instead focusing on returning relevant data back to this method.
|
||||
|
||||
The goal of this method is to ensure that `self.current_observations` and `self.probability`
|
||||
are set on a best-effort basis when this entity is register with hass.
|
||||
|
||||
In addition, this method must register the state listener defined within, which
|
||||
will be called any time a relevant entity changes its state.
|
||||
"""
|
||||
|
||||
@callback
|
||||
def async_threshold_sensor_state_listener(entity, old_state, new_state):
|
||||
"""Handle sensor state changes."""
|
||||
def async_threshold_sensor_state_listener(entity, _old_state, new_state):
|
||||
"""
|
||||
Handle sensor state changes.
|
||||
|
||||
When a state changes, we must update our list of current observations,
|
||||
then calculate the new probability.
|
||||
"""
|
||||
if new_state.state == STATE_UNKNOWN:
|
||||
return
|
||||
|
||||
entity_obs_list = self.entity_obs[entity]
|
||||
|
||||
for entity_obs in entity_obs_list:
|
||||
platform = entity_obs["platform"]
|
||||
|
||||
self.watchers[platform](entity_obs)
|
||||
|
||||
prior = self.prior
|
||||
for obs in self.current_obs.values():
|
||||
prior = update_probability(prior, obs["prob_true"], obs["prob_false"])
|
||||
self.probability = prior
|
||||
self.current_observations.update(self._record_entity_observations(entity))
|
||||
self.probability = self._calculate_new_probability()
|
||||
|
||||
self.hass.async_add_job(self.async_update_ha_state, True)
|
||||
|
||||
self.current_observations.update(self._initialize_current_observations())
|
||||
self.probability = self._calculate_new_probability()
|
||||
async_track_state_change(
|
||||
self.hass, self.entity_obs, async_threshold_sensor_state_listener
|
||||
self.hass,
|
||||
self.observations_by_entity,
|
||||
async_threshold_sensor_state_listener,
|
||||
)
|
||||
|
||||
def _update_current_obs(self, entity_observation, should_trigger):
|
||||
"""Update current observation."""
|
||||
obs_id = entity_observation["id"]
|
||||
def _initialize_current_observations(self):
|
||||
local_observations = OrderedDict({})
|
||||
for entity in self.observations_by_entity:
|
||||
local_observations.update(self._record_entity_observations(entity))
|
||||
return local_observations
|
||||
|
||||
if should_trigger:
|
||||
prob_true = entity_observation["prob_given_true"]
|
||||
prob_false = entity_observation.get("prob_given_false", 1 - prob_true)
|
||||
def _record_entity_observations(self, entity):
|
||||
local_observations = OrderedDict({})
|
||||
entity_obs_list = self.observations_by_entity[entity]
|
||||
|
||||
self.current_obs[obs_id] = {
|
||||
"prob_true": prob_true,
|
||||
"prob_false": prob_false,
|
||||
}
|
||||
for entity_obs in entity_obs_list:
|
||||
platform = entity_obs["platform"]
|
||||
|
||||
else:
|
||||
self.current_obs.pop(obs_id, None)
|
||||
should_trigger = self.observation_handlers[platform](entity_obs)
|
||||
|
||||
if should_trigger:
|
||||
obs_entry = {"entity_id": entity, **entity_obs}
|
||||
else:
|
||||
obs_entry = None
|
||||
|
||||
local_observations[entity_obs["id"]] = obs_entry
|
||||
|
||||
return local_observations
|
||||
|
||||
def _calculate_new_probability(self):
|
||||
prior = self.prior
|
||||
|
||||
for obs in self.current_observations.values():
|
||||
if obs is not None:
|
||||
prior = update_probability(
|
||||
prior,
|
||||
obs["prob_given_true"],
|
||||
obs.get("prob_given_false", 1 - obs["prob_given_true"]),
|
||||
)
|
||||
|
||||
return prior
|
||||
|
||||
def _build_observations_by_entity(self):
|
||||
"""
|
||||
Build and return data structure of the form below.
|
||||
|
||||
{
|
||||
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
|
||||
"sensor.sensor2": [{"id": 2, ...}],
|
||||
...
|
||||
}
|
||||
|
||||
Each "observation" must be recognized uniquely, and it should be possible
|
||||
for all relevant observations to be looked up via their `entity_id`.
|
||||
"""
|
||||
|
||||
observations_by_entity = {}
|
||||
for ind, obs in enumerate(self._observations):
|
||||
obs["id"] = ind
|
||||
|
||||
if "entity_id" in obs:
|
||||
entity_ids = [obs["entity_id"]]
|
||||
elif "value_template" in obs:
|
||||
entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()
|
||||
|
||||
for e_id in entity_ids:
|
||||
obs_list = observations_by_entity.get(e_id, [])
|
||||
obs_list.append(obs)
|
||||
observations_by_entity[e_id] = obs_list
|
||||
|
||||
return observations_by_entity
|
||||
|
||||
def _process_numeric_state(self, entity_observation):
|
||||
"""Add entity to current_obs if numeric state conditions are met."""
|
||||
"""Return True if numeric condition is met."""
|
||||
entity = entity_observation["entity_id"]
|
||||
|
||||
should_trigger = condition.async_numeric_state(
|
||||
@ -215,27 +254,26 @@ class BayesianBinarySensor(BinarySensorDevice):
|
||||
None,
|
||||
entity_observation,
|
||||
)
|
||||
|
||||
self._update_current_obs(entity_observation, should_trigger)
|
||||
return should_trigger
|
||||
|
||||
def _process_state(self, entity_observation):
|
||||
"""Add entity to current observations if state conditions are met."""
|
||||
"""Return True if state conditions are met."""
|
||||
entity = entity_observation["entity_id"]
|
||||
|
||||
should_trigger = condition.state(
|
||||
self.hass, entity, entity_observation.get("to_state")
|
||||
)
|
||||
|
||||
self._update_current_obs(entity_observation, should_trigger)
|
||||
return should_trigger
|
||||
|
||||
def _process_template(self, entity_observation):
|
||||
"""Add entity to current_obs if template is true."""
|
||||
"""Return True if template condition is True."""
|
||||
template = entity_observation.get(CONF_VALUE_TEMPLATE)
|
||||
template.hass = self.hass
|
||||
should_trigger = condition.async_template(
|
||||
self.hass, template, entity_observation
|
||||
)
|
||||
self._update_current_obs(entity_observation, should_trigger)
|
||||
return should_trigger
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -260,13 +298,15 @@ class BayesianBinarySensor(BinarySensorDevice):
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
"""Return the state attributes of the sensor."""
|
||||
print(self.current_observations)
|
||||
print(self.observations_by_entity)
|
||||
return {
|
||||
ATTR_OBSERVATIONS: list(self.current_obs.values()),
|
||||
ATTR_OBSERVATIONS: list(self.current_observations.values()),
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
|
||||
set(
|
||||
chain.from_iterable(
|
||||
self.entity_obs_dict[obs] for obs in self.current_obs.keys()
|
||||
)
|
||||
obs.get("entity_id")
|
||||
for obs in self.current_observations.values()
|
||||
if obs is not None
|
||||
)
|
||||
),
|
||||
ATTR_PROBABILITY: round(self.probability, 2),
|
||||
|
@ -2,6 +2,7 @@
|
||||
import unittest
|
||||
|
||||
from homeassistant.components.bayesian import binary_sensor as bayesian
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.setup import setup_component
|
||||
|
||||
from tests.common import get_test_home_assistant
|
||||
@ -18,6 +19,65 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
"""Stop everything that was started."""
|
||||
self.hass.stop()
|
||||
|
||||
def test_load_values_when_added_to_hass(self):
|
||||
"""Test that sensor initializes with observations of relevant entities."""
|
||||
|
||||
config = {
|
||||
"binary_sensor": {
|
||||
"name": "Test_Binary",
|
||||
"platform": "bayesian",
|
||||
"observations": [
|
||||
{
|
||||
"platform": "state",
|
||||
"entity_id": "sensor.test_monitored",
|
||||
"to_state": "off",
|
||||
"prob_given_true": 0.8,
|
||||
"prob_given_false": 0.4,
|
||||
}
|
||||
],
|
||||
"prior": 0.2,
|
||||
"probability_threshold": 0.32,
|
||||
}
|
||||
}
|
||||
|
||||
self.hass.states.set("sensor.test_monitored", "off")
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert setup_component(self.hass, "binary_sensor", config)
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
|
||||
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
|
||||
|
||||
def test_unknown_state_does_not_influence_probability(self):
|
||||
"""Test that an unknown state does not change the output probability."""
|
||||
|
||||
config = {
|
||||
"binary_sensor": {
|
||||
"name": "Test_Binary",
|
||||
"platform": "bayesian",
|
||||
"observations": [
|
||||
{
|
||||
"platform": "state",
|
||||
"entity_id": "sensor.test_monitored",
|
||||
"to_state": "off",
|
||||
"prob_given_true": 0.8,
|
||||
"prob_given_false": 0.4,
|
||||
}
|
||||
],
|
||||
"prior": 0.2,
|
||||
"probability_threshold": 0.32,
|
||||
}
|
||||
}
|
||||
|
||||
self.hass.states.set("sensor.test_monitored", STATE_UNKNOWN)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert setup_component(self.hass, "binary_sensor", config)
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
assert state.attributes.get("observations") == [None]
|
||||
|
||||
def test_sensor_numeric_state(self):
|
||||
"""Test sensor on numeric state platform observations."""
|
||||
config = {
|
||||
@ -52,7 +112,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
|
||||
assert [] == state.attributes.get("observations")
|
||||
assert [None, None] == state.attributes.get("observations")
|
||||
assert 0.2 == state.attributes.get("probability")
|
||||
|
||||
assert state.state == "off"
|
||||
@ -66,10 +126,9 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
assert [
|
||||
{"prob_false": 0.4, "prob_true": 0.6},
|
||||
{"prob_false": 0.1, "prob_true": 0.9},
|
||||
] == state.attributes.get("observations")
|
||||
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.6
|
||||
assert state.attributes.get("observations")[1]["prob_given_true"] == 0.9
|
||||
assert state.attributes.get("observations")[1]["prob_given_false"] == 0.1
|
||||
assert round(abs(0.77 - state.attributes.get("probability")), 7) == 0
|
||||
|
||||
assert state.state == "on"
|
||||
@ -118,7 +177,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
|
||||
assert [] == state.attributes.get("observations")
|
||||
assert [None] == state.attributes.get("observations")
|
||||
assert 0.2 == state.attributes.get("probability")
|
||||
|
||||
assert state.state == "off"
|
||||
@ -131,9 +190,62 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
assert [{"prob_true": 0.8, "prob_false": 0.4}] == state.attributes.get(
|
||||
"observations"
|
||||
)
|
||||
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
|
||||
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
|
||||
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
|
||||
|
||||
assert state.state == "on"
|
||||
|
||||
self.hass.states.set("sensor.test_monitored", "off")
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set("sensor.test_monitored", "on")
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
assert round(abs(0.2 - state.attributes.get("probability")), 7) == 0
|
||||
|
||||
assert state.state == "off"
|
||||
|
||||
def test_sensor_value_template(self):
|
||||
"""Test sensor on template platform observations."""
|
||||
config = {
|
||||
"binary_sensor": {
|
||||
"name": "Test_Binary",
|
||||
"platform": "bayesian",
|
||||
"observations": [
|
||||
{
|
||||
"platform": "template",
|
||||
"value_template": "{{states('sensor.test_monitored') == 'off'}}",
|
||||
"prob_given_true": 0.8,
|
||||
"prob_given_false": 0.4,
|
||||
}
|
||||
],
|
||||
"prior": 0.2,
|
||||
"probability_threshold": 0.32,
|
||||
}
|
||||
}
|
||||
|
||||
assert setup_component(self.hass, "binary_sensor", config)
|
||||
|
||||
self.hass.states.set("sensor.test_monitored", "on")
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
|
||||
assert [None] == state.attributes.get("observations")
|
||||
assert 0.2 == state.attributes.get("probability")
|
||||
|
||||
assert state.state == "off"
|
||||
|
||||
self.hass.states.set("sensor.test_monitored", "off")
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set("sensor.test_monitored", "on")
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set("sensor.test_monitored", "off")
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
|
||||
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
|
||||
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
|
||||
|
||||
assert state.state == "on"
|
||||
@ -210,7 +322,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
|
||||
assert [] == state.attributes.get("observations")
|
||||
assert [None, None] == state.attributes.get("observations")
|
||||
assert 0.2 == state.attributes.get("probability")
|
||||
|
||||
assert state.state == "off"
|
||||
@ -223,9 +335,9 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get("binary_sensor.test_binary")
|
||||
assert [{"prob_true": 0.8, "prob_false": 0.4}] == state.attributes.get(
|
||||
"observations"
|
||||
)
|
||||
|
||||
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
|
||||
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
|
||||
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
|
||||
|
||||
assert state.state == "on"
|
||||
@ -242,20 +354,20 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
|
||||
def test_probability_updates(self):
|
||||
"""Test probability update function."""
|
||||
prob_true = [0.3, 0.6, 0.8]
|
||||
prob_false = [0.7, 0.4, 0.2]
|
||||
prob_given_true = [0.3, 0.6, 0.8]
|
||||
prob_given_false = [0.7, 0.4, 0.2]
|
||||
prior = 0.5
|
||||
|
||||
for pt, pf in zip(prob_true, prob_false):
|
||||
for pt, pf in zip(prob_given_true, prob_given_false):
|
||||
prior = bayesian.update_probability(prior, pt, pf)
|
||||
|
||||
assert round(abs(0.720000 - prior), 7) == 0
|
||||
|
||||
prob_true = [0.8, 0.3, 0.9]
|
||||
prob_false = [0.6, 0.4, 0.2]
|
||||
prob_given_true = [0.8, 0.3, 0.9]
|
||||
prob_given_false = [0.6, 0.4, 0.2]
|
||||
prior = 0.7
|
||||
|
||||
for pt, pf in zip(prob_true, prob_false):
|
||||
for pt, pf in zip(prob_given_true, prob_given_false):
|
||||
prior = bayesian.update_probability(prior, pt, pf)
|
||||
|
||||
assert round(abs(0.9130434782608695 - prior), 7) == 0
|
||||
@ -271,7 +383,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||
"platform": "state",
|
||||
"entity_id": "sensor.test_monitored",
|
||||
"to_state": "off",
|
||||
"prob_given_true": 0.8,
|
||||
"prob_given_true": 0.9,
|
||||
"prob_given_false": 0.4,
|
||||
},
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user