diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 861b2143cb9..3359deb9b09 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable -from datetime import time as time_sys +from datetime import time as time_sys, timedelta from typing import Any, cast import voluptuous as vol @@ -142,10 +142,11 @@ class DeviceSelector(Selector): def __call__(self, data: Any) -> str | list[str]: """Validate the passed selection.""" if not self.config["multiple"]: - return cv.string(data) + device_id: str = vol.Schema(str)(data) + return device_id if not isinstance(data, list): raise vol.Invalid("Value should be a list") - return [cv.string(val) for val in data] + return [vol.Schema(str)(val) for val in data] @SELECTORS.register("area") @@ -165,10 +166,22 @@ class AreaSelector(Selector): def __call__(self, data: Any) -> str | list[str]: """Validate the passed selection.""" if not self.config["multiple"]: - return cv.string(data) + area_id: str = vol.Schema(str)(data) + return area_id if not isinstance(data, list): raise vol.Invalid("Value should be a list") - return [cv.string(val) for val in data] + return [vol.Schema(str)(val) for val in data] + + +def has_min_max_if_slider(data: Any) -> Any: + """Validate configuration.""" + if data["mode"] == "box": + return data + + if "min" not in data or "max" not in data: + raise vol.Invalid("min and max are required in slider mode") + + return data @SELECTORS.register("number") @@ -177,24 +190,32 @@ class NumberSelector(Selector): selector_type = "number" - CONFIG_SCHEMA = vol.Schema( - { - vol.Required("min"): vol.Coerce(float), - vol.Required("max"): vol.Coerce(float), - vol.Optional("step", default=1): vol.All( - vol.Coerce(float), vol.Range(min=1e-3) - ), - vol.Optional(CONF_UNIT_OF_MEASUREMENT): str, - vol.Optional(CONF_MODE, default="slider"): vol.In(["box", "slider"]), - } + CONFIG_SCHEMA = vol.All( + vol.Schema( + { + vol.Optional("min"): vol.Coerce(float), + vol.Optional("max"): vol.Coerce(float), + # Controls slider steps, and up/down keyboard binding for the box + # user input is not rounded + vol.Optional("step", default=1): vol.All( + vol.Coerce(float), vol.Range(min=1e-3) + ), + vol.Optional(CONF_UNIT_OF_MEASUREMENT): str, + vol.Optional(CONF_MODE, default="slider"): vol.In(["box", "slider"]), + } + ), + has_min_max_if_slider, ) def __call__(self, data: Any) -> float: """Validate the passed selection.""" value: float = vol.Coerce(float)(data) - if not self.config["min"] <= value <= self.config["max"]: - raise vol.Invalid(f"Value {value} is too small or too large") + if "min" in self.config and value < self.config["min"]: + raise vol.Invalid(f"Value {value} is too small") + + if "max" in self.config and value > self.config["max"]: + raise vol.Invalid(f"Value {value} is too large") return value @@ -205,11 +226,17 @@ class AddonSelector(Selector): selector_type = "addon" - CONFIG_SCHEMA = vol.Schema({}) + CONFIG_SCHEMA = vol.Schema( + { + vol.Optional("name"): str, + vol.Optional("slug"): str, + } + ) def __call__(self, data: Any) -> str: """Validate the passed selection.""" - return cv.string(data) + addon: str = vol.Schema(str)(data) + return addon @SELECTORS.register("boolean") @@ -250,7 +277,7 @@ class TargetSelector(Selector): CONFIG_SCHEMA = vol.Schema( { - vol.Optional("entity"): EntitySelector.CONFIG_SCHEMA, + vol.Optional("entity"): SINGLE_ENTITY_SELECTOR_CONFIG_SCHEMA, vol.Optional("device"): DeviceSelector.CONFIG_SCHEMA, } ) @@ -295,14 +322,48 @@ class StringSelector(Selector): selector_type = "text" - CONFIG_SCHEMA = vol.Schema({vol.Optional("multiline", default=False): bool}) + STRING_TYPES = [ + "number", + "text", + "search", + "tel", + "url", + "email", + "password", + "date", + "month", + "week", + "time", + "datetime-local", + "color", + ] + CONFIG_SCHEMA = vol.Schema( + { + vol.Optional("multiline", default=False): bool, + vol.Optional("suffix"): str, + # The "type" controls the input field in the browser, the resulting + # data can be any string so we don't validate it. + vol.Optional("type"): vol.In(STRING_TYPES), + } + ) def __call__(self, data: Any) -> str: """Validate the passed selection.""" - text = cv.string(data) + text: str = vol.Schema(str)(data) return text +select_option = vol.All( + dict, + vol.Schema( + { + vol.Required("value"): str, + vol.Required("label"): str, + } + ), +) + + @SELECTORS.register("select") class SelectSelector(Selector): """Selector for an single-choice input select.""" @@ -310,10 +371,124 @@ class SelectSelector(Selector): selector_type = "select" CONFIG_SCHEMA = vol.Schema( - {vol.Required("options"): vol.All([str], vol.Length(min=1))} + { + vol.Required("options"): vol.All( + vol.Any([str], [select_option]), vol.Length(min=1) + ) + } ) def __call__(self, data: Any) -> Any: """Validate the passed selection.""" - selected_option = vol.In(self.config["options"])(cv.string(data)) - return selected_option + if isinstance(self.config["options"][0], str): + options = self.config["options"] + else: + options = [option["value"] for option in self.config["options"]] + return vol.In(options)(vol.Schema(str)(data)) + + +@SELECTORS.register("attribute") +class AttributeSelector(Selector): + """Selector for an entity attribute.""" + + selector_type = "attribute" + + CONFIG_SCHEMA = vol.Schema({vol.Required("entity_id"): cv.entity_id}) + + def __call__(self, data: Any) -> str: + """Validate the passed selection.""" + attribute: str = vol.Schema(str)(data) + return attribute + + +@SELECTORS.register("duration") +class DurationSelector(Selector): + """Selector for a duration.""" + + selector_type = "duration" + + CONFIG_SCHEMA = vol.Schema({}) + + def __call__(self, data: Any) -> timedelta: + """Validate the passed selection.""" + duration: timedelta = cv.time_period_dict(data) + return duration + + +@SELECTORS.register("icon") +class IconSelector(Selector): + """Selector for an icon.""" + + selector_type = "icon" + + CONFIG_SCHEMA = vol.Schema( + {vol.Optional("placeholder"): str} + # Frontend also has a fallbackPath option, this is not used by core + ) + + def __call__(self, data: Any) -> str: + """Validate the passed selection.""" + icon: str = vol.Schema(str)(data) + return icon + + +@SELECTORS.register("theme") +class ThemeSelector(Selector): + """Selector for an theme.""" + + selector_type = "theme" + + CONFIG_SCHEMA = vol.Schema({}) + + def __call__(self, data: Any) -> str: + """Validate the passed selection.""" + theme: str = vol.Schema(str)(data) + return theme + + +@SELECTORS.register("media") +class MediaSelector(Selector): + """Selector for media.""" + + selector_type = "media" + + CONFIG_SCHEMA = vol.Schema({}) + DATA_SCHEMA = vol.Schema( + { + # Although marked as optional in frontend, this field is required + vol.Required("entity_id"): cv.entity_id_or_uuid, + # Although marked as optional in frontend, this field is required + vol.Required("media_content_id"): str, + # Although marked as optional in frontend, this field is required + vol.Required("media_content_type"): str, + vol.Remove("metadata"): dict, + } + ) + + def __call__(self, data: Any) -> dict[str, float]: + """Validate the passed selection.""" + media: dict[str, float] = self.DATA_SCHEMA(data) + return media + + +@SELECTORS.register("location") +class LocationSelector(Selector): + """Selector for a location.""" + + selector_type = "location" + + CONFIG_SCHEMA = vol.Schema( + {vol.Optional("radius"): bool, vol.Optional("icon"): str} + ) + DATA_SCHEMA = vol.Schema( + { + vol.Required("latitude"): float, + vol.Required("longitude"): float, + vol.Optional("radius"): float, + } + ) + + def __call__(self, data: Any) -> dict[str, float]: + """Validate the passed selection.""" + location: dict[str, float] = self.DATA_SCHEMA(data) + return location diff --git a/tests/components/blueprint/test_importer.py b/tests/components/blueprint/test_importer.py index 623d1e9ebbf..0e1e66405e6 100644 --- a/tests/components/blueprint/test_importer.py +++ b/tests/components/blueprint/test_importer.py @@ -32,7 +32,7 @@ COMMUNITY_POST_INPUTS = { "light": { "name": "Light(s)", "description": "The light(s) to control", - "selector": {"target": {"entity": {"domain": "light", "multiple": False}}}, + "selector": {"target": {"entity": {"domain": "light"}}}, }, "force_brightness": { "name": "Force turn on brightness", diff --git a/tests/helpers/test_selector.py b/tests/helpers/test_selector.py index 0af8a050ce8..7031f16d249 100644 --- a/tests/helpers/test_selector.py +++ b/tests/helpers/test_selector.py @@ -1,4 +1,6 @@ """Test selectors.""" +from datetime import timedelta + import pytest import voluptuous as vol @@ -206,6 +208,7 @@ def test_area_selector_schema(schema, valid_selections, invalid_selections): (), ), ({"min": 10, "max": 1000, "mode": "slider", "step": 0.5}, (), ()), + ({"mode": "box"}, (10,), ()), ), ) def test_number_selector_schema(schema, valid_selections, invalid_selections): @@ -213,6 +216,19 @@ def test_number_selector_schema(schema, valid_selections, invalid_selections): _test_selector("number", schema, valid_selections, invalid_selections) +@pytest.mark.parametrize( + "schema", + ( + {}, # Must have mandatory fields + {"mode": "slider"}, # Must have min+max in slider mode + ), +) +def test_number_selector_schema_error(schema): + """Test select selector.""" + with pytest.raises(vol.Invalid): + selector.validate_selector({"number": schema}) + + @pytest.mark.parametrize( "schema,valid_selections,invalid_selections", (({}, ("abc123",), (None,)),), @@ -315,6 +331,16 @@ def test_text_selector_schema(schema, valid_selections, invalid_selections): ("red", "green", "blue"), ("cat", 0, None), ), + ( + { + "options": [ + {"value": "red", "label": "Ruby Red"}, + {"value": "green", "label": "Emerald Green"}, + ] + }, + ("red", "green"), + ("cat", 0, None), + ), ), ) def test_select_selector_schema(schema, valid_selections, invalid_selections): @@ -325,12 +351,148 @@ def test_select_selector_schema(schema, valid_selections, invalid_selections): @pytest.mark.parametrize( "schema", ( - {}, - {"options": {"hello": "World"}}, - {"options": []}, + {}, # Must have options + {"options": {"hello": "World"}}, # Options must be a list + {"options": []}, # Must have at least option + # Options must be strings or value / label pairs + {"options": [{"hello": "World"}]}, + # Options must all be of the same type + {"options": ["red", {"value": "green", "label": "Emerald Green"}]}, ), ) def test_select_selector_schema_error(schema): """Test select selector.""" with pytest.raises(vol.Invalid): selector.validate_selector({"select": schema}) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {"entity_id": "sensor.abc"}, + ("friendly_name", "device_class"), + (None,), + ), + ), +) +def test_attribute_selector_schema(schema, valid_selections, invalid_selections): + """Test attribute selector.""" + _test_selector("attribute", schema, valid_selections, invalid_selections) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ({"seconds": 10},), + (None, {}), + ), + ), +) +def test_duration_selector_schema(schema, valid_selections, invalid_selections): + """Test duration selector.""" + _test_selector( + "duration", + schema, + valid_selections, + invalid_selections, + lambda x: timedelta(**x), + ) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ("mdi:abc",), + (None,), + ), + ), +) +def test_icon_selector_schema(schema, valid_selections, invalid_selections): + """Test icon selector.""" + _test_selector("icon", schema, valid_selections, invalid_selections) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ("abc",), + (None,), + ), + ), +) +def test_theme_selector_schema(schema, valid_selections, invalid_selections): + """Test theme selector.""" + _test_selector("theme", schema, valid_selections, invalid_selections) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ( + { + "entity_id": "sensor.abc", + "media_content_id": "abc", + "media_content_type": "def", + }, + { + "entity_id": "sensor.abc", + "media_content_id": "abc", + "media_content_type": "def", + "metadata": {}, + }, + ), + (None, "abc", {}), + ), + ), +) +def test_media_selector_schema(schema, valid_selections, invalid_selections): + """Test media selector.""" + + def drop_metadata(data): + """Drop metadata key from the input.""" + data.pop("metadata", None) + return data + + _test_selector("media", schema, valid_selections, invalid_selections, drop_metadata) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ( + { + "latitude": 1.0, + "longitude": 2.0, + }, + { + "latitude": 1.0, + "longitude": 2.0, + "radius": 3.0, + }, + ), + ( + None, + "abc", + {}, + {"latitude": 1.0}, + {"longitude": 1.0}, + {"latitude": 1.0, "longitude": "1.0"}, + ), + ), + ), +) +def test_location_selector_schema(schema, valid_selections, invalid_selections): + """Test location selector.""" + + _test_selector("location", schema, valid_selections, invalid_selections)