From b916247e8e03db9eac2339dfd759514506ae543b Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 21 Jun 2021 14:49:51 +0200 Subject: [PATCH] Improve editing of device automation referring non added select entity (#52047) * Improve editing of device automation referring non added select entity * Update tests --- .../alarm_control_panel/device_action.py | 2 ++ .../components/select/device_action.py | 20 ++++++--------- .../components/select/device_condition.py | 14 +++++------ .../components/select/device_trigger.py | 14 ++++++----- homeassistant/helpers/entity.py | 2 +- tests/components/select/test_device_action.py | 15 +++++++++-- .../select/test_device_condition.py | 18 ++++++++++++- .../components/select/test_device_trigger.py | 25 ++++++++++++++++++- 8 files changed, 79 insertions(+), 31 deletions(-) diff --git a/homeassistant/components/alarm_control_panel/device_action.py b/homeassistant/components/alarm_control_panel/device_action.py index fc218a2c9c3..d92f9615c9a 100644 --- a/homeassistant/components/alarm_control_panel/device_action.py +++ b/homeassistant/components/alarm_control_panel/device_action.py @@ -112,6 +112,8 @@ async def async_get_action_capabilities( hass: HomeAssistant, config: ConfigType ) -> dict[str, vol.Schema]: """List action capabilities.""" + # We need to refer to the state directly because ATTR_CODE_ARM_REQUIRED is not a + # capability attribute state = hass.states.get(config[CONF_ENTITY_ID]) code_required = state.attributes.get(ATTR_CODE_ARM_REQUIRED) if state else False diff --git a/homeassistant/components/select/device_action.py b/homeassistant/components/select/device_action.py index 656fca32970..ece3c981690 100644 --- a/homeassistant/components/select/device_action.py +++ b/homeassistant/components/select/device_action.py @@ -12,9 +12,10 @@ from homeassistant.const import ( CONF_ENTITY_ID, CONF_TYPE, ) -from homeassistant.core import Context, HomeAssistant +from homeassistant.core import Context, HomeAssistant, HomeAssistantError from homeassistant.helpers import entity_registry import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.entity import get_capability from homeassistant.helpers.typing import ConfigType from .const import ATTR_OPTION, ATTR_OPTIONS, CONF_OPTION, DOMAIN, SERVICE_SELECT_OPTION @@ -65,16 +66,9 @@ async def async_get_action_capabilities( hass: HomeAssistant, config: ConfigType ) -> dict[str, Any]: """List action capabilities.""" - state = hass.states.get(config[CONF_ENTITY_ID]) - if state is None: - return {} + try: + options = get_capability(hass, config[CONF_ENTITY_ID], ATTR_OPTIONS) or [] + except HomeAssistantError: + options = [] - return { - "extra_fields": vol.Schema( - { - vol.Required(CONF_OPTION): vol.In( - state.attributes.get(ATTR_OPTIONS, []) - ), - } - ) - } + return {"extra_fields": vol.Schema({vol.Required(CONF_OPTION): vol.In(options)})} diff --git a/homeassistant/components/select/device_condition.py b/homeassistant/components/select/device_condition.py index 444105075b1..ad82c432ce2 100644 --- a/homeassistant/components/select/device_condition.py +++ b/homeassistant/components/select/device_condition.py @@ -13,9 +13,10 @@ from homeassistant.const import ( CONF_FOR, CONF_TYPE, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant, HomeAssistantError, callback from homeassistant.helpers import condition, config_validation as cv, entity_registry from homeassistant.helpers.config_validation import DEVICE_CONDITION_BASE_SCHEMA +from homeassistant.helpers.entity import get_capability from homeassistant.helpers.typing import ConfigType, TemplateVarsType from .const import ATTR_OPTIONS, CONF_OPTION, DOMAIN @@ -72,16 +73,15 @@ async def async_get_condition_capabilities( hass: HomeAssistant, config: ConfigType ) -> dict[str, Any]: """List condition capabilities.""" - state = hass.states.get(config[CONF_ENTITY_ID]) - if state is None: - return {} + try: + options = get_capability(hass, config[CONF_ENTITY_ID], ATTR_OPTIONS) or [] + except HomeAssistantError: + options = [] return { "extra_fields": vol.Schema( { - vol.Required(CONF_OPTION): vol.In( - state.attributes.get(ATTR_OPTIONS, []) - ), + vol.Required(CONF_OPTION): vol.In(options), vol.Optional(CONF_FOR): cv.positive_time_period_dict, } ) diff --git a/homeassistant/components/select/device_trigger.py b/homeassistant/components/select/device_trigger.py index 164f420c122..84f61dfaec9 100644 --- a/homeassistant/components/select/device_trigger.py +++ b/homeassistant/components/select/device_trigger.py @@ -22,8 +22,9 @@ from homeassistant.const import ( CONF_PLATFORM, CONF_TYPE, ) -from homeassistant.core import CALLBACK_TYPE, HomeAssistant +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, HomeAssistantError from homeassistant.helpers import config_validation as cv, entity_registry +from homeassistant.helpers.entity import get_capability from homeassistant.helpers.typing import ConfigType from . import DOMAIN @@ -88,15 +89,16 @@ async def async_get_trigger_capabilities( hass: HomeAssistant, config: ConfigType ) -> dict[str, Any]: """List trigger capabilities.""" - state = hass.states.get(config[CONF_ENTITY_ID]) - if state is None: - return {} + try: + options = get_capability(hass, config[CONF_ENTITY_ID], ATTR_OPTIONS) or [] + except HomeAssistantError: + options = [] return { "extra_fields": vol.Schema( { - vol.Optional(CONF_FROM): vol.In(state.attributes.get(ATTR_OPTIONS, [])), - vol.Optional(CONF_TO): vol.In(state.attributes.get(ATTR_OPTIONS, [])), + vol.Optional(CONF_FROM): vol.In(options), + vol.Optional(CONF_TO): vol.In(options), vol.Optional(CONF_FOR): cv.positive_time_period_dict, } ) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index e9d1e1d2e07..187d53ea00b 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -93,7 +93,7 @@ def async_generate_entity_id( return test_string -def get_capability(hass: HomeAssistant, entity_id: str, capability: str) -> str | None: +def get_capability(hass: HomeAssistant, entity_id: str, capability: str) -> Any | None: """Get a capability attribute of an entity. First try the statemachine, then entity registry. diff --git a/tests/components/select/test_device_action.py b/tests/components/select/test_device_action.py index 1cdffe9ae00..5c2486a4e26 100644 --- a/tests/components/select/test_device_action.py +++ b/tests/components/select/test_device_action.py @@ -94,7 +94,7 @@ async def test_action(hass: HomeAssistant) -> None: assert select_calls[0].data == {"entity_id": "select.entity", "option": "option1"} -async def test_get_trigger_capabilities(hass: HomeAssistant) -> None: +async def test_get_action_capabilities(hass: HomeAssistant) -> None: """Test we get the expected capabilities from a select action.""" config = { "platform": "device", @@ -106,7 +106,18 @@ async def test_get_trigger_capabilities(hass: HomeAssistant) -> None: # Test when entity doesn't exists capabilities = await async_get_action_capabilities(hass, config) - assert capabilities == {} + assert capabilities + assert "extra_fields" in capabilities + assert voluptuous_serialize.convert( + capabilities["extra_fields"], custom_serializer=cv.custom_serializer + ) == [ + { + "name": "option", + "required": True, + "type": "select", + "options": [], + }, + ] # Mock an entity hass.states.async_set("select.test", "option1", {"options": ["option1", "option2"]}) diff --git a/tests/components/select/test_device_condition.py b/tests/components/select/test_device_condition.py index fa2d2736e0f..d5ee88156cf 100644 --- a/tests/components/select/test_device_condition.py +++ b/tests/components/select/test_device_condition.py @@ -159,7 +159,23 @@ async def test_get_condition_capabilities(hass: HomeAssistant) -> None: # Test when entity doesn't exists capabilities = await async_get_condition_capabilities(hass, config) - assert capabilities == {} + assert capabilities + assert "extra_fields" in capabilities + assert voluptuous_serialize.convert( + capabilities["extra_fields"], custom_serializer=cv.custom_serializer + ) == [ + { + "name": "option", + "required": True, + "type": "select", + "options": [], + }, + { + "name": "for", + "optional": True, + "type": "positive_time_period_dict", + }, + ] # Mock an entity hass.states.async_set("select.test", "option1", {"options": ["option1", "option2"]}) diff --git a/tests/components/select/test_device_trigger.py b/tests/components/select/test_device_trigger.py index df81a67b847..b0066e9ac22 100644 --- a/tests/components/select/test_device_trigger.py +++ b/tests/components/select/test_device_trigger.py @@ -184,7 +184,30 @@ async def test_get_trigger_capabilities(hass: HomeAssistant) -> None: # Test when entity doesn't exists capabilities = await async_get_trigger_capabilities(hass, config) - assert capabilities == {} + assert capabilities + assert "extra_fields" in capabilities + assert voluptuous_serialize.convert( + capabilities["extra_fields"], custom_serializer=cv.custom_serializer + ) == [ + { + "name": "from", + "optional": True, + "type": "select", + "options": [], + }, + { + "name": "to", + "optional": True, + "type": "select", + "options": [], + }, + { + "name": "for", + "optional": True, + "type": "positive_time_period_dict", + "optional": True, + }, + ] # Mock an entity hass.states.async_set("select.test", "option1", {"options": ["option1", "option2"]})