From 06fc21e287046875934038e0d818c16b44fa27e4 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 14 Jun 2021 15:22:31 +0200 Subject: [PATCH] Improve editing of device conditions referencing non-added sensor (#51835) --- .../components/sensor/device_condition.py | 30 +++----- .../sensor/test_device_condition.py | 76 +++++++++++++++++-- 2 files changed, 80 insertions(+), 26 deletions(-) diff --git a/homeassistant/components/sensor/device_condition.py b/homeassistant/components/sensor/device_condition.py index 4d3d8a4b477..a77ed2d2cd7 100644 --- a/homeassistant/components/sensor/device_condition.py +++ b/homeassistant/components/sensor/device_condition.py @@ -7,8 +7,6 @@ from homeassistant.components.device_automation.exceptions import ( InvalidDeviceAutomationConfig, ) from homeassistant.const import ( - ATTR_DEVICE_CLASS, - ATTR_UNIT_OF_MEASUREMENT, CONF_ABOVE, CONF_BELOW, CONF_ENTITY_ID, @@ -27,8 +25,9 @@ from homeassistant.const import ( DEVICE_CLASS_TEMPERATURE, DEVICE_CLASS_VOLTAGE, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant, HomeAssistantError, callback from homeassistant.helpers import condition, config_validation as cv +from homeassistant.helpers.entity import get_device_class, get_unit_of_measurement from homeassistant.helpers.entity_registry import ( async_entries_for_device, async_get_registry, @@ -116,18 +115,12 @@ async def async_get_conditions( ] for entry in entries: - device_class = DEVICE_CLASS_NONE - state = hass.states.get(entry.entity_id) - unit_of_measurement = ( - state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state else None - ) + device_class = get_device_class(hass, entry.entity_id) or DEVICE_CLASS_NONE + unit_of_measurement = get_unit_of_measurement(hass, entry.entity_id) - if not state or not unit_of_measurement: + if not unit_of_measurement: continue - if ATTR_DEVICE_CLASS in state.attributes: - device_class = state.attributes[ATTR_DEVICE_CLASS] - templates = ENTITY_CONDITIONS.get( device_class, ENTITY_CONDITIONS[DEVICE_CLASS_NONE] ) @@ -167,15 +160,14 @@ def async_condition_from_config( async def async_get_condition_capabilities(hass, config): """List condition capabilities.""" - state = hass.states.get(config[CONF_ENTITY_ID]) - unit_of_measurement = ( - state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state else None - ) + try: + unit_of_measurement = get_unit_of_measurement(hass, config[CONF_ENTITY_ID]) + except HomeAssistantError: + unit_of_measurement = None - if not state or not unit_of_measurement: + if not unit_of_measurement: raise InvalidDeviceAutomationConfig( - "No state or unit of measurement found for " - f"condition entity {config[CONF_ENTITY_ID]}" + "No unit of measurement found for condition entity {config[CONF_ENTITY_ID]}" ) return { diff --git a/tests/components/sensor/test_device_condition.py b/tests/components/sensor/test_device_condition.py index 6cad21c5bde..daf452cf715 100644 --- a/tests/components/sensor/test_device_condition.py +++ b/tests/components/sensor/test_device_condition.py @@ -4,7 +4,12 @@ import pytest import homeassistant.components.automation as automation from homeassistant.components.sensor import DOMAIN from homeassistant.components.sensor.device_condition import ENTITY_CONDITIONS -from homeassistant.const import CONF_PLATFORM, PERCENTAGE, STATE_UNKNOWN +from homeassistant.const import ( + CONF_PLATFORM, + DEVICE_CLASS_BATTERY, + PERCENTAGE, + STATE_UNKNOWN, +) from homeassistant.helpers import device_registry from homeassistant.setup import async_setup_component @@ -80,8 +85,60 @@ async def test_get_conditions(hass, device_reg, entity_reg, enable_custom_integr assert conditions == expected_conditions +async def test_get_conditions_no_state(hass, device_reg, entity_reg): + """Test we get the expected conditions from a sensor.""" + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_ids = {} + for device_class in DEVICE_CLASSES: + entity_ids[device_class] = entity_reg.async_get_or_create( + DOMAIN, + "test", + f"5678_{device_class}", + device_id=device_entry.id, + device_class=device_class, + unit_of_measurement=UNITS_OF_MEASUREMENT.get(device_class), + ).entity_id + + await hass.async_block_till_done() + + expected_conditions = [ + { + "condition": "device", + "domain": DOMAIN, + "type": condition["type"], + "device_id": device_entry.id, + "entity_id": entity_ids[device_class], + } + for device_class in DEVICE_CLASSES + if device_class in UNITS_OF_MEASUREMENT + for condition in ENTITY_CONDITIONS[device_class] + if device_class != "none" + ] + conditions = await async_get_device_automations(hass, "condition", device_entry.id) + assert conditions == expected_conditions + + +@pytest.mark.parametrize( + "set_state,device_class_reg,device_class_state,unit_reg,unit_state", + [ + (False, DEVICE_CLASS_BATTERY, None, PERCENTAGE, None), + (True, None, DEVICE_CLASS_BATTERY, None, PERCENTAGE), + ], +) async def test_get_condition_capabilities( - hass, device_reg, entity_reg, enable_custom_integrations + hass, + device_reg, + entity_reg, + set_state, + device_class_reg, + device_class_state, + unit_reg, + unit_state, ): """Test we get the expected capabilities from a sensor condition.""" platform = getattr(hass.components, f"test.{DOMAIN}") @@ -93,15 +150,20 @@ async def test_get_condition_capabilities( config_entry_id=config_entry.entry_id, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_reg.async_get_or_create( + entity_id = entity_reg.async_get_or_create( DOMAIN, "test", platform.ENTITIES["battery"].unique_id, device_id=device_entry.id, - ) - - assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) - await hass.async_block_till_done() + device_class=device_class_reg, + unit_of_measurement=unit_reg, + ).entity_id + if set_state: + hass.states.async_set( + entity_id, + None, + {"device_class": device_class_state, "unit_of_measurement": unit_state}, + ) expected_capabilities = { "extra_fields": [