Best effort state initialization of bayesian binary sensor (#30962)

* Best effort state initialization of bayesian binary sensor.

Why:

* https://github.com/home-assistant/home-assistant/issues/30119

This change addresses the need by:

* Running the main update logic eagerly for each entity being observed
  on `async_added_to_hass`.
* Test of the new behavior.

* Refactor in order to reduce number of methods with side effects that
mutate instance attributes.

* Improve test coverage

Why:

* Because for some reason my commits decreased test coverage.

This change addresses the need by:

* Adding coverage for the case where a device returns `STATE_UNKNOWN`
* Adding coverage for configurations with templates

* rebase and ensure upstream tests passed

* Delete commented code from addressing merge conflict.
This commit is contained in:
Jeff McGehee 2020-03-31 12:41:29 -04:00 committed by GitHub
parent f2f03cf552
commit dd1608db0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 240 additions and 88 deletions

View File

@ -1,6 +1,5 @@
"""Use Bayesian Inference to trigger a binary sensor.""" """Use Bayesian Inference to trigger a binary sensor."""
from collections import OrderedDict from collections import OrderedDict
from itertools import chain
import voluptuous as vol import voluptuous as vol
@ -88,10 +87,10 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
) )
def update_probability(prior, prob_true, prob_false): def update_probability(prior, prob_given_true, prob_given_false):
"""Update probability using Bayes' rule.""" """Update probability using Bayes' rule."""
numerator = prob_true * prior numerator = prob_given_true * prior
denominator = numerator + prob_false * (1 - prior) denominator = numerator + prob_given_false * (1 - prior)
probability = numerator / denominator probability = numerator / denominator
return probability return probability
@ -127,84 +126,124 @@ class BayesianBinarySensor(BinarySensorDevice):
self.prior = prior self.prior = prior
self.probability = prior self.probability = prior
self.current_obs = OrderedDict({}) self.current_observations = OrderedDict({})
self.entity_obs_dict = []
for obs in self._observations: self.observations_by_entity = self._build_observations_by_entity()
if "entity_id" in obs:
self.entity_obs_dict.append([obs.get("entity_id")])
if "value_template" in obs:
self.entity_obs_dict.append(
list(obs.get(CONF_VALUE_TEMPLATE).extract_entities())
)
to_observe = set() self.observation_handlers = {
for obs in self._observations:
if "entity_id" in obs:
to_observe.update(set([obs.get("entity_id")]))
if "value_template" in obs:
to_observe.update(set(obs.get(CONF_VALUE_TEMPLATE).extract_entities()))
self.entity_obs = {key: [] for key in to_observe}
for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "entity_id" in obs:
self.entity_obs[obs["entity_id"]].append(obs)
if "value_template" in obs:
for ent in obs.get(CONF_VALUE_TEMPLATE).extract_entities():
self.entity_obs[ent].append(obs)
self.watchers = {
"numeric_state": self._process_numeric_state, "numeric_state": self._process_numeric_state,
"state": self._process_state, "state": self._process_state,
"template": self._process_template, "template": self._process_template,
} }
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Call when entity about to be added.""" """
Call when entity about to be added.
All relevant update logic for instance attributes occurs within this closure.
Other methods in this class are designed to avoid directly modifying instance
attributes, by instead focusing on returning relevant data back to this method.
The goal of this method is to ensure that `self.current_observations` and `self.probability`
are set on a best-effort basis when this entity is register with hass.
In addition, this method must register the state listener defined within, which
will be called any time a relevant entity changes its state.
"""
@callback @callback
def async_threshold_sensor_state_listener(entity, old_state, new_state): def async_threshold_sensor_state_listener(entity, _old_state, new_state):
"""Handle sensor state changes.""" """
Handle sensor state changes.
When a state changes, we must update our list of current observations,
then calculate the new probability.
"""
if new_state.state == STATE_UNKNOWN: if new_state.state == STATE_UNKNOWN:
return return
entity_obs_list = self.entity_obs[entity] self.current_observations.update(self._record_entity_observations(entity))
self.probability = self._calculate_new_probability()
for entity_obs in entity_obs_list:
platform = entity_obs["platform"]
self.watchers[platform](entity_obs)
prior = self.prior
for obs in self.current_obs.values():
prior = update_probability(prior, obs["prob_true"], obs["prob_false"])
self.probability = prior
self.hass.async_add_job(self.async_update_ha_state, True) self.hass.async_add_job(self.async_update_ha_state, True)
self.current_observations.update(self._initialize_current_observations())
self.probability = self._calculate_new_probability()
async_track_state_change( async_track_state_change(
self.hass, self.entity_obs, async_threshold_sensor_state_listener self.hass,
self.observations_by_entity,
async_threshold_sensor_state_listener,
) )
def _update_current_obs(self, entity_observation, should_trigger): def _initialize_current_observations(self):
"""Update current observation.""" local_observations = OrderedDict({})
obs_id = entity_observation["id"] for entity in self.observations_by_entity:
local_observations.update(self._record_entity_observations(entity))
return local_observations
if should_trigger: def _record_entity_observations(self, entity):
prob_true = entity_observation["prob_given_true"] local_observations = OrderedDict({})
prob_false = entity_observation.get("prob_given_false", 1 - prob_true) entity_obs_list = self.observations_by_entity[entity]
self.current_obs[obs_id] = { for entity_obs in entity_obs_list:
"prob_true": prob_true, platform = entity_obs["platform"]
"prob_false": prob_false,
}
else: should_trigger = self.observation_handlers[platform](entity_obs)
self.current_obs.pop(obs_id, None)
if should_trigger:
obs_entry = {"entity_id": entity, **entity_obs}
else:
obs_entry = None
local_observations[entity_obs["id"]] = obs_entry
return local_observations
def _calculate_new_probability(self):
prior = self.prior
for obs in self.current_observations.values():
if obs is not None:
prior = update_probability(
prior,
obs["prob_given_true"],
obs.get("prob_given_false", 1 - obs["prob_given_true"]),
)
return prior
def _build_observations_by_entity(self):
"""
Build and return data structure of the form below.
{
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
"sensor.sensor2": [{"id": 2, ...}],
...
}
Each "observation" must be recognized uniquely, and it should be possible
for all relevant observations to be looked up via their `entity_id`.
"""
observations_by_entity = {}
for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "entity_id" in obs:
entity_ids = [obs["entity_id"]]
elif "value_template" in obs:
entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()
for e_id in entity_ids:
obs_list = observations_by_entity.get(e_id, [])
obs_list.append(obs)
observations_by_entity[e_id] = obs_list
return observations_by_entity
def _process_numeric_state(self, entity_observation): def _process_numeric_state(self, entity_observation):
"""Add entity to current_obs if numeric state conditions are met.""" """Return True if numeric condition is met."""
entity = entity_observation["entity_id"] entity = entity_observation["entity_id"]
should_trigger = condition.async_numeric_state( should_trigger = condition.async_numeric_state(
@ -215,27 +254,26 @@ class BayesianBinarySensor(BinarySensorDevice):
None, None,
entity_observation, entity_observation,
) )
return should_trigger
self._update_current_obs(entity_observation, should_trigger)
def _process_state(self, entity_observation): def _process_state(self, entity_observation):
"""Add entity to current observations if state conditions are met.""" """Return True if state conditions are met."""
entity = entity_observation["entity_id"] entity = entity_observation["entity_id"]
should_trigger = condition.state( should_trigger = condition.state(
self.hass, entity, entity_observation.get("to_state") self.hass, entity, entity_observation.get("to_state")
) )
self._update_current_obs(entity_observation, should_trigger) return should_trigger
def _process_template(self, entity_observation): def _process_template(self, entity_observation):
"""Add entity to current_obs if template is true.""" """Return True if template condition is True."""
template = entity_observation.get(CONF_VALUE_TEMPLATE) template = entity_observation.get(CONF_VALUE_TEMPLATE)
template.hass = self.hass template.hass = self.hass
should_trigger = condition.async_template( should_trigger = condition.async_template(
self.hass, template, entity_observation self.hass, template, entity_observation
) )
self._update_current_obs(entity_observation, should_trigger) return should_trigger
@property @property
def name(self): def name(self):
@ -260,13 +298,15 @@ class BayesianBinarySensor(BinarySensorDevice):
@property @property
def device_state_attributes(self): def device_state_attributes(self):
"""Return the state attributes of the sensor.""" """Return the state attributes of the sensor."""
print(self.current_observations)
print(self.observations_by_entity)
return { return {
ATTR_OBSERVATIONS: list(self.current_obs.values()), ATTR_OBSERVATIONS: list(self.current_observations.values()),
ATTR_OCCURRED_OBSERVATION_ENTITIES: list( ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
set( set(
chain.from_iterable( obs.get("entity_id")
self.entity_obs_dict[obs] for obs in self.current_obs.keys() for obs in self.current_observations.values()
) if obs is not None
) )
), ),
ATTR_PROBABILITY: round(self.probability, 2), ATTR_PROBABILITY: round(self.probability, 2),

View File

@ -2,6 +2,7 @@
import unittest import unittest
from homeassistant.components.bayesian import binary_sensor as bayesian from homeassistant.components.bayesian import binary_sensor as bayesian
from homeassistant.const import STATE_UNKNOWN
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
@ -18,6 +19,65 @@ class TestBayesianBinarySensor(unittest.TestCase):
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop() self.hass.stop()
def test_load_values_when_added_to_hass(self):
"""Test that sensor initializes with observations of relevant entities."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "state",
"entity_id": "sensor.test_monitored",
"to_state": "off",
"prob_given_true": 0.8,
"prob_given_false": 0.4,
}
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
assert setup_component(self.hass, "binary_sensor", config)
state = self.hass.states.get("binary_sensor.test_binary")
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
def test_unknown_state_does_not_influence_probability(self):
"""Test that an unknown state does not change the output probability."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "state",
"entity_id": "sensor.test_monitored",
"to_state": "off",
"prob_given_true": 0.8,
"prob_given_false": 0.4,
}
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
self.hass.states.set("sensor.test_monitored", STATE_UNKNOWN)
self.hass.block_till_done()
assert setup_component(self.hass, "binary_sensor", config)
state = self.hass.states.get("binary_sensor.test_binary")
assert state.attributes.get("observations") == [None]
def test_sensor_numeric_state(self): def test_sensor_numeric_state(self):
"""Test sensor on numeric state platform observations.""" """Test sensor on numeric state platform observations."""
config = { config = {
@ -52,7 +112,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
state = self.hass.states.get("binary_sensor.test_binary") state = self.hass.states.get("binary_sensor.test_binary")
assert [] == state.attributes.get("observations") assert [None, None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability") assert 0.2 == state.attributes.get("probability")
assert state.state == "off" assert state.state == "off"
@ -66,10 +126,9 @@ class TestBayesianBinarySensor(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary") state = self.hass.states.get("binary_sensor.test_binary")
assert [ assert state.attributes.get("observations")[0]["prob_given_true"] == 0.6
{"prob_false": 0.4, "prob_true": 0.6}, assert state.attributes.get("observations")[1]["prob_given_true"] == 0.9
{"prob_false": 0.1, "prob_true": 0.9}, assert state.attributes.get("observations")[1]["prob_given_false"] == 0.1
] == state.attributes.get("observations")
assert round(abs(0.77 - state.attributes.get("probability")), 7) == 0 assert round(abs(0.77 - state.attributes.get("probability")), 7) == 0
assert state.state == "on" assert state.state == "on"
@ -118,7 +177,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
state = self.hass.states.get("binary_sensor.test_binary") state = self.hass.states.get("binary_sensor.test_binary")
assert [] == state.attributes.get("observations") assert [None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability") assert 0.2 == state.attributes.get("probability")
assert state.state == "off" assert state.state == "off"
@ -131,9 +190,62 @@ class TestBayesianBinarySensor(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary") state = self.hass.states.get("binary_sensor.test_binary")
assert [{"prob_true": 0.8, "prob_false": 0.4}] == state.attributes.get( assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
"observations" assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
) assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
assert state.state == "on"
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
self.hass.states.set("sensor.test_monitored", "on")
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary")
assert round(abs(0.2 - state.attributes.get("probability")), 7) == 0
assert state.state == "off"
def test_sensor_value_template(self):
"""Test sensor on template platform observations."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "template",
"value_template": "{{states('sensor.test_monitored') == 'off'}}",
"prob_given_true": 0.8,
"prob_given_false": 0.4,
}
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
assert setup_component(self.hass, "binary_sensor", config)
self.hass.states.set("sensor.test_monitored", "on")
state = self.hass.states.get("binary_sensor.test_binary")
assert [None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability")
assert state.state == "off"
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
self.hass.states.set("sensor.test_monitored", "on")
self.hass.block_till_done()
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary")
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0 assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
assert state.state == "on" assert state.state == "on"
@ -210,7 +322,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
state = self.hass.states.get("binary_sensor.test_binary") state = self.hass.states.get("binary_sensor.test_binary")
assert [] == state.attributes.get("observations") assert [None, None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability") assert 0.2 == state.attributes.get("probability")
assert state.state == "off" assert state.state == "off"
@ -223,9 +335,9 @@ class TestBayesianBinarySensor(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary") state = self.hass.states.get("binary_sensor.test_binary")
assert [{"prob_true": 0.8, "prob_false": 0.4}] == state.attributes.get(
"observations" assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
) assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0 assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
assert state.state == "on" assert state.state == "on"
@ -242,20 +354,20 @@ class TestBayesianBinarySensor(unittest.TestCase):
def test_probability_updates(self): def test_probability_updates(self):
"""Test probability update function.""" """Test probability update function."""
prob_true = [0.3, 0.6, 0.8] prob_given_true = [0.3, 0.6, 0.8]
prob_false = [0.7, 0.4, 0.2] prob_given_false = [0.7, 0.4, 0.2]
prior = 0.5 prior = 0.5
for pt, pf in zip(prob_true, prob_false): for pt, pf in zip(prob_given_true, prob_given_false):
prior = bayesian.update_probability(prior, pt, pf) prior = bayesian.update_probability(prior, pt, pf)
assert round(abs(0.720000 - prior), 7) == 0 assert round(abs(0.720000 - prior), 7) == 0
prob_true = [0.8, 0.3, 0.9] prob_given_true = [0.8, 0.3, 0.9]
prob_false = [0.6, 0.4, 0.2] prob_given_false = [0.6, 0.4, 0.2]
prior = 0.7 prior = 0.7
for pt, pf in zip(prob_true, prob_false): for pt, pf in zip(prob_given_true, prob_given_false):
prior = bayesian.update_probability(prior, pt, pf) prior = bayesian.update_probability(prior, pt, pf)
assert round(abs(0.9130434782608695 - prior), 7) == 0 assert round(abs(0.9130434782608695 - prior), 7) == 0
@ -271,7 +383,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
"platform": "state", "platform": "state",
"entity_id": "sensor.test_monitored", "entity_id": "sensor.test_monitored",
"to_state": "off", "to_state": "off",
"prob_given_true": 0.8, "prob_given_true": 0.9,
"prob_given_false": 0.4, "prob_given_false": 0.4,
}, },
{ {