Baysesian Config Flow (#122552)

Co-authored-by: G Johansson <goran.johansson@shiftit.se>
Co-authored-by: Norbert Rittel <norbert@rittel.de>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
HarvsG
2025-08-26 20:15:57 +03:00
committed by GitHub
parent 87f0703be1
commit ecb51ce185
11 changed files with 2235 additions and 31 deletions

View File

@@ -16,6 +16,7 @@ from homeassistant.components.binary_sensor import (
BinarySensorDeviceClass,
BinarySensorEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
CONF_ABOVE,
CONF_BELOW,
@@ -32,7 +33,10 @@ from homeassistant.const import (
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
from homeassistant.exceptions import ConditionError, TemplateError
from homeassistant.helpers import condition, config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.entity_platform import (
AddConfigEntryEntitiesCallback,
AddEntitiesCallback,
)
from homeassistant.helpers.event import (
TrackTemplate,
TrackTemplateResult,
@@ -44,7 +48,6 @@ from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.template import Template, result_as_boolean
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DOMAIN, PLATFORMS
from .const import (
ATTR_OBSERVATIONS,
ATTR_OCCURRED_OBSERVATION_ENTITIES,
@@ -60,6 +63,8 @@ from .const import (
CONF_TO_STATE,
DEFAULT_NAME,
DEFAULT_PROBABILITY_THRESHOLD,
DOMAIN,
PLATFORMS,
)
from .helpers import Observation
from .issues import raise_mirrored_entries, raise_no_prob_given_false
@@ -67,7 +72,13 @@ from .issues import raise_mirrored_entries, raise_no_prob_given_false
_LOGGER = logging.getLogger(__name__)
def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
def above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
"""Validate above and below options.
If the observation is of type/platform NUMERIC_STATE, then ensure that the
value given for 'above' is not greater than that for 'below'. Also check
that at least one of the two is specified.
"""
if config[CONF_PLATFORM] == CONF_NUMERIC_STATE:
above = config.get(CONF_ABOVE)
below = config.get(CONF_BELOW)
@@ -76,9 +87,7 @@ def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
"For bayesian numeric state for entity: %s at least one of 'above' or 'below' must be specified",
config[CONF_ENTITY_ID],
)
raise vol.Invalid(
"For bayesian numeric state at least one of 'above' or 'below' must be specified."
)
raise vol.Invalid("above_or_below")
if above is not None and below is not None:
if above > below:
_LOGGER.error(
@@ -86,7 +95,7 @@ def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
above,
below,
)
raise vol.Invalid("'above' is greater than 'below'")
raise vol.Invalid("above_below")
return config
@@ -102,11 +111,16 @@ NUMERIC_STATE_SCHEMA = vol.All(
},
required=True,
),
_above_greater_than_below,
above_greater_than_below,
)
def _no_overlapping(configs: list[dict]) -> list[dict]:
def no_overlapping(configs: list[dict]) -> list[dict]:
"""Validate that intervals are not overlapping.
For a list of observations ensure that there are no overlapping intervals
for NUMERIC_STATE observations for the same entity.
"""
numeric_configs = [
config for config in configs if config[CONF_PLATFORM] == CONF_NUMERIC_STATE
]
@@ -129,11 +143,16 @@ def _no_overlapping(configs: list[dict]) -> list[dict]:
for i, tup in enumerate(intervals):
if len(intervals) > i + 1 and tup.below > intervals[i + 1].above:
_LOGGER.error(
"Ranges for bayesian numeric state entities must not overlap, but %s has overlapping ranges, above:%s, below:%s overlaps with above:%s, below:%s",
ent_id,
tup.above,
tup.below,
intervals[i + 1].above,
intervals[i + 1].below,
)
raise vol.Invalid(
"Ranges for bayesian numeric state entities must not overlap, "
f"but {ent_id} has overlapping ranges, above:{tup.above}, "
f"below:{tup.below} overlaps with above:{intervals[i + 1].above}, "
f"below:{intervals[i + 1].below}."
"overlapping_ranges",
)
return configs
@@ -168,7 +187,7 @@ PLATFORM_SCHEMA = BINARY_SENSOR_PLATFORM_SCHEMA.extend(
vol.All(
cv.ensure_list,
[vol.Any(TEMPLATE_SCHEMA, STATE_SCHEMA, NUMERIC_STATE_SCHEMA)],
_no_overlapping,
no_overlapping,
)
),
vol.Required(CONF_PRIOR): vol.Coerce(float),
@@ -194,9 +213,13 @@ async def async_setup_platform(
async_add_entities: AddEntitiesCallback,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up the Bayesian Binary sensor."""
"""Set up the Bayesian Binary sensor from a yaml config."""
_LOGGER.debug(
"Setting up config entry for Bayesian sensor: '%s' with %s observations",
config[CONF_NAME],
len(config.get(CONF_OBSERVATIONS, [])),
)
await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
name: str = config[CONF_NAME]
unique_id: str | None = config.get(CONF_UNIQUE_ID)
observations: list[ConfigType] = config[CONF_OBSERVATIONS]
@@ -231,6 +254,42 @@ async def async_setup_platform(
)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up the Bayesian Binary sensor from a config entry."""
_LOGGER.debug(
"Setting up config entry for Bayesian sensor: '%s' with %s observations",
config_entry.options[CONF_NAME],
len(config_entry.subentries),
)
config = config_entry.options
name: str = config[CONF_NAME]
unique_id: str | None = config.get(CONF_UNIQUE_ID, config_entry.entry_id)
observations: list[ConfigType] = [
dict(subentry.data) for subentry in config_entry.subentries.values()
]
prior: float = config[CONF_PRIOR]
probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD]
device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS)
async_add_entities(
[
BayesianBinarySensor(
name,
unique_id,
prior,
observations,
probability_threshold,
device_class,
)
]
)
class BayesianBinarySensor(BinarySensorEntity):
"""Representation of a Bayesian sensor."""
@@ -248,6 +307,7 @@ class BayesianBinarySensor(BinarySensorEntity):
"""Initialize the Bayesian sensor."""
self._attr_name = name
self._attr_unique_id = unique_id and f"bayesian-{unique_id}"
self._observations = [
Observation(
entity_id=observation.get(CONF_ENTITY_ID),
@@ -432,7 +492,7 @@ class BayesianBinarySensor(BinarySensorEntity):
1 - observation.prob_given_false,
)
continue
# observation.observed is None
# Entity exists but observation.observed is None
if observation.entity_id is not None:
_LOGGER.debug(
(
@@ -495,7 +555,10 @@ class BayesianBinarySensor(BinarySensorEntity):
for observation in self._observations:
if observation.value_template is None:
continue
if isinstance(observation.value_template, str):
observation.value_template = Template(
observation.value_template, hass=self.hass
)
template = observation.value_template
observations_by_template.setdefault(template, []).append(observation)