mirror of
https://github.com/home-assistant/core.git
synced 2025-10-17 23:59:38 +00:00
Baysesian Config Flow (#122552)
Co-authored-by: G Johansson <goran.johansson@shiftit.se> Co-authored-by: Norbert Rittel <norbert@rittel.de> Co-authored-by: Erik Montnemery <erik@montnemery.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from homeassistant.components.binary_sensor import (
|
||||
BinarySensorDeviceClass,
|
||||
BinarySensorEntity,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import (
|
||||
CONF_ABOVE,
|
||||
CONF_BELOW,
|
||||
@@ -32,7 +33,10 @@ from homeassistant.const import (
|
||||
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConditionError, TemplateError
|
||||
from homeassistant.helpers import condition, config_validation as cv
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.entity_platform import (
|
||||
AddConfigEntryEntitiesCallback,
|
||||
AddEntitiesCallback,
|
||||
)
|
||||
from homeassistant.helpers.event import (
|
||||
TrackTemplate,
|
||||
TrackTemplateResult,
|
||||
@@ -44,7 +48,6 @@ from homeassistant.helpers.reload import async_setup_reload_service
|
||||
from homeassistant.helpers.template import Template, result_as_boolean
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import DOMAIN, PLATFORMS
|
||||
from .const import (
|
||||
ATTR_OBSERVATIONS,
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES,
|
||||
@@ -60,6 +63,8 @@ from .const import (
|
||||
CONF_TO_STATE,
|
||||
DEFAULT_NAME,
|
||||
DEFAULT_PROBABILITY_THRESHOLD,
|
||||
DOMAIN,
|
||||
PLATFORMS,
|
||||
)
|
||||
from .helpers import Observation
|
||||
from .issues import raise_mirrored_entries, raise_no_prob_given_false
|
||||
@@ -67,7 +72,13 @@ from .issues import raise_mirrored_entries, raise_no_prob_given_false
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
|
||||
def above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate above and below options.
|
||||
|
||||
If the observation is of type/platform NUMERIC_STATE, then ensure that the
|
||||
value given for 'above' is not greater than that for 'below'. Also check
|
||||
that at least one of the two is specified.
|
||||
"""
|
||||
if config[CONF_PLATFORM] == CONF_NUMERIC_STATE:
|
||||
above = config.get(CONF_ABOVE)
|
||||
below = config.get(CONF_BELOW)
|
||||
@@ -76,9 +87,7 @@ def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"For bayesian numeric state for entity: %s at least one of 'above' or 'below' must be specified",
|
||||
config[CONF_ENTITY_ID],
|
||||
)
|
||||
raise vol.Invalid(
|
||||
"For bayesian numeric state at least one of 'above' or 'below' must be specified."
|
||||
)
|
||||
raise vol.Invalid("above_or_below")
|
||||
if above is not None and below is not None:
|
||||
if above > below:
|
||||
_LOGGER.error(
|
||||
@@ -86,7 +95,7 @@ def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
|
||||
above,
|
||||
below,
|
||||
)
|
||||
raise vol.Invalid("'above' is greater than 'below'")
|
||||
raise vol.Invalid("above_below")
|
||||
return config
|
||||
|
||||
|
||||
@@ -102,11 +111,16 @@ NUMERIC_STATE_SCHEMA = vol.All(
|
||||
},
|
||||
required=True,
|
||||
),
|
||||
_above_greater_than_below,
|
||||
above_greater_than_below,
|
||||
)
|
||||
|
||||
|
||||
def _no_overlapping(configs: list[dict]) -> list[dict]:
|
||||
def no_overlapping(configs: list[dict]) -> list[dict]:
|
||||
"""Validate that intervals are not overlapping.
|
||||
|
||||
For a list of observations ensure that there are no overlapping intervals
|
||||
for NUMERIC_STATE observations for the same entity.
|
||||
"""
|
||||
numeric_configs = [
|
||||
config for config in configs if config[CONF_PLATFORM] == CONF_NUMERIC_STATE
|
||||
]
|
||||
@@ -129,11 +143,16 @@ def _no_overlapping(configs: list[dict]) -> list[dict]:
|
||||
|
||||
for i, tup in enumerate(intervals):
|
||||
if len(intervals) > i + 1 and tup.below > intervals[i + 1].above:
|
||||
_LOGGER.error(
|
||||
"Ranges for bayesian numeric state entities must not overlap, but %s has overlapping ranges, above:%s, below:%s overlaps with above:%s, below:%s",
|
||||
ent_id,
|
||||
tup.above,
|
||||
tup.below,
|
||||
intervals[i + 1].above,
|
||||
intervals[i + 1].below,
|
||||
)
|
||||
raise vol.Invalid(
|
||||
"Ranges for bayesian numeric state entities must not overlap, "
|
||||
f"but {ent_id} has overlapping ranges, above:{tup.above}, "
|
||||
f"below:{tup.below} overlaps with above:{intervals[i + 1].above}, "
|
||||
f"below:{intervals[i + 1].below}."
|
||||
"overlapping_ranges",
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -168,7 +187,7 @@ PLATFORM_SCHEMA = BINARY_SENSOR_PLATFORM_SCHEMA.extend(
|
||||
vol.All(
|
||||
cv.ensure_list,
|
||||
[vol.Any(TEMPLATE_SCHEMA, STATE_SCHEMA, NUMERIC_STATE_SCHEMA)],
|
||||
_no_overlapping,
|
||||
no_overlapping,
|
||||
)
|
||||
),
|
||||
vol.Required(CONF_PRIOR): vol.Coerce(float),
|
||||
@@ -194,9 +213,13 @@ async def async_setup_platform(
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up the Bayesian Binary sensor."""
|
||||
"""Set up the Bayesian Binary sensor from a yaml config."""
|
||||
_LOGGER.debug(
|
||||
"Setting up config entry for Bayesian sensor: '%s' with %s observations",
|
||||
config[CONF_NAME],
|
||||
len(config.get(CONF_OBSERVATIONS, [])),
|
||||
)
|
||||
await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
|
||||
|
||||
name: str = config[CONF_NAME]
|
||||
unique_id: str | None = config.get(CONF_UNIQUE_ID)
|
||||
observations: list[ConfigType] = config[CONF_OBSERVATIONS]
|
||||
@@ -231,6 +254,42 @@ async def async_setup_platform(
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up the Bayesian Binary sensor from a config entry."""
|
||||
_LOGGER.debug(
|
||||
"Setting up config entry for Bayesian sensor: '%s' with %s observations",
|
||||
config_entry.options[CONF_NAME],
|
||||
len(config_entry.subentries),
|
||||
)
|
||||
config = config_entry.options
|
||||
name: str = config[CONF_NAME]
|
||||
unique_id: str | None = config.get(CONF_UNIQUE_ID, config_entry.entry_id)
|
||||
observations: list[ConfigType] = [
|
||||
dict(subentry.data) for subentry in config_entry.subentries.values()
|
||||
]
|
||||
prior: float = config[CONF_PRIOR]
|
||||
probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD]
|
||||
device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS)
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
BayesianBinarySensor(
|
||||
name,
|
||||
unique_id,
|
||||
prior,
|
||||
observations,
|
||||
probability_threshold,
|
||||
device_class,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BayesianBinarySensor(BinarySensorEntity):
|
||||
"""Representation of a Bayesian sensor."""
|
||||
|
||||
@@ -248,6 +307,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||
"""Initialize the Bayesian sensor."""
|
||||
self._attr_name = name
|
||||
self._attr_unique_id = unique_id and f"bayesian-{unique_id}"
|
||||
|
||||
self._observations = [
|
||||
Observation(
|
||||
entity_id=observation.get(CONF_ENTITY_ID),
|
||||
@@ -432,7 +492,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||
1 - observation.prob_given_false,
|
||||
)
|
||||
continue
|
||||
# observation.observed is None
|
||||
# Entity exists but observation.observed is None
|
||||
if observation.entity_id is not None:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
@@ -495,7 +555,10 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||
for observation in self._observations:
|
||||
if observation.value_template is None:
|
||||
continue
|
||||
|
||||
if isinstance(observation.value_template, str):
|
||||
observation.value_template = Template(
|
||||
observation.value_template, hass=self.hass
|
||||
)
|
||||
template = observation.value_template
|
||||
observations_by_template.setdefault(template, []).append(observation)
|
||||
|
||||
|
Reference in New Issue
Block a user