Add a preview to history_stats options flow (#145721)

This commit is contained in:
karwosts 2025-07-05 00:09:02 -07:00 committed by GitHub
parent 275d390a6c
commit 3cfff4de3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 536 additions and 10 deletions

View File

@ -3,11 +3,15 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from datetime import timedelta
from typing import Any, cast from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, CONF_STATE, CONF_TYPE from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, CONF_STATE, CONF_TYPE
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.schema_config_entry_flow import ( from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler, SchemaCommonFlowHandler,
SchemaConfigFlowHandler, SchemaConfigFlowHandler,
@ -26,6 +30,7 @@ from homeassistant.helpers.selector import (
TextSelector, TextSelector,
TextSelectorConfig, TextSelectorConfig,
) )
from homeassistant.helpers.template import Template
from .const import ( from .const import (
CONF_DURATION, CONF_DURATION,
@ -37,14 +42,21 @@ from .const import (
DEFAULT_NAME, DEFAULT_NAME,
DOMAIN, DOMAIN,
) )
from .coordinator import HistoryStatsUpdateCoordinator
from .data import HistoryStats
from .sensor import HistoryStatsSensor
def _validate_two_period_keys(user_input: dict[str, Any]) -> None:
if sum(param in user_input for param in CONF_PERIOD_KEYS) != 2:
raise SchemaFlowError("only_two_keys_allowed")
async def validate_options( async def validate_options(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any] handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Validate options selected.""" """Validate options selected."""
if sum(param in user_input for param in CONF_PERIOD_KEYS) != 2: _validate_two_period_keys(user_input)
raise SchemaFlowError("only_two_keys_allowed")
handler.parent_handler._async_abort_entries_match({**handler.options, **user_input}) # noqa: SLF001 handler.parent_handler._async_abort_entries_match({**handler.options, **user_input}) # noqa: SLF001
@ -97,12 +109,14 @@ CONFIG_FLOW = {
"options": SchemaFlowFormStep( "options": SchemaFlowFormStep(
schema=DATA_SCHEMA_OPTIONS, schema=DATA_SCHEMA_OPTIONS,
validate_user_input=validate_options, validate_user_input=validate_options,
preview="history_stats",
), ),
} }
OPTIONS_FLOW = { OPTIONS_FLOW = {
"init": SchemaFlowFormStep( "init": SchemaFlowFormStep(
DATA_SCHEMA_OPTIONS, DATA_SCHEMA_OPTIONS,
validate_user_input=validate_options, validate_user_input=validate_options,
preview="history_stats",
), ),
} }
@ -116,3 +130,115 @@ class HistoryStatsConfigFlowHandler(SchemaConfigFlowHandler, domain=DOMAIN):
def async_config_entry_title(self, options: Mapping[str, Any]) -> str: def async_config_entry_title(self, options: Mapping[str, Any]) -> str:
"""Return config entry title.""" """Return config entry title."""
return cast(str, options[CONF_NAME]) return cast(str, options[CONF_NAME])
@staticmethod
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview WS API."""
websocket_api.async_register_command(hass, ws_start_preview)
@websocket_api.websocket_command(
{
vol.Required("type"): "history_stats/start_preview",
vol.Required("flow_id"): str,
vol.Required("flow_type"): vol.Any("config_flow", "options_flow"),
vol.Required("user_input"): dict,
}
)
@websocket_api.async_response
async def ws_start_preview(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Generate a preview."""
if msg["flow_type"] == "config_flow":
flow_status = hass.config_entries.flow.async_get(msg["flow_id"])
flow_sets = hass.config_entries.flow._handler_progress_index.get( # noqa: SLF001
flow_status["handler"]
)
options = {}
assert flow_sets
for active_flow in flow_sets:
options = active_flow._common_handler.options # type: ignore [attr-defined] # noqa: SLF001
config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
entity_id = options[CONF_ENTITY_ID]
name = options[CONF_NAME]
else:
flow_status = hass.config_entries.options.async_get(msg["flow_id"])
config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
if not config_entry:
raise HomeAssistantError("Config entry not found")
entity_id = config_entry.options[CONF_ENTITY_ID]
name = config_entry.options[CONF_NAME]
@callback
def async_preview_updated(
last_exception: Exception | None, state: str, attributes: Mapping[str, Any]
) -> None:
"""Forward config entry state events to websocket."""
if last_exception:
connection.send_message(
websocket_api.event_message(
msg["id"], {"error": str(last_exception) or "Unknown error"}
)
)
else:
connection.send_message(
websocket_api.event_message(
msg["id"], {"attributes": attributes, "state": state}
)
)
for param in CONF_PERIOD_KEYS:
if param in msg["user_input"] and not bool(msg["user_input"][param]):
del msg["user_input"][param] # Remove falsy values before counting keys
validated_data: Any = None
try:
validated_data = DATA_SCHEMA_OPTIONS(msg["user_input"])
except vol.Invalid as ex:
connection.send_error(msg["id"], "invalid_schema", str(ex))
return
try:
_validate_two_period_keys(validated_data)
except SchemaFlowError:
connection.send_error(
msg["id"],
"invalid_schema",
f"Exactly two of {', '.join(CONF_PERIOD_KEYS)} required",
)
return
sensor_type = validated_data.get(CONF_TYPE)
entity_states = validated_data.get(CONF_STATE)
start = validated_data.get(CONF_START)
end = validated_data.get(CONF_END)
duration = validated_data.get(CONF_DURATION)
history_stats = HistoryStats(
hass,
entity_id,
entity_states,
Template(start, hass) if start else None,
Template(end, hass) if end else None,
timedelta(**duration) if duration else None,
True,
)
coordinator = HistoryStatsUpdateCoordinator(hass, history_stats, None, name, True)
await coordinator.async_refresh()
preview_entity = HistoryStatsSensor(
hass, coordinator, sensor_type, name, None, entity_id
)
preview_entity.hass = hass
connection.send_result(msg["id"])
cancel_listener = coordinator.async_setup_state_listener()
cancel_preview = await preview_entity.async_start_preview(async_preview_updated)
def unsub() -> None:
cancel_listener()
cancel_preview()
connection.subscriptions[msg["id"]] = unsub

View File

@ -36,12 +36,14 @@ class HistoryStatsUpdateCoordinator(DataUpdateCoordinator[HistoryStatsState]):
history_stats: HistoryStats, history_stats: HistoryStats,
config_entry: ConfigEntry | None, config_entry: ConfigEntry | None,
name: str, name: str,
preview: bool = False,
) -> None: ) -> None:
"""Initialize DataUpdateCoordinator.""" """Initialize DataUpdateCoordinator."""
self._history_stats = history_stats self._history_stats = history_stats
self._subscriber_count = 0 self._subscriber_count = 0
self._at_start_listener: CALLBACK_TYPE | None = None self._at_start_listener: CALLBACK_TYPE | None = None
self._track_events_listener: CALLBACK_TYPE | None = None self._track_events_listener: CALLBACK_TYPE | None = None
self._preview = preview
super().__init__( super().__init__(
hass, hass,
_LOGGER, _LOGGER,
@ -104,3 +106,8 @@ class HistoryStatsUpdateCoordinator(DataUpdateCoordinator[HistoryStatsState]):
return await self._history_stats.async_update(None) return await self._history_stats.async_update(None)
except (TemplateError, TypeError, ValueError) as ex: except (TemplateError, TypeError, ValueError) as ex:
raise UpdateFailed(ex) from ex raise UpdateFailed(ex) from ex
async def async_refresh(self) -> None:
"""Refresh data and log errors."""
log_failures = not self._preview
await self._async_refresh(log_failures)

View File

@ -47,6 +47,7 @@ class HistoryStats:
start: Template | None, start: Template | None,
end: Template | None, end: Template | None,
duration: datetime.timedelta | None, duration: datetime.timedelta | None,
preview: bool = False,
) -> None: ) -> None:
"""Init the history stats manager.""" """Init the history stats manager."""
self.hass = hass self.hass = hass
@ -59,6 +60,7 @@ class HistoryStats:
self._duration = duration self._duration = duration
self._start = start self._start = start
self._end = end self._end = end
self._preview = preview
self._pending_events: list[Event[EventStateChangedData]] = [] self._pending_events: list[Event[EventStateChangedData]] = []
self._query_count = 0 self._query_count = 0
@ -70,7 +72,9 @@ class HistoryStats:
# Get previous values of start and end # Get previous values of start and end
previous_period_start, previous_period_end = self._period previous_period_start, previous_period_end = self._period
# Parse templates # Parse templates
self._period = async_calculate_period(self._duration, self._start, self._end) self._period = async_calculate_period(
self._duration, self._start, self._end, log_errors=not self._preview
)
# Get the current period # Get the current period
current_period_start, current_period_end = self._period current_period_start, current_period_end = self._period

View File

@ -23,6 +23,7 @@ def async_calculate_period(
duration: datetime.timedelta | None, duration: datetime.timedelta | None,
start_template: Template | None, start_template: Template | None,
end_template: Template | None, end_template: Template | None,
log_errors: bool = True,
) -> tuple[datetime.datetime, datetime.datetime]: ) -> tuple[datetime.datetime, datetime.datetime]:
"""Parse the templates and return the period.""" """Parse the templates and return the period."""
bounds: dict[str, datetime.datetime | None] = { bounds: dict[str, datetime.datetime | None] = {
@ -37,13 +38,17 @@ def async_calculate_period(
if template is None: if template is None:
continue continue
try: try:
rendered = template.async_render() rendered = template.async_render(
log_fn=None if log_errors else lambda *args, **kwargs: None
)
except (TemplateError, TypeError) as ex: except (TemplateError, TypeError) as ex:
if ex.args and not ex.args[0].startswith( if (
"UndefinedError: 'None' has no attribute" log_errors
and ex.args
and not ex.args[0].startswith("UndefinedError: 'None' has no attribute")
): ):
_LOGGER.error("Error parsing template for field %s", bound, exc_info=ex) _LOGGER.error("Error parsing template for field %s", bound, exc_info=ex)
raise raise type(ex)(f"Error parsing template for field {bound}: {ex}") from ex
if isinstance(rendered, str): if isinstance(rendered, str):
bounds[bound] = dt_util.parse_datetime(rendered) bounds[bound] = dt_util.parse_datetime(rendered)
if bounds[bound] is not None: if bounds[bound] is not None:

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Mapping
import datetime import datetime
from typing import Any from typing import Any
@ -23,7 +24,7 @@ from homeassistant.const import (
PERCENTAGE, PERCENTAGE,
UnitOfTime, UnitOfTime,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import PlatformNotReady from homeassistant.exceptions import PlatformNotReady
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.device import async_device_info_to_link_from_entity
@ -183,6 +184,9 @@ class HistoryStatsSensor(HistoryStatsSensorBase):
) -> None: ) -> None:
"""Initialize the HistoryStats sensor.""" """Initialize the HistoryStats sensor."""
super().__init__(coordinator, name) super().__init__(coordinator, name)
self._preview_callback: (
Callable[[Exception | None, str, Mapping[str, Any]], None] | None
) = None
self._attr_native_unit_of_measurement = UNITS[sensor_type] self._attr_native_unit_of_measurement = UNITS[sensor_type]
self._type = sensor_type self._type = sensor_type
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
@ -212,3 +216,29 @@ class HistoryStatsSensor(HistoryStatsSensorBase):
self._attr_native_value = pretty_ratio(state.seconds_matched, state.period) self._attr_native_value = pretty_ratio(state.seconds_matched, state.period)
elif self._type == CONF_TYPE_COUNT: elif self._type == CONF_TYPE_COUNT:
self._attr_native_value = state.match_count self._attr_native_value = state.match_count
if self._preview_callback:
calculated_state = self._async_calculate_state()
self._preview_callback(
None, calculated_state.state, calculated_state.attributes
)
async def async_start_preview(
self,
preview_callback: Callable[[Exception | None, str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
self.async_on_remove(
self.coordinator.async_add_listener(self._process_update, None)
)
self._preview_callback = preview_callback
calculated_state = self._async_calculate_state()
preview_callback(
self.coordinator.last_exception,
calculated_state.state,
calculated_state.attributes,
)
return self._call_on_remove_callbacks

View File

@ -2,22 +2,28 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock import logging
from unittest.mock import AsyncMock, patch
from freezegun import freeze_time
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.history_stats.const import ( from homeassistant.components.history_stats.const import (
CONF_DURATION, CONF_DURATION,
CONF_END, CONF_END,
CONF_START, CONF_START,
CONF_TYPE_COUNT,
DEFAULT_NAME, DEFAULT_NAME,
DOMAIN, DOMAIN,
) )
from homeassistant.components.recorder import Recorder from homeassistant.components.recorder import Recorder
from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, CONF_STATE, CONF_TYPE from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, CONF_STATE, CONF_TYPE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, State
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from homeassistant.util import dt as dt_util
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
async def test_form( async def test_form(
@ -193,3 +199,351 @@ async def test_entry_already_exist(
assert result["type"] is FlowResultType.ABORT assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured" assert result["reason"] == "already_configured"
async def test_config_flow_preview_success(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test the config flow preview."""
client = await hass_ws_client(hass)
# add state for the tests
await hass.config.async_set_time_zone("UTC")
utcnow = dt_util.utcnow()
start_time = utcnow.replace(hour=0, minute=0, second=0, microsecond=0)
t1 = start_time.replace(hour=3)
t2 = start_time.replace(hour=4)
t3 = start_time.replace(hour=5)
monitored_entity = "binary_sensor.state"
def _fake_states(*args, **kwargs):
return {
monitored_entity: [
State(
monitored_entity,
"on",
last_changed=start_time,
last_updated=start_time,
),
State(
monitored_entity,
"off",
last_changed=t1,
last_updated=t1,
),
State(
monitored_entity,
"on",
last_changed=t2,
last_updated=t2,
),
]
}
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user"
assert result["errors"] is None
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_NAME: DEFAULT_NAME,
CONF_ENTITY_ID: monitored_entity,
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
},
)
await hass.async_block_till_done()
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "options"
assert result["errors"] is None
assert result["preview"] == "history_stats"
with (
patch(
"homeassistant.components.recorder.history.state_changes_during_period",
_fake_states,
),
freeze_time(t3),
):
await client.send_json_auto_id(
{
"type": "history_stats/start_preview",
"flow_id": result["flow_id"],
"flow_type": "config_flow",
"user_input": {
CONF_ENTITY_ID: monitored_entity,
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_END: "{{now()}}",
CONF_START: "{{ today_at() }}",
},
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
msg = await client.receive_json()
assert msg["event"]["state"] == "2"
async def test_options_flow_preview(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test the options flow preview."""
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)
client = await hass_ws_client(hass)
# add state for the tests
await hass.config.async_set_time_zone("UTC")
utcnow = dt_util.utcnow()
start_time = utcnow.replace(hour=0, minute=0, second=0, microsecond=0)
t1 = start_time.replace(hour=3)
t2 = start_time.replace(hour=4)
t3 = start_time.replace(hour=5)
monitored_entity = "binary_sensor.state"
def _fake_states(*args, **kwargs):
return {
monitored_entity: [
State(
monitored_entity,
"on",
last_changed=start_time,
last_updated=start_time,
),
State(
monitored_entity,
"off",
last_changed=t1,
last_updated=t1,
),
State(
monitored_entity,
"on",
last_changed=t2,
last_updated=t2,
),
State(
monitored_entity,
"off",
last_changed=t2,
last_updated=t2,
),
]
}
# Setup the config entry
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_NAME: DEFAULT_NAME,
CONF_ENTITY_ID: monitored_entity,
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_END: "{{ now() }}",
CONF_START: "{{ today_at() }}",
},
title=DEFAULT_NAME,
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
assert result["preview"] == "history_stats"
with (
patch(
"homeassistant.components.recorder.history.state_changes_during_period",
_fake_states,
),
freeze_time(t3),
):
for end, exp_count in (
("{{now()}}", "2"),
("{{today_at('2:00')}}", "1"),
("{{today_at('23:00')}}", "2"),
):
await client.send_json_auto_id(
{
"type": "history_stats/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_ENTITY_ID: monitored_entity,
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_END: end,
CONF_START: "{{ today_at() }}",
},
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
msg = await client.receive_json()
assert msg["event"]["state"] == exp_count
hass.states.async_set(monitored_entity, "on")
msg = await client.receive_json()
assert msg["event"]["state"] == "3"
async def test_options_flow_preview_errors(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test the options flow preview."""
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)
client = await hass_ws_client(hass)
# add state for the tests
monitored_entity = "binary_sensor.state"
# Setup the config entry
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_NAME: DEFAULT_NAME,
CONF_ENTITY_ID: monitored_entity,
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_END: "{{ now() }}",
CONF_START: "{{ today_at() }}",
},
title=DEFAULT_NAME,
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
assert result["preview"] == "history_stats"
for schema in (
{CONF_END: "{{ now() }"}, # Missing '}' at end of template
{CONF_START: "{{ today_at( }}"}, # Missing ')' in template function
{CONF_DURATION: {"hours": 1}}, # Specified 3 period keys (1 too many)
{CONF_START: ""}, # Specified 1 period keys (1 too few)
):
await client.send_json_auto_id(
{
"type": "history_stats/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_ENTITY_ID: monitored_entity,
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_END: "{{ now() }}",
CONF_START: "{{ today_at() }}",
**schema,
},
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "invalid_schema"
for schema in (
{CONF_END: "{{ nowwww() }}"}, # Unknown jinja function
{CONF_START: "{{ today_at('abcde') }}"}, # Invalid value passed to today_at
{CONF_END: '"{{ now() }}"'}, # Invalid quotes around template
):
await client.send_json_auto_id(
{
"type": "history_stats/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_ENTITY_ID: monitored_entity,
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_END: "{{ now() }}",
CONF_START: "{{ today_at() }}",
**schema,
},
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
msg = await client.receive_json()
assert msg["event"]["error"]
async def test_options_flow_sensor_preview_config_entry_removed(
recorder_mock: Recorder, hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test the option flow preview where the config entry is removed."""
client = await hass_ws_client(hass)
# Setup the config entry
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_NAME: DEFAULT_NAME,
CONF_ENTITY_ID: "sensor.test_monitored",
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_START: "0",
CONF_END: "1",
},
title=DEFAULT_NAME,
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
assert result["preview"] == "history_stats"
await hass.config_entries.async_remove(config_entry.entry_id)
await client.send_json_auto_id(
{
"type": "history_stats/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_ENTITY_ID: "sensor.test_monitored",
CONF_TYPE: CONF_TYPE_COUNT,
CONF_STATE: ["on"],
CONF_START: "0",
CONF_END: "1",
},
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Config entry not found",
}