From 4efe217d9b5eccf06206c5d5e72eacc1a41ca62d Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 26 Jun 2023 22:29:14 +0200 Subject: [PATCH] Use entity registry id in select device actions (#95274) --- .../components/device_automation/helpers.py | 1 + .../components/select/device_action.py | 30 ++- tests/components/select/test_device_action.py | 192 ++++++++++++++++-- tests/components/zha/test_device_action.py | 21 +- 4 files changed, 213 insertions(+), 31 deletions(-) diff --git a/homeassistant/components/device_automation/helpers.py b/homeassistant/components/device_automation/helpers.py index e228b64bed8..8a7fcd95f48 100644 --- a/homeassistant/components/device_automation/helpers.py +++ b/homeassistant/components/device_automation/helpers.py @@ -32,6 +32,7 @@ ENTITY_PLATFORMS = { Platform.HUMIDIFIER.value, Platform.LIGHT.value, Platform.REMOTE.value, + Platform.SELECT.value, Platform.SWITCH.value, } diff --git a/homeassistant/components/select/device_action.py b/homeassistant/components/select/device_action.py index d553cdf3043..a7d47d8c833 100644 --- a/homeassistant/components/select/device_action.py +++ b/homeassistant/components/select/device_action.py @@ -5,6 +5,10 @@ from contextlib import suppress import voluptuous as vol +from homeassistant.components.device_automation import ( + async_get_entity_registry_entry_or_raise, + async_validate_entity_schema, +) from homeassistant.const import ( ATTR_ENTITY_ID, CONF_DEVICE_ID, @@ -33,43 +37,50 @@ from .const import ( SERVICE_SELECT_PREVIOUS, ) -ACTION_SCHEMA = vol.Any( +_ACTION_SCHEMA = vol.Any( cv.DEVICE_ACTION_BASE_SCHEMA.extend( { vol.Required(CONF_TYPE): SERVICE_SELECT_FIRST, - vol.Required(CONF_ENTITY_ID): cv.entity_domain(DOMAIN), + vol.Required(CONF_ENTITY_ID): cv.entity_id_or_uuid, } ), cv.DEVICE_ACTION_BASE_SCHEMA.extend( { vol.Required(CONF_TYPE): SERVICE_SELECT_LAST, - vol.Required(CONF_ENTITY_ID): cv.entity_domain(DOMAIN), + vol.Required(CONF_ENTITY_ID): cv.entity_id_or_uuid, } ), cv.DEVICE_ACTION_BASE_SCHEMA.extend( { vol.Required(CONF_TYPE): SERVICE_SELECT_NEXT, - vol.Required(CONF_ENTITY_ID): cv.entity_domain(DOMAIN), + vol.Required(CONF_ENTITY_ID): cv.entity_id_or_uuid, vol.Optional(CONF_CYCLE, default=True): cv.boolean, } ), cv.DEVICE_ACTION_BASE_SCHEMA.extend( { vol.Required(CONF_TYPE): SERVICE_SELECT_PREVIOUS, - vol.Required(CONF_ENTITY_ID): cv.entity_domain(DOMAIN), + vol.Required(CONF_ENTITY_ID): cv.entity_id_or_uuid, vol.Optional(CONF_CYCLE, default=True): cv.boolean, } ), cv.DEVICE_ACTION_BASE_SCHEMA.extend( { vol.Required(CONF_TYPE): SERVICE_SELECT_OPTION, - vol.Required(CONF_ENTITY_ID): cv.entity_domain(DOMAIN), + vol.Required(CONF_ENTITY_ID): cv.entity_id_or_uuid, vol.Required(CONF_OPTION): cv.string, } ), ) +async def async_validate_action_config( + hass: HomeAssistant, config: ConfigType +) -> ConfigType: + """Validate config.""" + return async_validate_entity_schema(hass, config, _ACTION_SCHEMA) + + async def async_get_actions( hass: HomeAssistant, device_id: str ) -> list[dict[str, str]]: @@ -79,7 +90,7 @@ async def async_get_actions( { CONF_DEVICE_ID: device_id, CONF_DOMAIN: DOMAIN, - CONF_ENTITY_ID: entry.entity_id, + CONF_ENTITY_ID: entry.id, CONF_TYPE: service_conf_type, } for service_conf_type in ( @@ -130,7 +141,10 @@ async def async_get_action_capabilities( if config[CONF_TYPE] == SERVICE_SELECT_OPTION: options: list[str] = [] with suppress(HomeAssistantError): - options = get_capability(hass, config[CONF_ENTITY_ID], ATTR_OPTIONS) or [] + entry = async_get_entity_registry_entry_or_raise( + hass, config[CONF_ENTITY_ID] + ) + options = get_capability(hass, entry.entity_id, ATTR_OPTIONS) or [] return { "extra_fields": vol.Schema({vol.Required(CONF_OPTION): vol.In(options)}) } diff --git a/tests/components/select/test_device_action.py b/tests/components/select/test_device_action.py index a517d16ad9e..ce5d48bb358 100644 --- a/tests/components/select/test_device_action.py +++ b/tests/components/select/test_device_action.py @@ -35,7 +35,7 @@ async def test_get_actions( config_entry_id=config_entry.entry_id, connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_registry.async_get_or_create( + entity_entry = entity_registry.async_get_or_create( DOMAIN, "test", "5678", device_id=device_entry.id ) expected_actions = [ @@ -43,7 +43,7 @@ async def test_get_actions( "domain": DOMAIN, "type": action, "device_id": device_entry.id, - "entity_id": "select.test_5678", + "entity_id": entity_entry.id, "metadata": {"secondary": False}, } for action in [ @@ -83,7 +83,7 @@ async def test_get_actions_hidden_auxiliary( config_entry_id=config_entry.entry_id, connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_registry.async_get_or_create( + entity_entry = entity_registry.async_get_or_create( DOMAIN, "test", "5678", @@ -97,7 +97,7 @@ async def test_get_actions_hidden_auxiliary( "domain": DOMAIN, "type": action, "device_id": device_entry.id, - "entity_id": f"{DOMAIN}.test_5678", + "entity_id": entity_entry.id, "metadata": {"secondary": True}, } for action in [ @@ -115,8 +115,12 @@ async def test_get_actions_hidden_auxiliary( @pytest.mark.parametrize("action_type", ("select_first", "select_last")) -async def test_action_select_first_last(hass: HomeAssistant, action_type: str) -> None: +async def test_action_select_first_last( + hass: HomeAssistant, entity_registry: er.EntityRegistry, action_type: str +) -> None: """Test for select_first and select_last actions.""" + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") + assert await async_setup_component( hass, automation.DOMAIN, @@ -130,7 +134,7 @@ async def test_action_select_first_last(hass: HomeAssistant, action_type: str) - "action": { "domain": DOMAIN, "device_id": "abcdefgh", - "entity_id": "select.entity", + "entity_id": entry.id, "type": action_type, }, }, @@ -145,11 +149,16 @@ async def test_action_select_first_last(hass: HomeAssistant, action_type: str) - assert len(select_calls) == 1 assert select_calls[0].domain == DOMAIN assert select_calls[0].service == action_type - assert select_calls[0].data == {"entity_id": "select.entity"} + assert select_calls[0].data == {"entity_id": entry.entity_id} -async def test_action_select_option(hass: HomeAssistant) -> None: - """Test for select_option action.""" +@pytest.mark.parametrize("action_type", ("select_first", "select_last")) +async def test_action_select_first_last_legacy( + hass: HomeAssistant, entity_registry: er.EntityRegistry, action_type: str +) -> None: + """Test for select_first and select_last actions.""" + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") + assert await async_setup_component( hass, automation.DOMAIN, @@ -163,7 +172,44 @@ async def test_action_select_option(hass: HomeAssistant) -> None: "action": { "domain": DOMAIN, "device_id": "abcdefgh", - "entity_id": "select.entity", + "entity_id": entry.entity_id, + "type": action_type, + }, + }, + ] + }, + ) + + select_calls = async_mock_service(hass, DOMAIN, action_type) + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(select_calls) == 1 + assert select_calls[0].domain == DOMAIN + assert select_calls[0].service == action_type + assert select_calls[0].data == {"entity_id": entry.entity_id} + + +async def test_action_select_option( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test for select_option action.""" + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "event", + "event_type": "test_event", + }, + "action": { + "domain": DOMAIN, + "device_id": "abcdefgh", + "entity_id": entry.id, "type": "select_option", "option": "option1", }, @@ -179,14 +225,16 @@ async def test_action_select_option(hass: HomeAssistant) -> None: assert len(select_calls) == 1 assert select_calls[0].domain == DOMAIN assert select_calls[0].service == "select_option" - assert select_calls[0].data == {"entity_id": "select.entity", "option": "option1"} + assert select_calls[0].data == {"entity_id": entry.entity_id, "option": "option1"} @pytest.mark.parametrize("action_type", ["select_next", "select_previous"]) async def test_action_select_next_previous( - hass: HomeAssistant, action_type: str + hass: HomeAssistant, entity_registry: er.EntityRegistry, action_type: str ) -> None: """Test for select_next and select_previous actions.""" + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") + assert await async_setup_component( hass, automation.DOMAIN, @@ -200,7 +248,7 @@ async def test_action_select_next_previous( "action": { "domain": DOMAIN, "device_id": "abcdefgh", - "entity_id": "select.entity", + "entity_id": entry.id, "type": action_type, "cycle": False, }, @@ -216,16 +264,20 @@ async def test_action_select_next_previous( assert len(select_calls) == 1 assert select_calls[0].domain == DOMAIN assert select_calls[0].service == action_type - assert select_calls[0].data == {"entity_id": "select.entity", "cycle": False} + assert select_calls[0].data == {"entity_id": entry.entity_id, "cycle": False} -async def test_get_action_capabilities(hass: HomeAssistant) -> None: +async def test_get_action_capabilities( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: """Test we get the expected capabilities from a select action.""" + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") + config = { "platform": "device", "domain": DOMAIN, "type": "select_option", - "entity_id": "select.test", + "entity_id": entry.id, "option": "option1", } @@ -245,7 +297,9 @@ async def test_get_action_capabilities(hass: HomeAssistant) -> None: ] # Mock an entity - hass.states.async_set("select.test", "option1", {"options": ["option1", "option2"]}) + hass.states.async_set( + entry.entity_id, "option1", {"options": ["option1", "option2"]} + ) # Test if we get the right capabilities now capabilities = await async_get_action_capabilities(hass, config) @@ -267,7 +321,7 @@ async def test_get_action_capabilities(hass: HomeAssistant) -> None: "platform": "device", "domain": DOMAIN, "type": "select_next", - "entity_id": "select.test", + "entity_id": entry.id, } capabilities = await async_get_action_capabilities(hass, config) assert capabilities @@ -303,7 +357,107 @@ async def test_get_action_capabilities(hass: HomeAssistant) -> None: "platform": "device", "domain": DOMAIN, "type": "select_first", - "entity_id": "select.test", + "entity_id": entry.id, + } + capabilities = await async_get_action_capabilities(hass, config) + assert capabilities == {} + + config["type"] = "select_last" + capabilities = await async_get_action_capabilities(hass, config) + assert capabilities == {} + + +async def test_get_action_capabilities_legacy( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test we get the expected capabilities from a select action.""" + entry = entity_registry.async_get_or_create(DOMAIN, "test", "5678") + + config = { + "platform": "device", + "domain": DOMAIN, + "type": "select_option", + "entity_id": entry.entity_id, + "option": "option1", + } + + # Test when entity doesn't exists + capabilities = await async_get_action_capabilities(hass, config) + assert capabilities + assert "extra_fields" in capabilities + assert voluptuous_serialize.convert( + capabilities["extra_fields"], custom_serializer=cv.custom_serializer + ) == [ + { + "name": "option", + "required": True, + "type": "select", + "options": [], + }, + ] + + # Mock an entity + hass.states.async_set( + entry.entity_id, "option1", {"options": ["option1", "option2"]} + ) + + # Test if we get the right capabilities now + capabilities = await async_get_action_capabilities(hass, config) + assert capabilities + assert "extra_fields" in capabilities + assert voluptuous_serialize.convert( + capabilities["extra_fields"], custom_serializer=cv.custom_serializer + ) == [ + { + "name": "option", + "required": True, + "type": "select", + "options": [("option1", "option1"), ("option2", "option2")], + }, + ] + + # Test next/previous actions + config = { + "platform": "device", + "domain": DOMAIN, + "type": "select_next", + "entity_id": entry.entity_id, + } + capabilities = await async_get_action_capabilities(hass, config) + assert capabilities + assert "extra_fields" in capabilities + assert voluptuous_serialize.convert( + capabilities["extra_fields"], custom_serializer=cv.custom_serializer + ) == [ + { + "name": "cycle", + "optional": True, + "type": "boolean", + "default": True, + }, + ] + + config["type"] = "select_previous" + capabilities = await async_get_action_capabilities(hass, config) + assert capabilities + assert "extra_fields" in capabilities + assert voluptuous_serialize.convert( + capabilities["extra_fields"], custom_serializer=cv.custom_serializer + ) == [ + { + "name": "cycle", + "optional": True, + "type": "boolean", + "default": True, + }, + ] + + # Test action types without extra fields + config = { + "platform": "device", + "domain": DOMAIN, + "type": "select_first", + "entity_id": entry.entity_id, } capabilities = await async_get_action_capabilities(hass, config) assert capabilities == {} diff --git a/tests/components/zha/test_device_action.py b/tests/components/zha/test_device_action.py index beb085408e0..d938512981f 100644 --- a/tests/components/zha/test_device_action.py +++ b/tests/components/zha/test_device_action.py @@ -114,6 +114,19 @@ async def test_get_actions(hass: HomeAssistant, device_ias) -> None: ha_device_registry = dr.async_get(hass) reg_device = ha_device_registry.async_get_device({(DOMAIN, ieee_address)}) + ha_entity_registry = er.async_get(hass) + siren_level_select = ha_entity_registry.async_get( + "select.fakemanufacturer_fakemodel_default_siren_level" + ) + siren_tone_select = ha_entity_registry.async_get( + "select.fakemanufacturer_fakemodel_default_siren_tone" + ) + strobe_level_select = ha_entity_registry.async_get( + "select.fakemanufacturer_fakemodel_default_strobe_level" + ) + strobe_select = ha_entity_registry.async_get( + "select.fakemanufacturer_fakemodel_default_strobe" + ) actions = await async_get_device_automations( hass, DeviceAutomationType.ACTION, reg_device.id @@ -145,10 +158,10 @@ async def test_get_actions(hass: HomeAssistant, device_ias) -> None: "select_previous", ] for entity_id in [ - "select.fakemanufacturer_fakemodel_default_siren_level", - "select.fakemanufacturer_fakemodel_default_siren_tone", - "select.fakemanufacturer_fakemodel_default_strobe_level", - "select.fakemanufacturer_fakemodel_default_strobe", + siren_level_select.id, + siren_tone_select.id, + strobe_level_select.id, + strobe_select.id, ] ] )