Best effort state initialization of bayesian binary sensor (#30962)

* Best effort state initialization of bayesian binary sensor.

Why:

* https://github.com/home-assistant/home-assistant/issues/30119

This change addresses the need by:

* Running the main update logic eagerly for each entity being observed
  on `async_added_to_hass`.
* Test of the new behavior.

* Refactor in order to reduce number of methods with side effects that
mutate instance attributes.

* Improve test coverage

Why:

* Because for some reason my commits decreased test coverage.

This change addresses the need by:

* Adding coverage for the case where a device returns `STATE_UNKNOWN`
* Adding coverage for configurations with templates

* rebase and ensure upstream tests passed

* Delete commented code from addressing merge conflict.
This commit is contained in:
Jeff McGehee 2020-03-31 12:41:29 -04:00 committed by GitHub
parent f2f03cf552
commit dd1608db0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 240 additions and 88 deletions

View File

@ -1,6 +1,5 @@
"""Use Bayesian Inference to trigger a binary sensor."""
from collections import OrderedDict
from itertools import chain
import voluptuous as vol
@ -88,10 +87,10 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
def update_probability(prior, prob_true, prob_false):
def update_probability(prior, prob_given_true, prob_given_false):
"""Update probability using Bayes' rule."""
numerator = prob_true * prior
denominator = numerator + prob_false * (1 - prior)
numerator = prob_given_true * prior
denominator = numerator + prob_given_false * (1 - prior)
probability = numerator / denominator
return probability
@ -127,84 +126,124 @@ class BayesianBinarySensor(BinarySensorDevice):
self.prior = prior
self.probability = prior
self.current_obs = OrderedDict({})
self.entity_obs_dict = []
self.current_observations = OrderedDict({})
for obs in self._observations:
if "entity_id" in obs:
self.entity_obs_dict.append([obs.get("entity_id")])
if "value_template" in obs:
self.entity_obs_dict.append(
list(obs.get(CONF_VALUE_TEMPLATE).extract_entities())
)
self.observations_by_entity = self._build_observations_by_entity()
to_observe = set()
for obs in self._observations:
if "entity_id" in obs:
to_observe.update(set([obs.get("entity_id")]))
if "value_template" in obs:
to_observe.update(set(obs.get(CONF_VALUE_TEMPLATE).extract_entities()))
self.entity_obs = {key: [] for key in to_observe}
for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "entity_id" in obs:
self.entity_obs[obs["entity_id"]].append(obs)
if "value_template" in obs:
for ent in obs.get(CONF_VALUE_TEMPLATE).extract_entities():
self.entity_obs[ent].append(obs)
self.watchers = {
self.observation_handlers = {
"numeric_state": self._process_numeric_state,
"state": self._process_state,
"template": self._process_template,
}
async def async_added_to_hass(self):
"""Call when entity about to be added."""
"""
Call when entity about to be added.
All relevant update logic for instance attributes occurs within this closure.
Other methods in this class are designed to avoid directly modifying instance
attributes, by instead focusing on returning relevant data back to this method.
The goal of this method is to ensure that `self.current_observations` and `self.probability`
are set on a best-effort basis when this entity is register with hass.
In addition, this method must register the state listener defined within, which
will be called any time a relevant entity changes its state.
"""
@callback
def async_threshold_sensor_state_listener(entity, old_state, new_state):
"""Handle sensor state changes."""
def async_threshold_sensor_state_listener(entity, _old_state, new_state):
"""
Handle sensor state changes.
When a state changes, we must update our list of current observations,
then calculate the new probability.
"""
if new_state.state == STATE_UNKNOWN:
return
entity_obs_list = self.entity_obs[entity]
for entity_obs in entity_obs_list:
platform = entity_obs["platform"]
self.watchers[platform](entity_obs)
prior = self.prior
for obs in self.current_obs.values():
prior = update_probability(prior, obs["prob_true"], obs["prob_false"])
self.probability = prior
self.current_observations.update(self._record_entity_observations(entity))
self.probability = self._calculate_new_probability()
self.hass.async_add_job(self.async_update_ha_state, True)
self.current_observations.update(self._initialize_current_observations())
self.probability = self._calculate_new_probability()
async_track_state_change(
self.hass, self.entity_obs, async_threshold_sensor_state_listener
self.hass,
self.observations_by_entity,
async_threshold_sensor_state_listener,
)
def _update_current_obs(self, entity_observation, should_trigger):
"""Update current observation."""
obs_id = entity_observation["id"]
def _initialize_current_observations(self):
local_observations = OrderedDict({})
for entity in self.observations_by_entity:
local_observations.update(self._record_entity_observations(entity))
return local_observations
if should_trigger:
prob_true = entity_observation["prob_given_true"]
prob_false = entity_observation.get("prob_given_false", 1 - prob_true)
def _record_entity_observations(self, entity):
local_observations = OrderedDict({})
entity_obs_list = self.observations_by_entity[entity]
self.current_obs[obs_id] = {
"prob_true": prob_true,
"prob_false": prob_false,
}
for entity_obs in entity_obs_list:
platform = entity_obs["platform"]
else:
self.current_obs.pop(obs_id, None)
should_trigger = self.observation_handlers[platform](entity_obs)
if should_trigger:
obs_entry = {"entity_id": entity, **entity_obs}
else:
obs_entry = None
local_observations[entity_obs["id"]] = obs_entry
return local_observations
def _calculate_new_probability(self):
prior = self.prior
for obs in self.current_observations.values():
if obs is not None:
prior = update_probability(
prior,
obs["prob_given_true"],
obs.get("prob_given_false", 1 - obs["prob_given_true"]),
)
return prior
def _build_observations_by_entity(self):
"""
Build and return data structure of the form below.
{
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
"sensor.sensor2": [{"id": 2, ...}],
...
}
Each "observation" must be recognized uniquely, and it should be possible
for all relevant observations to be looked up via their `entity_id`.
"""
observations_by_entity = {}
for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "entity_id" in obs:
entity_ids = [obs["entity_id"]]
elif "value_template" in obs:
entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()
for e_id in entity_ids:
obs_list = observations_by_entity.get(e_id, [])
obs_list.append(obs)
observations_by_entity[e_id] = obs_list
return observations_by_entity
def _process_numeric_state(self, entity_observation):
"""Add entity to current_obs if numeric state conditions are met."""
"""Return True if numeric condition is met."""
entity = entity_observation["entity_id"]
should_trigger = condition.async_numeric_state(
@ -215,27 +254,26 @@ class BayesianBinarySensor(BinarySensorDevice):
None,
entity_observation,
)
self._update_current_obs(entity_observation, should_trigger)
return should_trigger
def _process_state(self, entity_observation):
"""Add entity to current observations if state conditions are met."""
"""Return True if state conditions are met."""
entity = entity_observation["entity_id"]
should_trigger = condition.state(
self.hass, entity, entity_observation.get("to_state")
)
self._update_current_obs(entity_observation, should_trigger)
return should_trigger
def _process_template(self, entity_observation):
"""Add entity to current_obs if template is true."""
"""Return True if template condition is True."""
template = entity_observation.get(CONF_VALUE_TEMPLATE)
template.hass = self.hass
should_trigger = condition.async_template(
self.hass, template, entity_observation
)
self._update_current_obs(entity_observation, should_trigger)
return should_trigger
@property
def name(self):
@ -260,13 +298,15 @@ class BayesianBinarySensor(BinarySensorDevice):
@property
def device_state_attributes(self):
"""Return the state attributes of the sensor."""
print(self.current_observations)
print(self.observations_by_entity)
return {
ATTR_OBSERVATIONS: list(self.current_obs.values()),
ATTR_OBSERVATIONS: list(self.current_observations.values()),
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
set(
chain.from_iterable(
self.entity_obs_dict[obs] for obs in self.current_obs.keys()
)
obs.get("entity_id")
for obs in self.current_observations.values()
if obs is not None
)
),
ATTR_PROBABILITY: round(self.probability, 2),

View File

@ -2,6 +2,7 @@
import unittest
from homeassistant.components.bayesian import binary_sensor as bayesian
from homeassistant.const import STATE_UNKNOWN
from homeassistant.setup import setup_component
from tests.common import get_test_home_assistant
@ -18,6 +19,65 @@ class TestBayesianBinarySensor(unittest.TestCase):
"""Stop everything that was started."""
self.hass.stop()
def test_load_values_when_added_to_hass(self):
"""Test that sensor initializes with observations of relevant entities."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "state",
"entity_id": "sensor.test_monitored",
"to_state": "off",
"prob_given_true": 0.8,
"prob_given_false": 0.4,
}
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
assert setup_component(self.hass, "binary_sensor", config)
state = self.hass.states.get("binary_sensor.test_binary")
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
def test_unknown_state_does_not_influence_probability(self):
"""Test that an unknown state does not change the output probability."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "state",
"entity_id": "sensor.test_monitored",
"to_state": "off",
"prob_given_true": 0.8,
"prob_given_false": 0.4,
}
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
self.hass.states.set("sensor.test_monitored", STATE_UNKNOWN)
self.hass.block_till_done()
assert setup_component(self.hass, "binary_sensor", config)
state = self.hass.states.get("binary_sensor.test_binary")
assert state.attributes.get("observations") == [None]
def test_sensor_numeric_state(self):
"""Test sensor on numeric state platform observations."""
config = {
@ -52,7 +112,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
state = self.hass.states.get("binary_sensor.test_binary")
assert [] == state.attributes.get("observations")
assert [None, None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability")
assert state.state == "off"
@ -66,10 +126,9 @@ class TestBayesianBinarySensor(unittest.TestCase):
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary")
assert [
{"prob_false": 0.4, "prob_true": 0.6},
{"prob_false": 0.1, "prob_true": 0.9},
] == state.attributes.get("observations")
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.6
assert state.attributes.get("observations")[1]["prob_given_true"] == 0.9
assert state.attributes.get("observations")[1]["prob_given_false"] == 0.1
assert round(abs(0.77 - state.attributes.get("probability")), 7) == 0
assert state.state == "on"
@ -118,7 +177,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
state = self.hass.states.get("binary_sensor.test_binary")
assert [] == state.attributes.get("observations")
assert [None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability")
assert state.state == "off"
@ -131,9 +190,62 @@ class TestBayesianBinarySensor(unittest.TestCase):
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary")
assert [{"prob_true": 0.8, "prob_false": 0.4}] == state.attributes.get(
"observations"
)
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
assert state.state == "on"
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
self.hass.states.set("sensor.test_monitored", "on")
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary")
assert round(abs(0.2 - state.attributes.get("probability")), 7) == 0
assert state.state == "off"
def test_sensor_value_template(self):
"""Test sensor on template platform observations."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "template",
"value_template": "{{states('sensor.test_monitored') == 'off'}}",
"prob_given_true": 0.8,
"prob_given_false": 0.4,
}
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
assert setup_component(self.hass, "binary_sensor", config)
self.hass.states.set("sensor.test_monitored", "on")
state = self.hass.states.get("binary_sensor.test_binary")
assert [None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability")
assert state.state == "off"
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
self.hass.states.set("sensor.test_monitored", "on")
self.hass.block_till_done()
self.hass.states.set("sensor.test_monitored", "off")
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary")
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
assert state.state == "on"
@ -210,7 +322,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
state = self.hass.states.get("binary_sensor.test_binary")
assert [] == state.attributes.get("observations")
assert [None, None] == state.attributes.get("observations")
assert 0.2 == state.attributes.get("probability")
assert state.state == "off"
@ -223,9 +335,9 @@ class TestBayesianBinarySensor(unittest.TestCase):
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_binary")
assert [{"prob_true": 0.8, "prob_false": 0.4}] == state.attributes.get(
"observations"
)
assert state.attributes.get("observations")[0]["prob_given_true"] == 0.8
assert state.attributes.get("observations")[0]["prob_given_false"] == 0.4
assert round(abs(0.33 - state.attributes.get("probability")), 7) == 0
assert state.state == "on"
@ -242,20 +354,20 @@ class TestBayesianBinarySensor(unittest.TestCase):
def test_probability_updates(self):
"""Test probability update function."""
prob_true = [0.3, 0.6, 0.8]
prob_false = [0.7, 0.4, 0.2]
prob_given_true = [0.3, 0.6, 0.8]
prob_given_false = [0.7, 0.4, 0.2]
prior = 0.5
for pt, pf in zip(prob_true, prob_false):
for pt, pf in zip(prob_given_true, prob_given_false):
prior = bayesian.update_probability(prior, pt, pf)
assert round(abs(0.720000 - prior), 7) == 0
prob_true = [0.8, 0.3, 0.9]
prob_false = [0.6, 0.4, 0.2]
prob_given_true = [0.8, 0.3, 0.9]
prob_given_false = [0.6, 0.4, 0.2]
prior = 0.7
for pt, pf in zip(prob_true, prob_false):
for pt, pf in zip(prob_given_true, prob_given_false):
prior = bayesian.update_probability(prior, pt, pf)
assert round(abs(0.9130434782608695 - prior), 7) == 0
@ -271,7 +383,7 @@ class TestBayesianBinarySensor(unittest.TestCase):
"platform": "state",
"entity_id": "sensor.test_monitored",
"to_state": "off",
"prob_given_true": 0.8,
"prob_given_true": 0.9,
"prob_given_false": 0.4,
},
{