diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index efb1ee0b1f1..1dba926a9af 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -11,6 +11,7 @@ import voluptuous as vol from homeassistant.const import CONF_MODE, CONF_UNIT_OF_MEASUREMENT from homeassistant.core import split_entity_id, valid_entity_id +from homeassistant.generated.countries import COUNTRIES from homeassistant.util import decorator from homeassistant.util.yaml import dumper @@ -564,6 +565,40 @@ class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]): return agent +class CountrySelectorConfig(TypedDict, total=False): + """Class to represent a country selector config.""" + + countries: list[str] + no_sort: bool + + +@SELECTORS.register("country") +class CountrySelector(Selector[CountrySelectorConfig]): + """Selector for a single-choice country select.""" + + selector_type = "country" + + CONFIG_SCHEMA = vol.Schema( + { + vol.Optional("countries"): [str], + vol.Optional("no_sort", default=False): cv.boolean, + } + ) + + def __init__(self, config: CountrySelectorConfig | None = None) -> None: + """Instantiate a selector.""" + super().__init__(config) + + def __call__(self, data: Any) -> Any: + """Validate the passed selection.""" + country: str = vol.Schema(str)(data) + if "countries" in self.config and ( + country not in self.config["countries"] or country not in COUNTRIES + ): + raise vol.Invalid(f"Value {country} is not a valid option") + return country + + class DateSelectorConfig(TypedDict): """Class to represent a date selector config.""" diff --git a/tests/helpers/test_selector.py b/tests/helpers/test_selector.py index 590526cdb2b..ee4749be346 100644 --- a/tests/helpers/test_selector.py +++ b/tests/helpers/test_selector.py @@ -479,6 +479,26 @@ def test_config_entry_selector_schema( _test_selector("config_entry", schema, valid_selections, invalid_selections) +@pytest.mark.parametrize( + ("schema", "valid_selections", "invalid_selections"), + ( + ( + {}, + ("NL", "DE"), + (None, True, 1), + ), + ( + {"countries": ["NL", "DE"]}, + ("NL", "DE"), + (None, True, 1, "sv", "en"), + ), + ), +) +def test_country_selector_schema(schema, valid_selections, invalid_selections) -> None: + """Test country selector.""" + _test_selector("country", schema, valid_selections, invalid_selections) + + @pytest.mark.parametrize( ("schema", "valid_selections", "invalid_selections"), (({}, ("00:00:00",), ("blah", None)),),