diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 200d678719a..7377120af40 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -33,7 +33,7 @@ from homeassistant.const import ( ) from homeassistant.core import State, callback, split_entity_id, valid_entity_id from homeassistant.exceptions import TemplateError -from homeassistant.helpers import location as loc_helper +from homeassistant.helpers import entity_registry, location as loc_helper from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType from homeassistant.loader import bind_hass from homeassistant.util import convert, dt as dt_util, location as loc_util @@ -48,6 +48,7 @@ DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S" _RENDER_INFO = "template.render_info" _ENVIRONMENT = "template.environment" +_ENVIRONMENT_LIMITED = "template.environment_limited" _RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{|\{#") # Match "simple" ints and floats. -1.0, 1, +5, 5.0 @@ -300,11 +301,12 @@ class Template: @property def _env(self) -> TemplateEnvironment: - if self.hass is None or self._limited: + if self.hass is None: return _NO_HASS_ENV - ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT) + wanted_env = _ENVIRONMENT_LIMITED if self._limited else _ENVIRONMENT + ret: Optional[TemplateEnvironment] = self.hass.data.get(wanted_env) if ret is None: - ret = self.hass.data[_ENVIRONMENT] = TemplateEnvironment(self.hass) # type: ignore[no-untyped-call] + ret = self.hass.data[wanted_env] = TemplateEnvironment(self.hass, self._limited) # type: ignore[no-untyped-call] return ret def ensure_valid(self) -> None: @@ -867,6 +869,13 @@ def expand(hass: HomeAssistantType, *args: Any) -> Iterable[State]: return sorted(found.values(), key=lambda a: a.entity_id) +def device_entities(hass: HomeAssistantType, device_id: str) -> Iterable[str]: + """Get entity ids for entities tied to a device.""" + entity_reg = entity_registry.async_get(hass) + entries = entity_registry.async_entries_for_device(entity_reg, device_id) + return [entry.entity_id for entry in entries] + + def closest(hass, *args): """Find closest entity. @@ -1311,7 +1320,7 @@ def urlencode(value): class TemplateEnvironment(ImmutableSandboxedEnvironment): """The Home Assistant template environment.""" - def __init__(self, hass): + def __init__(self, hass, limited=False): """Initialise template environment.""" super().__init__() self.hass = hass @@ -1368,7 +1377,27 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): self.globals["strptime"] = strptime self.globals["urlencode"] = urlencode if hass is None: + return + # We mark these as a context functions to ensure they get + # evaluated fresh with every execution, rather than executed + # at compile time and the value stored. The context itself + # can be discarded, we only need to get at the hass object. + def hassfunction(func): + """Wrap function that depend on hass.""" + + @wraps(func) + def wrapper(*args, **kwargs): + return func(hass, *args[1:], **kwargs) + + return contextfunction(wrapper) + + self.globals["device_entities"] = hassfunction(device_entities) + self.filters["device_entities"] = contextfilter(self.globals["device_entities"]) + + if limited: + # Only device_entities is available to limited templates, mark other + # functions and filters as unsupported. def unsupported(name): def warn_unsupported(*args, **kwargs): raise TemplateError( @@ -1395,19 +1424,6 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): self.filters[filt] = unsupported(filt) return - # We mark these as a context functions to ensure they get - # evaluated fresh with every execution, rather than executed - # at compile time and the value stored. The context itself - # can be discarded, we only need to get at the hass object. - def hassfunction(func): - """Wrap function that depend on hass.""" - - @wraps(func) - def wrapper(*args, **kwargs): - return func(hass, *args[1:], **kwargs) - - return contextfunction(wrapper) - self.globals["expand"] = hassfunction(expand) self.filters["expand"] = contextfilter(self.globals["expand"]) self.globals["closest"] = hassfunction(closest) diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index 174d61ea470..4259e7302ed 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -24,6 +24,8 @@ from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util from homeassistant.util.unit_system import UnitSystem +from tests.common import MockConfigEntry, mock_device_registry, mock_registry + def _set_up_units(hass): """Set up the tests.""" @@ -1470,6 +1472,79 @@ async def test_expand(hass): assert info.rate_limit is None +async def test_device_entities(hass): + """Test expand function.""" + config_entry = MockConfigEntry(domain="light") + device_registry = mock_device_registry(hass) + entity_registry = mock_registry(hass) + + # Test non existing device ids + info = render_to_info(hass, "{{ device_entities('abc123') }}") + assert_result_info(info, []) + assert info.rate_limit is None + + info = render_to_info(hass, "{{ device_entities(56) }}") + assert_result_info(info, []) + assert info.rate_limit is None + + # Test device without entities + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "12:34:56:AB:CD:EF")}, + ) + info = render_to_info(hass, f"{{{{ device_entities('{device_entry.id}') }}}}") + assert_result_info(info, []) + assert info.rate_limit is None + + # Test device with single entity, which has no state + entity_registry.async_get_or_create( + "light", + "hue", + "5678", + config_entry=config_entry, + device_id=device_entry.id, + ) + info = render_to_info(hass, f"{{{{ device_entities('{device_entry.id}') }}}}") + assert_result_info(info, ["light.hue_5678"], []) + assert info.rate_limit is None + info = render_to_info( + hass, + f"{{{{ device_entities('{device_entry.id}') | expand | map(attribute='entity_id') | join(', ') }}}}", + ) + assert_result_info(info, "", ["light.hue_5678"]) + assert info.rate_limit is None + + # Test device with single entity, with state + hass.states.async_set("light.hue_5678", "happy") + info = render_to_info( + hass, + f"{{{{ device_entities('{device_entry.id}') | expand | map(attribute='entity_id') | join(', ') }}}}", + ) + assert_result_info(info, "light.hue_5678", ["light.hue_5678"]) + assert info.rate_limit is None + + # Test device with multiple entities, which have a state + entity_registry.async_get_or_create( + "light", + "hue", + "ABCD", + config_entry=config_entry, + device_id=device_entry.id, + ) + hass.states.async_set("light.hue_abcd", "camper") + info = render_to_info(hass, f"{{{{ device_entities('{device_entry.id}') }}}}") + assert_result_info(info, ["light.hue_5678", "light.hue_abcd"], []) + assert info.rate_limit is None + info = render_to_info( + hass, + f"{{{{ device_entities('{device_entry.id}') | expand | map(attribute='entity_id') | join(', ') }}}}", + ) + assert_result_info( + info, "light.hue_5678, light.hue_abcd", ["light.hue_5678", "light.hue_abcd"] + ) + assert info.rate_limit is None + + def test_closest_function_to_coord(hass): """Test closest function to coord.""" hass.states.async_set(