From 73bc0267e9b40ef170080011e2a31b454f888f5d Mon Sep 17 00:00:00 2001 From: Raman Gupta <7243222+raman325@users.noreply.github.com> Date: Tue, 27 Jul 2021 18:55:55 -0400 Subject: [PATCH] Add DeviceRegistry template functions (#53131) --- homeassistant/helpers/template.py | 57 +++++++++++- tests/helpers/test_template.py | 145 ++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 4 deletions(-) diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index d991a0b58f2..66354aa7aa6 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -43,7 +43,11 @@ from homeassistant.core import ( valid_entity_id, ) from homeassistant.exceptions import TemplateError -from homeassistant.helpers import entity_registry, location as loc_helper +from homeassistant.helpers import ( + device_registry, + entity_registry, + location as loc_helper, +) from homeassistant.helpers.typing import TemplateVarsType from homeassistant.loader import bind_hass from homeassistant.util import convert, dt as dt_util, location as loc_util @@ -902,13 +906,49 @@ def expand(hass: HomeAssistant, *args: Any) -> Iterable[State]: return sorted(found.values(), key=lambda a: a.entity_id) -def device_entities(hass: HomeAssistant, device_id: str) -> Iterable[str]: +def device_entities(hass: HomeAssistant, _device_id: str) -> Iterable[str]: """Get entity ids for entities tied to a device.""" entity_reg = entity_registry.async_get(hass) - entries = entity_registry.async_entries_for_device(entity_reg, device_id) + entries = entity_registry.async_entries_for_device(entity_reg, _device_id) return [entry.entity_id for entry in entries] +def device_id(hass: HomeAssistant, entity_id: str) -> str | None: + """Get a device ID from an entity ID.""" + if not isinstance(entity_id, str) or "." not in entity_id: + raise TemplateError(f"Must provide an entity ID, got {entity_id}") # type: ignore + entity_reg = entity_registry.async_get(hass) + entity = entity_reg.async_get(entity_id) + if entity is None: + return None + return entity.device_id + + +def device_attr(hass: HomeAssistant, device_or_entity_id: str, attr_name: str) -> Any: + """Get the device specific attribute.""" + device_reg = device_registry.async_get(hass) + if not isinstance(device_or_entity_id, str): + raise TemplateError("Must provide a device or entity ID") + device = None + if ( + "." in device_or_entity_id + and (_device_id := device_id(hass, device_or_entity_id)) is not None + ): + device = device_reg.async_get(_device_id) + elif "." not in device_or_entity_id: + device = device_reg.async_get(device_or_entity_id) + if device is None or not hasattr(device, attr_name): + return None + return getattr(device, attr_name) + + +def is_device_attr( + hass: HomeAssistant, device_or_entity_id: str, attr_name: str, attr_value: Any +) -> bool: + """Test if a device's attribute is a specific value.""" + return bool(device_attr(hass, device_or_entity_id, attr_name) == attr_value) + + def closest(hass, *args): """Find closest entity. @@ -1486,6 +1526,12 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): self.globals["device_entities"] = hassfunction(device_entities) self.filters["device_entities"] = pass_context(self.globals["device_entities"]) + self.globals["device_attr"] = hassfunction(device_attr) + self.globals["is_device_attr"] = hassfunction(is_device_attr) + + self.globals["device_id"] = hassfunction(device_id) + self.filters["device_id"] = pass_context(self.globals["device_id"]) + if limited: # Only device_entities is available to limited templates, mark other # functions and filters as unsupported. @@ -1507,8 +1553,11 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): "states", "utcnow", "now", + "device_attr", + "is_device_attr", + "device_id", ] - hass_filters = ["closest", "expand"] + hass_filters = ["closest", "expand", "device_id"] for glob in hass_globals: self.globals[glob] = unsupported(glob) for filt in hass_filters: diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index 2547537bff9..d6fe2b6dbaf 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -1585,6 +1585,151 @@ async def test_device_entities(hass): assert info.rate_limit is None +async def test_device_id(hass): + """Test device_id function.""" + config_entry = MockConfigEntry(domain="light") + device_registry = mock_device_registry(hass) + entity_registry = mock_registry(hass) + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + model="test", + ) + entity_entry = entity_registry.async_get_or_create( + "sensor", "test", "test", suggested_object_id="test", device_id=device_entry.id + ) + entity_entry_no_device = entity_registry.async_get_or_create( + "sensor", "test", "test_no_device", suggested_object_id="test" + ) + + info = render_to_info(hass, "{{ 'sensor.fail' | device_id }}") + assert_result_info(info, None) + assert info.rate_limit is None + + with pytest.raises(TemplateError): + info = render_to_info(hass, "{{ 56 | device_id }}") + assert_result_info(info, None) + + with pytest.raises(TemplateError): + info = render_to_info(hass, "{{ 'not_a_real_entity_id' | device_id }}") + assert_result_info(info, None) + + info = render_to_info( + hass, f"{{{{ device_id('{entity_entry_no_device.entity_id}') }}}}" + ) + assert_result_info(info, None) + assert info.rate_limit is None + + info = render_to_info(hass, f"{{{{ device_id('{entity_entry.entity_id}') }}}}") + assert_result_info(info, device_entry.id) + assert info.rate_limit is None + + +async def test_device_attr(hass): + """Test device_attr and is_device_attr functions.""" + config_entry = MockConfigEntry(domain="light") + device_registry = mock_device_registry(hass) + entity_registry = mock_registry(hass) + + # Test non existing device ids (device_attr) + info = render_to_info(hass, "{{ device_attr('abc123', 'id') }}") + assert_result_info(info, None) + assert info.rate_limit is None + + with pytest.raises(TemplateError): + info = render_to_info(hass, "{{ device_attr(56, 'id') }}") + assert_result_info(info, None) + + # Test non existing device ids (is_device_attr) + info = render_to_info(hass, "{{ is_device_attr('abc123', 'id', 'test') }}") + assert_result_info(info, False) + assert info.rate_limit is None + + with pytest.raises(TemplateError): + info = render_to_info(hass, "{{ is_device_attr(56, 'id', 'test') }}") + assert_result_info(info, False) + + # Test non existing entity id (device_attr) + info = render_to_info(hass, "{{ device_attr('entity.test', 'id') }}") + assert_result_info(info, None) + assert info.rate_limit is None + + # Test non existing entity id (is_device_attr) + info = render_to_info(hass, "{{ is_device_attr('entity.test', 'id', 'test') }}") + assert_result_info(info, False) + assert info.rate_limit is None + + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + model="test", + ) + entity_entry = entity_registry.async_get_or_create( + "sensor", "test", "test", suggested_object_id="test", device_id=device_entry.id + ) + + # Test non existent device attribute (device_attr) + info = render_to_info( + hass, f"{{{{ device_attr('{device_entry.id}', 'invalid_attr') }}}}" + ) + assert_result_info(info, None) + assert info.rate_limit is None + + # Test non existent device attribute (is_device_attr) + info = render_to_info( + hass, f"{{{{ is_device_attr('{device_entry.id}', 'invalid_attr', 'test') }}}}" + ) + assert_result_info(info, False) + assert info.rate_limit is None + + # Test None device attribute (device_attr) + info = render_to_info( + hass, f"{{{{ device_attr('{device_entry.id}', 'manufacturer') }}}}" + ) + assert_result_info(info, None) + assert info.rate_limit is None + + # Test None device attribute mismatch (is_device_attr) + info = render_to_info( + hass, f"{{{{ is_device_attr('{device_entry.id}', 'manufacturer', 'test') }}}}" + ) + assert_result_info(info, False) + assert info.rate_limit is None + + # Test None device attribute match (is_device_attr) + info = render_to_info( + hass, f"{{{{ is_device_attr('{device_entry.id}', 'manufacturer', None) }}}}" + ) + assert_result_info(info, True) + assert info.rate_limit is None + + # Test valid device attribute match (device_attr) + info = render_to_info(hass, f"{{{{ device_attr('{device_entry.id}', 'model') }}}}") + assert_result_info(info, "test") + assert info.rate_limit is None + + # Test valid device attribute match (device_attr) + info = render_to_info( + hass, f"{{{{ device_attr('{entity_entry.entity_id}', 'model') }}}}" + ) + assert_result_info(info, "test") + assert info.rate_limit is None + + # Test valid device attribute mismatch (is_device_attr) + info = render_to_info( + hass, f"{{{{ is_device_attr('{device_entry.id}', 'model', 'fail') }}}}" + ) + assert_result_info(info, False) + assert info.rate_limit is None + + # Test valid device attribute match (is_device_attr) + info = render_to_info( + hass, f"{{{{ is_device_attr('{device_entry.id}', 'model', 'test') }}}}" + ) + assert_result_info(info, True) + assert info.rate_limit is None + + def test_closest_function_to_coord(hass): """Test closest function to coord.""" hass.states.async_set(