Include template listener info in template preview (#99669)

This commit is contained in:
Erik Montnemery 2023-09-06 09:49:42 +02:00 committed by Bram Kragten
parent 9d87e8d02b
commit 0cbcacbbf5
4 changed files with 57 additions and 13 deletions

View File

@ -349,6 +349,7 @@ def ws_start_preview(
def async_preview_updated( def async_preview_updated(
state: str | None, state: str | None,
attributes: Mapping[str, Any] | None, attributes: Mapping[str, Any] | None,
listeners: dict[str, bool | set[str]] | None,
error: str | None, error: str | None,
) -> None: ) -> None:
"""Forward config entry state events to websocket.""" """Forward config entry state events to websocket."""
@ -363,7 +364,7 @@ def ws_start_preview(
connection.send_message( connection.send_message(
websocket_api.event_message( websocket_api.event_message(
msg["id"], msg["id"],
{"attributes": attributes, "state": state}, {"attributes": attributes, "listeners": listeners, "state": state},
) )
) )

View File

@ -34,6 +34,7 @@ from homeassistant.helpers.event import (
EventStateChangedData, EventStateChangedData,
TrackTemplate, TrackTemplate,
TrackTemplateResult, TrackTemplateResult,
TrackTemplateResultInfo,
async_track_template_result, async_track_template_result,
) )
from homeassistant.helpers.script import Script, _VarsType from homeassistant.helpers.script import Script, _VarsType
@ -260,12 +261,18 @@ class TemplateEntity(Entity):
) -> None: ) -> None:
"""Template Entity.""" """Template Entity."""
self._template_attrs: dict[Template, list[_TemplateAttribute]] = {} self._template_attrs: dict[Template, list[_TemplateAttribute]] = {}
self._async_update: Callable[[], None] | None = None self._template_result_info: TrackTemplateResultInfo | None = None
self._attr_extra_state_attributes = {} self._attr_extra_state_attributes = {}
self._self_ref_update_count = 0 self._self_ref_update_count = 0
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
self._preview_callback: Callable[ self._preview_callback: Callable[
[str | None, dict[str, Any] | None, str | None], None [
str | None,
dict[str, Any] | None,
dict[str, bool | set[str]] | None,
str | None,
],
None,
] | None = None ] | None = None
if config is None: if config is None:
self._attribute_templates = attribute_templates self._attribute_templates = attribute_templates
@ -427,9 +434,12 @@ class TemplateEntity(Entity):
state, attrs = self._async_generate_attributes() state, attrs = self._async_generate_attributes()
validate_state(state) validate_state(state)
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
self._preview_callback(None, None, str(err)) self._preview_callback(None, None, None, str(err))
else: else:
self._preview_callback(state, attrs, None) assert self._template_result_info
self._preview_callback(
state, attrs, self._template_result_info.listeners, None
)
@callback @callback
def _async_template_startup(self, *_: Any) -> None: def _async_template_startup(self, *_: Any) -> None:
@ -460,7 +470,7 @@ class TemplateEntity(Entity):
has_super_template=has_availability_template, has_super_template=has_availability_template,
) )
self.async_on_remove(result_info.async_remove) self.async_on_remove(result_info.async_remove)
self._async_update = result_info.async_refresh self._template_result_info = result_info
result_info.async_refresh() result_info.async_refresh()
@callback @callback
@ -494,7 +504,13 @@ class TemplateEntity(Entity):
def async_start_preview( def async_start_preview(
self, self,
preview_callback: Callable[ preview_callback: Callable[
[str | None, Mapping[str, Any] | None, str | None], None [
str | None,
Mapping[str, Any] | None,
dict[str, bool | set[str]] | None,
str | None,
],
None,
], ],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Render a preview.""" """Render a preview."""
@ -504,7 +520,7 @@ class TemplateEntity(Entity):
try: try:
self._async_template_startup() self._async_template_startup()
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
preview_callback(None, None, str(err)) preview_callback(None, None, None, str(err))
return self._call_on_remove_callbacks return self._call_on_remove_callbacks
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
@ -521,8 +537,8 @@ class TemplateEntity(Entity):
async def async_update(self) -> None: async def async_update(self) -> None:
"""Call for forced update.""" """Call for forced update."""
assert self._async_update assert self._template_result_info
self._async_update() self._template_result_info.async_refresh()
async def async_run_script( async def async_run_script(
self, self,

View File

@ -77,8 +77,8 @@ class TriggerBaseEntity(Entity):
"""Template Base entity based on trigger data.""" """Template Base entity based on trigger data."""
domain: str domain: str
extra_template_keys: tuple | None = None extra_template_keys: tuple[str, ...] | None = None
extra_template_keys_complex: tuple | None = None extra_template_keys_complex: tuple[str, ...] | None = None
_unique_id: str | None _unique_id: str | None
def __init__( def __init__(
@ -94,7 +94,7 @@ class TriggerBaseEntity(Entity):
self._config = config self._config = config
self._static_rendered = {} self._static_rendered = {}
self._to_render_simple = [] self._to_render_simple: list[str] = []
self._to_render_complex: list[str] = [] self._to_render_complex: list[str] = []
for itm in ( for itm in (

View File

@ -3,6 +3,7 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from pytest_unordered import unordered
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.template import DOMAIN, async_setup_entry from homeassistant.components.template import DOMAIN, async_setup_entry
@ -257,6 +258,7 @@ async def test_options(
"input_states", "input_states",
"template_states", "template_states",
"extra_attributes", "extra_attributes",
"listeners",
), ),
( (
( (
@ -266,6 +268,7 @@ async def test_options(
{"one": "on", "two": "off"}, {"one": "on", "two": "off"},
["off", "on"], ["off", "on"],
[{}, {}], [{}, {}],
[["one", "two"], ["one"]],
), ),
( (
"sensor", "sensor",
@ -274,6 +277,7 @@ async def test_options(
{"one": "30.0", "two": "20.0"}, {"one": "30.0", "two": "20.0"},
["unavailable", "50.0"], ["unavailable", "50.0"],
[{}, {}], [{}, {}],
[["one"], ["one", "two"]],
), ),
), ),
) )
@ -286,6 +290,7 @@ async def test_config_flow_preview(
input_states: list[str], input_states: list[str],
template_states: str, template_states: str,
extra_attributes: list[dict[str, Any]], extra_attributes: list[dict[str, Any]],
listeners: list[list[str]],
) -> None: ) -> None:
"""Test the config flow preview.""" """Test the config flow preview."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -323,6 +328,12 @@ async def test_config_flow_preview(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"] == { assert msg["event"] == {
"attributes": {"friendly_name": "My template"} | extra_attributes[0], "attributes": {"friendly_name": "My template"} | extra_attributes[0],
"listeners": {
"all": False,
"domains": [],
"entities": unordered([f"{template_type}.{_id}" for _id in listeners[0]]),
"time": False,
},
"state": template_states[0], "state": template_states[0],
} }
@ -336,6 +347,12 @@ async def test_config_flow_preview(
"attributes": {"friendly_name": "My template"} "attributes": {"friendly_name": "My template"}
| extra_attributes[0] | extra_attributes[0]
| extra_attributes[1], | extra_attributes[1],
"listeners": {
"all": False,
"domains": [],
"entities": unordered([f"{template_type}.{_id}" for _id in listeners[1]]),
"time": False,
},
"state": template_states[1], "state": template_states[1],
} }
assert len(hass.states.async_all()) == 2 assert len(hass.states.async_all()) == 2
@ -526,6 +543,7 @@ async def test_config_flow_preview_bad_state(
"input_states", "input_states",
"template_state", "template_state",
"extra_attributes", "extra_attributes",
"listeners",
), ),
[ [
( (
@ -537,6 +555,7 @@ async def test_config_flow_preview_bad_state(
{"one": "on", "two": "off"}, {"one": "on", "two": "off"},
"off", "off",
{}, {},
["one", "two"],
), ),
( (
"sensor", "sensor",
@ -547,6 +566,7 @@ async def test_config_flow_preview_bad_state(
{"one": "30.0", "two": "20.0"}, {"one": "30.0", "two": "20.0"},
"10.0", "10.0",
{}, {},
["one", "two"],
), ),
], ],
) )
@ -561,6 +581,7 @@ async def test_option_flow_preview(
input_states: list[str], input_states: list[str],
template_state: str, template_state: str,
extra_attributes: dict[str, Any], extra_attributes: dict[str, Any],
listeners: list[str],
) -> None: ) -> None:
"""Test the option flow preview.""" """Test the option flow preview."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -608,6 +629,12 @@ async def test_option_flow_preview(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"] == { assert msg["event"] == {
"attributes": {"friendly_name": "My template"} | extra_attributes, "attributes": {"friendly_name": "My template"} | extra_attributes,
"listeners": {
"all": False,
"domains": [],
"entities": unordered([f"{template_type}.{_id}" for _id in listeners]),
"time": False,
},
"state": template_state, "state": template_state,
} }
assert len(hass.states.async_all()) == 3 assert len(hass.states.async_all()) == 3