From 8299d0a7c3172a6913a751c02b4838b3adcd29ca Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Wed, 4 Aug 2021 11:12:42 +0200 Subject: [PATCH] Validate Select option before calling entity method (#52352) Co-authored-by: Martin Hjelmare Co-authored-by: Franck Nijhof Co-authored-by: Franck Nijhof --- homeassistant/components/demo/select.py | 3 - homeassistant/components/select/__init__.py | 12 ++- tests/components/select/test_init.py | 86 ++++++++++++++++++- tests/components/wled/test_select.py | 12 +-- .../custom_components/test/select.py | 63 ++++++++++++++ 5 files changed, 164 insertions(+), 12 deletions(-) create mode 100644 tests/testing_config/custom_components/test/select.py diff --git a/homeassistant/components/demo/select.py b/homeassistant/components/demo/select.py index dcc0c12a9b4..8d499c7a258 100644 --- a/homeassistant/components/demo/select.py +++ b/homeassistant/components/demo/select.py @@ -73,8 +73,5 @@ class DemoSelect(SelectEntity): async def async_select_option(self, option: str) -> None: """Update the current selected option.""" - if option not in self.options: - raise ValueError(f"Invalid option for {self.entity_id}: {option}") - self._attr_current_option = option self.async_write_ha_state() diff --git a/homeassistant/components/select/__init__.py b/homeassistant/components/select/__init__.py index d5c70c76cd0..9a7bfa62cdf 100644 --- a/homeassistant/components/select/__init__.py +++ b/homeassistant/components/select/__init__.py @@ -9,7 +9,7 @@ from typing import Any, final import voluptuous as vol from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.helpers import config_validation as cv from homeassistant.helpers.config_validation import ( # noqa: F401 PLATFORM_SCHEMA, @@ -40,12 +40,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: component.async_register_entity_service( SERVICE_SELECT_OPTION, {vol.Required(ATTR_OPTION): cv.string}, - "async_select_option", + async_select_option, ) return True +async def async_select_option(entity: SelectEntity, service_call: ServiceCall) -> None: + """Service call wrapper to set a new value.""" + option = service_call.data[ATTR_OPTION] + if option not in entity.options: + raise ValueError(f"Option {option} not valid for {entity.name}") + await entity.async_select_option(option) + + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" component: EntityComponent = hass.data[DOMAIN] diff --git a/tests/components/select/test_init.py b/tests/components/select/test_init.py index 188099164c2..21745694d38 100644 --- a/tests/components/select/test_init.py +++ b/tests/components/select/test_init.py @@ -1,6 +1,18 @@ """The tests for the Select component.""" -from homeassistant.components.select import SelectEntity +from unittest.mock import MagicMock + +import pytest + +from homeassistant.components.select import ATTR_OPTIONS, DOMAIN, SelectEntity +from homeassistant.const import ( + ATTR_ENTITY_ID, + ATTR_OPTION, + CONF_PLATFORM, + SERVICE_SELECT_OPTION, + STATE_UNKNOWN, +) from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component class MockSelectEntity(SelectEntity): @@ -26,3 +38,75 @@ async def test_select(hass: HomeAssistant) -> None: select._attr_current_option = "option_four" assert select.current_option == "option_four" assert select.state is None + + select.hass = hass + + with pytest.raises(NotImplementedError): + await select.async_select_option("option_one") + + select.select_option = MagicMock() + await select.async_select_option("option_one") + + assert select.select_option.called + assert select.select_option.call_args[0][0] == "option_one" + + assert select.capability_attributes[ATTR_OPTIONS] == [ + "option_one", + "option_two", + "option_three", + ] + + +async def test_custom_integration_and_validation(hass, enable_custom_integrations): + """Test we can only select valid options.""" + platform = getattr(hass.components, f"test.{DOMAIN}") + platform.init() + + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) + await hass.async_block_till_done() + + assert hass.states.get("select.select_1").state == "option 1" + + await hass.services.async_call( + DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_OPTION: "option 2", ATTR_ENTITY_ID: "select.select_1"}, + blocking=True, + ) + + hass.states.async_set("select.select_1", "option 2") + await hass.async_block_till_done() + assert hass.states.get("select.select_1").state == "option 2" + + # test ValueError trigger + with pytest.raises(ValueError): + await hass.services.async_call( + DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_OPTION: "option invalid", ATTR_ENTITY_ID: "select.select_1"}, + blocking=True, + ) + await hass.async_block_till_done() + assert hass.states.get("select.select_1").state == "option 2" + + assert hass.states.get("select.select_2").state == STATE_UNKNOWN + + with pytest.raises(ValueError): + await hass.services.async_call( + DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_OPTION: "option invalid", ATTR_ENTITY_ID: "select.select_2"}, + blocking=True, + ) + await hass.async_block_till_done() + assert hass.states.get("select.select_2").state == STATE_UNKNOWN + + await hass.services.async_call( + DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_OPTION: "option 3", ATTR_ENTITY_ID: "select.select_2"}, + blocking=True, + ) + await hass.async_block_till_done() + + assert hass.states.get("select.select_2").state == "option 3" diff --git a/tests/components/wled/test_select.py b/tests/components/wled/test_select.py index dbc1bf7c970..1d68879d510 100644 --- a/tests/components/wled/test_select.py +++ b/tests/components/wled/test_select.py @@ -126,7 +126,7 @@ async def test_color_palette_segment_change_state( SERVICE_SELECT_OPTION, { ATTR_ENTITY_ID: "select.wled_rgb_light_segment_1_color_palette", - ATTR_OPTION: "Some Other Palette", + ATTR_OPTION: "Icefire", }, blocking=True, ) @@ -134,7 +134,7 @@ async def test_color_palette_segment_change_state( assert mock_wled.segment.call_count == 1 mock_wled.segment.assert_called_with( segment_id=1, - palette="Some Other Palette", + palette="Icefire", ) @@ -195,7 +195,7 @@ async def test_color_palette_select_error( SERVICE_SELECT_OPTION, { ATTR_ENTITY_ID: "select.wled_rgb_light_segment_1_color_palette", - ATTR_OPTION: "Whatever", + ATTR_OPTION: "Icefire", }, blocking=True, ) @@ -206,7 +206,7 @@ async def test_color_palette_select_error( assert state.state == "Random Cycle" assert "Invalid response from API" in caplog.text assert mock_wled.segment.call_count == 1 - mock_wled.segment.assert_called_with(segment_id=1, palette="Whatever") + mock_wled.segment.assert_called_with(segment_id=1, palette="Icefire") async def test_color_palette_select_connection_error( @@ -224,7 +224,7 @@ async def test_color_palette_select_connection_error( SERVICE_SELECT_OPTION, { ATTR_ENTITY_ID: "select.wled_rgb_light_segment_1_color_palette", - ATTR_OPTION: "Whatever", + ATTR_OPTION: "Icefire", }, blocking=True, ) @@ -235,7 +235,7 @@ async def test_color_palette_select_connection_error( assert state.state == STATE_UNAVAILABLE assert "Error communicating with API" in caplog.text assert mock_wled.segment.call_count == 1 - mock_wled.segment.assert_called_with(segment_id=1, palette="Whatever") + mock_wled.segment.assert_called_with(segment_id=1, palette="Icefire") async def test_preset_unavailable_without_presets( diff --git a/tests/testing_config/custom_components/test/select.py b/tests/testing_config/custom_components/test/select.py new file mode 100644 index 00000000000..375191983b5 --- /dev/null +++ b/tests/testing_config/custom_components/test/select.py @@ -0,0 +1,63 @@ +""" +Provide a mock select platform. + +Call init before using it in your tests to ensure clean test data. +""" +from homeassistant.components.select import SelectEntity + +from tests.common import MockEntity + +UNIQUE_SELECT_1 = "unique_select_1" +UNIQUE_SELECT_2 = "unique_select_2" + +ENTITIES = [] + + +class MockSelectEntity(MockEntity, SelectEntity): + """Mock Select class.""" + + _attr_current_option = None + + @property + def current_option(self): + """Return the current option of this select.""" + return self._handle("current_option") + + @property + def options(self) -> list: + """Return the list of available options of this select.""" + return self._handle("options") + + def select_option(self, option: str) -> None: + """Change the selected option.""" + self._attr_current_option = option + + +def init(empty=False): + """Initialize the platform with entities.""" + global ENTITIES + + ENTITIES = ( + [] + if empty + else [ + MockSelectEntity( + name="select 1", + unique_id="unique_select_1", + options=["option 1", "option 2", "option 3"], + current_option="option 1", + ), + MockSelectEntity( + name="select 2", + unique_id="unique_select_2", + options=["option 1", "option 2", "option 3"], + ), + ] + ) + + +async def async_setup_platform( + hass, config, async_add_entities_callback, discovery_info=None +): + """Return mock entities.""" + async_add_entities_callback(ENTITIES)