diff --git a/homeassistant/components/sensor/device_trigger.py b/homeassistant/components/sensor/device_trigger.py index 0bca1e299d6..9729d629d1c 100644 --- a/homeassistant/components/sensor/device_trigger.py +++ b/homeassistant/components/sensor/device_trigger.py @@ -9,8 +9,6 @@ from homeassistant.components.homeassistant.triggers import ( numeric_state as numeric_state_trigger, ) from homeassistant.const import ( - ATTR_DEVICE_CLASS, - ATTR_UNIT_OF_MEASUREMENT, CONF_ABOVE, CONF_BELOW, CONF_ENTITY_ID, @@ -30,7 +28,9 @@ from homeassistant.const import ( DEVICE_CLASS_TEMPERATURE, DEVICE_CLASS_VOLTAGE, ) +from homeassistant.core import HomeAssistantError from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.entity import get_device_class, get_unit_of_measurement from homeassistant.helpers.entity_registry import async_entries_for_device from . import DOMAIN @@ -134,18 +134,12 @@ async def async_get_triggers(hass, device_id): ] for entry in entries: - device_class = DEVICE_CLASS_NONE - state = hass.states.get(entry.entity_id) - unit_of_measurement = ( - state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state else None - ) + device_class = get_device_class(hass, entry.entity_id) or DEVICE_CLASS_NONE + unit_of_measurement = get_unit_of_measurement(hass, entry.entity_id) - if not state or not unit_of_measurement: + if not unit_of_measurement: continue - if ATTR_DEVICE_CLASS in state.attributes: - device_class = state.attributes[ATTR_DEVICE_CLASS] - templates = ENTITY_TRIGGERS.get( device_class, ENTITY_TRIGGERS[DEVICE_CLASS_NONE] ) @@ -166,15 +160,14 @@ async def async_get_triggers(hass, device_id): async def async_get_trigger_capabilities(hass, config): """List trigger capabilities.""" - state = hass.states.get(config[CONF_ENTITY_ID]) - unit_of_measurement = ( - state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state else None - ) + try: + unit_of_measurement = get_unit_of_measurement(hass, config[CONF_ENTITY_ID]) + except HomeAssistantError: + unit_of_measurement = None - if not state or not unit_of_measurement: + if not unit_of_measurement: raise InvalidDeviceAutomationConfig( - "No state or unit of measurement found for " - f"trigger entity {config[CONF_ENTITY_ID]}" + f"No unit of measurement found for trigger entity {config[CONF_ENTITY_ID]}" ) return { diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 4afab38fabf..dce68c80871 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -93,6 +93,23 @@ def async_generate_entity_id( return test_string +def get_device_class(hass: HomeAssistant, entity_id: str) -> str | None: + """Get device class of an entity. + + First try the statemachine, then entity registry. + """ + state = hass.states.get(entity_id) + if state: + return state.attributes.get(ATTR_DEVICE_CLASS) + + entity_registry = er.async_get(hass) + entry = entity_registry.async_get(entity_id) + if not entry: + raise HomeAssistantError(f"Unknown entity {entity_id}") + + return entry.device_class + + def get_supported_features(hass: HomeAssistant, entity_id: str) -> int: """Get supported features for an entity. @@ -110,6 +127,23 @@ def get_supported_features(hass: HomeAssistant, entity_id: str) -> int: return entry.supported_features or 0 +def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None: + """Get unit of measurement class of an entity. + + First try the statemachine, then entity registry. + """ + state = hass.states.get(entity_id) + if state: + return state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + + entity_registry = er.async_get(hass) + entry = entity_registry.async_get(entity_id) + if not entry: + raise HomeAssistantError(f"Unknown entity {entity_id}") + + return entry.unit_of_measurement + + class DeviceInfo(TypedDict, total=False): """Entity device information for device registry.""" diff --git a/tests/components/sensor/test_device_trigger.py b/tests/components/sensor/test_device_trigger.py index 9da93510523..ce35e2506a9 100644 --- a/tests/components/sensor/test_device_trigger.py +++ b/tests/components/sensor/test_device_trigger.py @@ -6,7 +6,12 @@ import pytest import homeassistant.components.automation as automation from homeassistant.components.sensor import DOMAIN from homeassistant.components.sensor.device_trigger import ENTITY_TRIGGERS -from homeassistant.const import CONF_PLATFORM, PERCENTAGE, STATE_UNKNOWN +from homeassistant.const import ( + CONF_PLATFORM, + DEVICE_CLASS_BATTERY, + PERCENTAGE, + STATE_UNKNOWN, +) from homeassistant.helpers import device_registry from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -85,8 +90,22 @@ async def test_get_triggers(hass, device_reg, entity_reg, enable_custom_integrat assert triggers == expected_triggers +@pytest.mark.parametrize( + "set_state,device_class_reg,device_class_state,unit_reg,unit_state", + [ + (False, DEVICE_CLASS_BATTERY, None, PERCENTAGE, None), + (True, None, DEVICE_CLASS_BATTERY, None, PERCENTAGE), + ], +) async def test_get_trigger_capabilities( - hass, device_reg, entity_reg, enable_custom_integrations + hass, + device_reg, + entity_reg, + set_state, + device_class_reg, + device_class_state, + unit_reg, + unit_state, ): """Test we get the expected capabilities from a sensor trigger.""" platform = getattr(hass.components, f"test.{DOMAIN}") @@ -98,15 +117,20 @@ async def test_get_trigger_capabilities( config_entry_id=config_entry.entry_id, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_reg.async_get_or_create( + entity_id = entity_reg.async_get_or_create( DOMAIN, "test", platform.ENTITIES["battery"].unique_id, device_id=device_entry.id, - ) - - assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) - await hass.async_block_till_done() + device_class=device_class_reg, + unit_of_measurement=unit_reg, + ).entity_id + if set_state: + hass.states.async_set( + entity_id, + None, + {"device_class": device_class_state, "unit_of_measurement": unit_state}, + ) expected_capabilities = { "extra_fields": [