From b71e0302d6408bdc535fcf87c64872079ffb13c9 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 26 Jun 2023 21:20:40 +0200 Subject: [PATCH] Use entity registry id in sensor device conditions (#95260) --- .../components/sensor/device_condition.py | 11 +- .../sensor/test_device_condition.py | 244 +++++++++++++----- 2 files changed, 194 insertions(+), 61 deletions(-) diff --git a/homeassistant/components/sensor/device_condition.py b/homeassistant/components/sensor/device_condition.py index c52e076e51e..7d6c57de296 100644 --- a/homeassistant/components/sensor/device_condition.py +++ b/homeassistant/components/sensor/device_condition.py @@ -3,6 +3,9 @@ from __future__ import annotations import voluptuous as vol +from homeassistant.components.device_automation import ( + async_get_entity_registry_entry_or_raise, +) from homeassistant.components.device_automation.exceptions import ( InvalidDeviceAutomationConfig, ) @@ -136,7 +139,7 @@ ENTITY_CONDITIONS = { CONDITION_SCHEMA = vol.All( cv.DEVICE_CONDITION_BASE_SCHEMA.extend( { - vol.Required(CONF_ENTITY_ID): cv.entity_id, + vol.Required(CONF_ENTITY_ID): cv.entity_id_or_uuid, vol.Required(CONF_TYPE): vol.In( [ CONF_IS_APPARENT_POWER, @@ -223,7 +226,7 @@ async def async_get_conditions( **template, "condition": "device", "device_id": device_id, - "entity_id": entry.entity_id, + "entity_id": entry.id, "domain": DOMAIN, } for template in templates @@ -257,8 +260,10 @@ async def async_get_condition_capabilities( hass: HomeAssistant, config: ConfigType ) -> dict[str, vol.Schema]: """List condition capabilities.""" + try: - unit_of_measurement = get_unit_of_measurement(hass, config[CONF_ENTITY_ID]) + entry = async_get_entity_registry_entry_or_raise(hass, config[CONF_ENTITY_ID]) + unit_of_measurement = get_unit_of_measurement(hass, entry.entity_id) except HomeAssistantError: unit_of_measurement = None diff --git a/tests/components/sensor/test_device_condition.py b/tests/components/sensor/test_device_condition.py index 1989f95c789..301baf0fc49 100644 --- a/tests/components/sensor/test_device_condition.py +++ b/tests/components/sensor/test_device_condition.py @@ -91,6 +91,7 @@ async def test_get_conditions( platform.init() assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) await hass.async_block_till_done() + sensor_entries = {} config_entry = MockConfigEntry(domain="test", data={}) config_entry.add_to_hass(hass) @@ -99,7 +100,7 @@ async def test_get_conditions( connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) for device_class in SensorDeviceClass: - entity_registry.async_get_or_create( + sensor_entries[device_class] = entity_registry.async_get_or_create( DOMAIN, "test", platform.ENTITIES[device_class].unique_id, @@ -112,7 +113,7 @@ async def test_get_conditions( "domain": DOMAIN, "type": condition["type"], "device_id": device_entry.id, - "entity_id": platform.ENTITIES[device_class].entity_id, + "entity_id": sensor_entries[device_class].id, "metadata": {"secondary": False}, } for device_class in SensorDeviceClass @@ -150,7 +151,7 @@ async def test_get_conditions_hidden_auxiliary( config_entry_id=config_entry.entry_id, connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_registry.async_get_or_create( + entity_entry = entity_registry.async_get_or_create( DOMAIN, "test", "5678", @@ -165,7 +166,7 @@ async def test_get_conditions_hidden_auxiliary( "domain": DOMAIN, "type": condition, "device_id": device_entry.id, - "entity_id": f"{DOMAIN}.test_5678", + "entity_id": entity_entry.id, "metadata": {"secondary": True}, } for condition in ["is_value"] @@ -188,16 +189,16 @@ async def test_get_conditions_no_state( config_entry_id=config_entry.entry_id, connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_ids = {} + sensor_entries = {} for device_class in SensorDeviceClass: - entity_ids[device_class] = entity_registry.async_get_or_create( + sensor_entries[device_class] = entity_registry.async_get_or_create( DOMAIN, "test", f"5678_{device_class}", device_id=device_entry.id, original_device_class=device_class, unit_of_measurement=UNITS_OF_MEASUREMENT.get(device_class), - ).entity_id + ) await hass.async_block_till_done() @@ -207,7 +208,7 @@ async def test_get_conditions_no_state( "domain": DOMAIN, "type": condition["type"], "device_id": device_entry.id, - "entity_id": entity_ids[device_class], + "entity_id": sensor_entries[device_class].id, "metadata": {"secondary": False}, } for device_class in SensorDeviceClass @@ -246,7 +247,7 @@ async def test_get_conditions_no_unit_or_stateclass( config_entry_id=config_entry.entry_id, connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_registry.async_get_or_create( + entity_entry = entity_registry.async_get_or_create( DOMAIN, "test", "5678", @@ -260,7 +261,7 @@ async def test_get_conditions_no_unit_or_stateclass( "domain": DOMAIN, "type": condition, "device_id": device_entry.id, - "entity_id": f"{DOMAIN}.test_5678", + "entity_id": entity_entry.id, "metadata": {"secondary": False}, } for condition in condition_types @@ -340,8 +341,22 @@ async def test_get_condition_capabilities( assert capabilities == expected_capabilities -async def test_get_condition_capabilities_none( - hass: HomeAssistant, enable_custom_integrations: None +@pytest.mark.parametrize( + ("set_state", "device_class_reg", "device_class_state", "unit_reg", "unit_state"), + [ + (False, SensorDeviceClass.BATTERY, None, PERCENTAGE, None), + (True, None, SensorDeviceClass.BATTERY, None, PERCENTAGE), + ], +) +async def test_get_condition_capabilities_legacy( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, + set_state, + device_class_reg, + device_class_state, + unit_reg, + unit_state, ) -> None: """Test we get the expected capabilities from a sensor condition.""" platform = getattr(hass.components, f"test.{DOMAIN}") @@ -349,6 +364,72 @@ async def test_get_condition_capabilities_none( config_entry = MockConfigEntry(domain="test", data={}) config_entry.add_to_hass(hass) + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_id = entity_registry.async_get_or_create( + DOMAIN, + "test", + platform.ENTITIES["battery"].unique_id, + device_id=device_entry.id, + original_device_class=device_class_reg, + unit_of_measurement=unit_reg, + ).entity_id + if set_state: + hass.states.async_set( + entity_id, + None, + {"device_class": device_class_state, "unit_of_measurement": unit_state}, + ) + + expected_capabilities = { + "extra_fields": [ + { + "description": {"suffix": PERCENTAGE}, + "name": "above", + "optional": True, + "type": "float", + }, + { + "description": {"suffix": PERCENTAGE}, + "name": "below", + "optional": True, + "type": "float", + }, + ] + } + conditions = await async_get_device_automations( + hass, DeviceAutomationType.CONDITION, device_entry.id + ) + assert len(conditions) == 1 + for condition in conditions: + condition["entity_id"] = entity_registry.async_get( + condition["entity_id"] + ).entity_id + capabilities = await async_get_device_automation_capabilities( + hass, DeviceAutomationType.CONDITION, condition + ) + assert capabilities == expected_capabilities + + +async def test_get_condition_capabilities_none( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + enable_custom_integrations: None, +) -> None: + """Test we get the expected capabilities from a sensor condition.""" + platform = getattr(hass.components, f"test.{DOMAIN}") + platform.init() + + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + + entry_none = entity_registry.async_get_or_create( + DOMAIN, + "test", + platform.ENTITIES["none"].unique_id, + ) assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) await hass.async_block_till_done() @@ -358,14 +439,14 @@ async def test_get_condition_capabilities_none( "condition": "device", "device_id": "8770c43885354d5fa27604db6817f63f", "domain": "sensor", - "entity_id": "sensor.beer", + "entity_id": "01234567890123456789012345678901", "type": "is_battery_level", }, { "condition": "device", "device_id": "8770c43885354d5fa27604db6817f63f", "domain": "sensor", - "entity_id": platform.ENTITIES["none"].entity_id, + "entity_id": entry_none.id, "type": "is_battery_level", }, ] @@ -380,18 +461,13 @@ async def test_get_condition_capabilities_none( async def test_if_state_not_above_below( hass: HomeAssistant, + entity_registry: er.EntityRegistry, calls, caplog: pytest.LogCaptureFixture, enable_custom_integrations: None, ) -> None: """Test for bad value conditions.""" - platform = getattr(hass.components, f"test.{DOMAIN}") - - platform.init() - assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) - await hass.async_block_till_done() - - sensor1 = platform.ENTITIES["battery"] + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") assert await async_setup_component( hass, @@ -405,7 +481,7 @@ async def test_if_state_not_above_below( "condition": "device", "domain": DOMAIN, "device_id": "", - "entity_id": sensor1.entity_id, + "entity_id": entry.id, "type": "is_battery_level", } ], @@ -418,16 +494,15 @@ async def test_if_state_not_above_below( async def test_if_state_above( - hass: HomeAssistant, calls, enable_custom_integrations: None + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + calls, + enable_custom_integrations: None, ) -> None: """Test for value conditions.""" - platform = getattr(hass.components, f"test.{DOMAIN}") + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") - platform.init() - assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) - await hass.async_block_till_done() - - sensor1 = platform.ENTITIES["battery"] + hass.states.async_set(entry.entity_id, STATE_UNKNOWN, {"device_class": "battery"}) assert await async_setup_component( hass, @@ -441,7 +516,7 @@ async def test_if_state_above( "condition": "device", "domain": DOMAIN, "device_id": "", - "entity_id": sensor1.entity_id, + "entity_id": entry.id, "type": "is_battery_level", "above": 10, } @@ -458,36 +533,34 @@ async def test_if_state_above( }, ) await hass.async_block_till_done() - assert hass.states.get(sensor1.entity_id).state == STATE_UNKNOWN assert len(calls) == 0 hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 0 - hass.states.async_set(sensor1.entity_id, 9) + hass.states.async_set(entry.entity_id, 9) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 0 - hass.states.async_set(sensor1.entity_id, 11) + hass.states.async_set(entry.entity_id, 11) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 1 assert calls[0].data["some"] == "event - test_event1" -async def test_if_state_below( - hass: HomeAssistant, calls, enable_custom_integrations: None +async def test_if_state_above_legacy( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + calls, + enable_custom_integrations: None, ) -> None: """Test for value conditions.""" - platform = getattr(hass.components, f"test.{DOMAIN}") + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") - platform.init() - assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) - await hass.async_block_till_done() - - sensor1 = platform.ENTITIES["battery"] + hass.states.async_set(entry.entity_id, STATE_UNKNOWN, {"device_class": "battery"}) assert await async_setup_component( hass, @@ -501,7 +574,65 @@ async def test_if_state_below( "condition": "device", "domain": DOMAIN, "device_id": "", - "entity_id": sensor1.entity_id, + "entity_id": entry.entity_id, + "type": "is_battery_level", + "above": 10, + } + ], + "action": { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.%s }}" + % "}} - {{ trigger.".join(("platform", "event.event_type")) + }, + }, + } + ] + }, + ) + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire("test_event1") + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.states.async_set(entry.entity_id, 9) + hass.bus.async_fire("test_event1") + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.states.async_set(entry.entity_id, 11) + hass.bus.async_fire("test_event1") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["some"] == "event - test_event1" + + +async def test_if_state_below( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + calls, + enable_custom_integrations: None, +) -> None: + """Test for value conditions.""" + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") + + hass.states.async_set(entry.entity_id, STATE_UNKNOWN, {"device_class": "battery"}) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": {"platform": "event", "event_type": "test_event1"}, + "condition": [ + { + "condition": "device", + "domain": DOMAIN, + "device_id": "", + "entity_id": entry.id, "type": "is_battery_level", "below": 10, } @@ -518,19 +649,18 @@ async def test_if_state_below( }, ) await hass.async_block_till_done() - assert hass.states.get(sensor1.entity_id).state == STATE_UNKNOWN assert len(calls) == 0 hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 0 - hass.states.async_set(sensor1.entity_id, 11) + hass.states.async_set(entry.entity_id, 11) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 0 - hass.states.async_set(sensor1.entity_id, 9) + hass.states.async_set(entry.entity_id, 9) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 1 @@ -538,16 +668,15 @@ async def test_if_state_below( async def test_if_state_between( - hass: HomeAssistant, calls, enable_custom_integrations: None + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + calls, + enable_custom_integrations: None, ) -> None: """Test for value conditions.""" - platform = getattr(hass.components, f"test.{DOMAIN}") + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") - platform.init() - assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) - await hass.async_block_till_done() - - sensor1 = platform.ENTITIES["battery"] + hass.states.async_set(entry.entity_id, STATE_UNKNOWN, {"device_class": "battery"}) assert await async_setup_component( hass, @@ -561,7 +690,7 @@ async def test_if_state_between( "condition": "device", "domain": DOMAIN, "device_id": "", - "entity_id": sensor1.entity_id, + "entity_id": entry.id, "type": "is_battery_level", "above": 10, "below": 20, @@ -579,30 +708,29 @@ async def test_if_state_between( }, ) await hass.async_block_till_done() - assert hass.states.get(sensor1.entity_id).state == STATE_UNKNOWN assert len(calls) == 0 hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 0 - hass.states.async_set(sensor1.entity_id, 9) + hass.states.async_set(entry.entity_id, 9) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 0 - hass.states.async_set(sensor1.entity_id, 11) + hass.states.async_set(entry.entity_id, 11) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 1 assert calls[0].data["some"] == "event - test_event1" - hass.states.async_set(sensor1.entity_id, 21) + hass.states.async_set(entry.entity_id, 21) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 1 - hass.states.async_set(sensor1.entity_id, 19) + hass.states.async_set(entry.entity_id, 19) hass.bus.async_fire("test_event1") await hass.async_block_till_done() assert len(calls) == 2