diff --git a/homeassistant/components/switch/__init__.py b/homeassistant/components/switch/__init__.py index 157a5cd40c7..b7e0ffac59c 100644 --- a/homeassistant/components/switch/__init__.py +++ b/homeassistant/components/switch/__init__.py @@ -11,12 +11,15 @@ import voluptuous as vol from homeassistant.backports.enum import StrEnum from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( + CONF_ENTITY_ID, SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON, STATE_ON, + Platform, ) from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.config_validation import ( # noqa: F401 PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, @@ -26,7 +29,8 @@ from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass -DOMAIN = "switch" +from .const import DOMAIN + SCAN_INTERVAL = timedelta(seconds=30) ENTITY_ID_FORMAT = DOMAIN + ".{}" @@ -59,6 +63,8 @@ DEVICE_CLASSES = [cls.value for cls in SwitchDeviceClass] DEVICE_CLASS_OUTLET = SwitchDeviceClass.OUTLET.value DEVICE_CLASS_SWITCH = SwitchDeviceClass.SWITCH.value +PLATFORMS: list[Platform] = [Platform.LIGHT] + @bind_hass def is_on(hass: HomeAssistant, entity_id: str) -> bool: @@ -85,6 +91,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" + if entry.domain == DOMAIN: + registry = er.async_get(hass) + try: + er.async_validate_entity_id(registry, entry.options[CONF_ENTITY_ID]) + except vol.Invalid: + # The entity is identified by an unknown entity registry ID + _LOGGER.error( + "Failed to setup light switch for unknown entity %s", + entry.options[CONF_ENTITY_ID], + ) + return False + + hass.config_entries.async_setup_platforms(entry, PLATFORMS) + return True + component: EntityComponent = hass.data[DOMAIN] return await component.async_setup_entry(entry) diff --git a/homeassistant/components/switch/config_flow.py b/homeassistant/components/switch/config_flow.py new file mode 100644 index 00000000000..1adc4ec0aee --- /dev/null +++ b/homeassistant/components/switch/config_flow.py @@ -0,0 +1,45 @@ +"""Config flow for Switch integration.""" +from __future__ import annotations + +from typing import Any + +import voluptuous as vol + +from homeassistant.core import split_entity_id +from homeassistant.helpers import ( + entity_registry as er, + helper_config_entry_flow, + selector, +) + +from .const import DOMAIN + +STEPS = { + "init": vol.Schema( + { + vol.Required("entity_id"): selector.selector( + {"entity": {"domain": "switch"}} + ), + } + ) +} + + +class SwitchLightConfigFlowHandler( + helper_config_entry_flow.HelperConfigFlowHandler, domain=DOMAIN +): + """Handle a config or options flow for Switch Light.""" + + steps = STEPS + + def async_config_entry_title(self, user_input: dict[str, Any]) -> str: + """Return config entry title.""" + registry = er.async_get(self.hass) + object_id = split_entity_id(user_input["entity_id"])[1] + entry = registry.async_get(user_input["entity_id"]) + if entry: + return entry.name or entry.original_name or object_id + state = self.hass.states.get(user_input["entity_id"]) + if state: + return state.name or object_id + return object_id diff --git a/homeassistant/components/switch/const.py b/homeassistant/components/switch/const.py new file mode 100644 index 00000000000..aaff452c5ce --- /dev/null +++ b/homeassistant/components/switch/const.py @@ -0,0 +1,3 @@ +"""Constants for the Switch integration.""" + +DOMAIN = "switch" diff --git a/homeassistant/components/switch/light.py b/homeassistant/components/switch/light.py index 32c0aff74fa..7c732d7750d 100644 --- a/homeassistant/components/switch/light.py +++ b/homeassistant/components/switch/light.py @@ -11,6 +11,7 @@ from homeassistant.components.light import ( PLATFORM_SCHEMA, LightEntity, ) +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( ATTR_ENTITY_ID, CONF_ENTITY_ID, @@ -59,6 +60,29 @@ async def async_setup_platform( ) +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Initialize Light Switch config entry.""" + + registry = er.async_get(hass) + entity_id = er.async_validate_entity_id( + registry, config_entry.options[CONF_ENTITY_ID] + ) + + async_add_entities( + [ + LightSwitch( + config_entry.title, + entity_id, + config_entry.entry_id, + ) + ] + ) + + class LightSwitch(LightEntity): """Represents a Switch as a Light.""" diff --git a/homeassistant/components/switch/manifest.json b/homeassistant/components/switch/manifest.json index 4c52e596648..f087ace1bce 100644 --- a/homeassistant/components/switch/manifest.json +++ b/homeassistant/components/switch/manifest.json @@ -2,6 +2,9 @@ "domain": "switch", "name": "Switch", "documentation": "https://www.home-assistant.io/integrations/switch", - "codeowners": ["@home-assistant/core"], - "quality_scale": "internal" + "codeowners": [ + "@home-assistant/core" + ], + "quality_scale": "internal", + "config_flow": true } diff --git a/homeassistant/components/switch/strings.json b/homeassistant/components/switch/strings.json index 7ea84e649ef..5cdd0c35936 100644 --- a/homeassistant/components/switch/strings.json +++ b/homeassistant/components/switch/strings.json @@ -1,5 +1,15 @@ { "title": "Switch", + "config": { + "step": { + "init": { + "description": "Select the switch for the light switch.", + "data": { + "entity_id": "Switch entity" + } + } + } + }, "device_automation": { "action_type": { "toggle": "Toggle {entity_name}", diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index 7a36cfebc7b..2ee6e235e91 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -319,6 +319,7 @@ FLOWS = [ "stookalert", "subaru", "surepetcare", + "switch", "switchbot", "switcher_kis", "syncthing", diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 5cce811b4b3..7c86bfaa501 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -923,6 +923,20 @@ async def async_migrate_entries( ent_reg.async_update_entity(entry.entity_id, **updates) +@callback +def async_validate_entity_id(registry: EntityRegistry, entity_id_or_uuid: str) -> str: + """Validate and resolve an entity id or UUID to an entity id. + + Raises vol.Invalid if the entity or UUID is invalid, or if the UUID is not + associated with an entity registry item. + """ + if valid_entity_id(entity_id_or_uuid): + return entity_id_or_uuid + if (entry := registry.entities.get_entry(entity_id_or_uuid)) is None: + raise vol.Invalid(f"Unknown entity registry entry {entity_id_or_uuid}") + return entry.entity_id + + @callback def async_validate_entity_ids( registry: EntityRegistry, entity_ids_or_uuids: list[str] @@ -934,21 +948,4 @@ def async_validate_entity_ids( an entity registry item. """ - def async_validate_entity_id(entity_id_or_uuid: str) -> str | None: - """Resolve an entity id or UUID to an entity id. - - Raises vol.Invalid if the entity or UUID is invalid, or if the UUID is not - associated with an entity registry item. - """ - if valid_entity_id(entity_id_or_uuid): - return entity_id_or_uuid - if (entry := registry.entities.get_entry(entity_id_or_uuid)) is None: - raise vol.Invalid(f"Unknown entity registry entry {entity_id_or_uuid}") - return entry.entity_id - - tmp = [ - resolved_item - for item in entity_ids_or_uuids - if (resolved_item := async_validate_entity_id(item)) is not None - ] - return tmp + return [async_validate_entity_id(registry, item) for item in entity_ids_or_uuids] diff --git a/homeassistant/helpers/helper_config_entry_flow.py b/homeassistant/helpers/helper_config_entry_flow.py new file mode 100644 index 00000000000..82d10868d01 --- /dev/null +++ b/homeassistant/helpers/helper_config_entry_flow.py @@ -0,0 +1,105 @@ +"""Helpers for data entry flows for helper config entries.""" +from __future__ import annotations + +from abc import abstractmethod +from typing import Any + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import RESULT_TYPE_CREATE_ENTRY, FlowResult + + +class HelperCommonFlowHandler: + """Handle a config or options flow for helper.""" + + def __init__( + self, + handler: HelperConfigFlowHandler, + config_entry: config_entries.ConfigEntry | None, + ) -> None: + """Initialize a common handler.""" + self._handler = handler + self._options = dict(config_entry.options) if config_entry is not None else {} + + async def async_step(self, _user_input: dict[str, Any] | None = None) -> FlowResult: + """Handle a step.""" + errors = None + step_id = ( + self._handler.cur_step["step_id"] if self._handler.cur_step else "init" + ) + if _user_input is not None: + errors = {} + try: + user_input = await self._handler.async_validate_input( + self._handler.hass, step_id, _user_input + ) + except vol.Invalid as exc: + errors["base"] = str(exc) + else: + if ( + next_step_id := self._handler.async_next_step(step_id, user_input) + ) is None: + title = self._handler.async_config_entry_title(user_input) + return self._handler.async_create_entry( + title=title, data=user_input + ) + return self._handler.async_show_form( + step_id=next_step_id, data_schema=self._handler.steps[next_step_id] + ) + + return self._handler.async_show_form( + step_id=step_id, data_schema=self._handler.steps[step_id], errors=errors + ) + + +class HelperConfigFlowHandler(config_entries.ConfigFlow): + """Handle a config flow for helper integrations.""" + + steps: dict[str, vol.Schema] + + VERSION = 1 + + # pylint: disable-next=arguments-differ + def __init_subclass__(cls, **kwargs: Any) -> None: + """Initialize a subclass, register if possible.""" + super().__init_subclass__(**kwargs) + + for step in cls.steps: + setattr(cls, f"async_step_{step}", cls.async_step) + + def __init__(self) -> None: + """Initialize config flow.""" + self._common_handler = HelperCommonFlowHandler(self, None) + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle the initial step.""" + return await self.async_step() + + async def async_step(self, user_input: dict[str, Any] | None = None) -> FlowResult: + """Handle a step.""" + result = await self._common_handler.async_step(user_input) + if result["type"] == RESULT_TYPE_CREATE_ENTRY: + result["options"] = result["data"] + result["data"] = {} + return result + + # pylint: disable-next=no-self-use + @abstractmethod + def async_config_entry_title(self, user_input: dict[str, Any]) -> str: + """Return config entry title.""" + + # pylint: disable-next=no-self-use + def async_next_step(self, step_id: str, user_input: dict[str, Any]) -> str | None: + """Return next step_id, or None to finish the flow.""" + return None + + # pylint: disable-next=no-self-use + async def async_validate_input( + self, hass: HomeAssistant, step_id: str, user_input: dict[str, Any] + ) -> dict[str, Any]: + """Validate user input.""" + return user_input diff --git a/tests/components/switch/test_config_flow.py b/tests/components/switch/test_config_flow.py new file mode 100644 index 00000000000..ca838b7b972 --- /dev/null +++ b/tests/components/switch/test_config_flow.py @@ -0,0 +1,154 @@ +"""Test the switch light config flow.""" +from unittest.mock import patch + +import pytest + +from homeassistant import config_entries, data_entry_flow +from homeassistant.components.switch import async_setup_entry +from homeassistant.components.switch.const import DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import RESULT_TYPE_CREATE_ENTRY, RESULT_TYPE_FORM +from homeassistant.helpers import entity_registry as er + + +async def test_config_flow(hass: HomeAssistant) -> None: + """Test the config flow.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == RESULT_TYPE_FORM + assert result["errors"] is None + + with patch( + "homeassistant.components.switch.async_setup_entry", + wraps=async_setup_entry, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "entity_id": "switch.ceiling", + }, + ) + await hass.async_block_till_done() + + assert result["type"] == RESULT_TYPE_CREATE_ENTRY + assert result["title"] == "ceiling" + assert result["data"] == {} + assert result["options"] == {"entity_id": "switch.ceiling"} + assert len(mock_setup_entry.mock_calls) == 1 + + config_entry = hass.config_entries.async_entries(DOMAIN)[0] + assert config_entry.data == {} + assert config_entry.options == {"entity_id": "switch.ceiling"} + + assert hass.states.get("light.ceiling") + + +async def test_name(hass: HomeAssistant) -> None: + """Test the config flow name is copied from registry entry, with fallback to state.""" + registry = er.async_get(hass) + + # No entry or state, use Object ID + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"entity_id": "switch.ceiling"}, + ) + assert result["title"] == "ceiling" + + # State set, use name from state + hass.states.async_set("switch.ceiling", "on", {"friendly_name": "State Name"}) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"entity_id": "switch.ceiling"}, + ) + assert result["title"] == "State Name" + + # Entity registered, use original name from registry entry + hass.states.async_remove("switch.ceiling") + entry = registry.async_get_or_create( + "switch", + "test", + "unique", + suggested_object_id="ceiling", + original_name="Original Name", + ) + assert entry.entity_id == "switch.ceiling" + hass.states.async_set("switch.ceiling", "on", {"friendly_name": "State Name"}) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"entity_id": "switch.ceiling"}, + ) + assert result["title"] == "Original Name" + + # Entity has customized name + registry.async_update_entity("switch.ceiling", name="Custom Name") + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"entity_id": "switch.ceiling"}, + ) + assert result["title"] == "Custom Name" + + +def get_suggested(schema, key): + """Get suggested value for key in voluptuous schema.""" + for k in schema.keys(): + if k == key: + if k.description is None or "suggested_value" not in k.description: + return None + return k.description["suggested_value"] + + +async def test_options(hass: HomeAssistant) -> None: + """Test reconfiguring.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == RESULT_TYPE_FORM + assert result["errors"] is None + assert get_suggested(result["data_schema"].schema, "entity_id") is None + assert get_suggested(result["data_schema"].schema, "name") is None + + with patch( + "homeassistant.components.switch.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "entity_id": "switch.ceiling", + }, + ) + await hass.async_block_till_done() + + assert result["type"] == RESULT_TYPE_CREATE_ENTRY + assert result["title"] == "ceiling" + assert result["data"] == {} + assert result["options"] == {"entity_id": "switch.ceiling"} + assert len(mock_setup_entry.mock_calls) == 1 + + config_entry = hass.config_entries.async_entries(DOMAIN)[0] + assert config_entry.data == {} + assert config_entry.options == {"entity_id": "switch.ceiling"} + + # Switch light has no options flow + with pytest.raises(data_entry_flow.UnknownHandler): + await hass.config_entries.options.async_init(config_entry.entry_id) diff --git a/tests/components/switch/test_light.py b/tests/components/switch/test_light.py index 62fef242e9f..518fd8db20b 100644 --- a/tests/components/switch/test_light.py +++ b/tests/components/switch/test_light.py @@ -5,8 +5,12 @@ from homeassistant.components.light import ( ATTR_SUPPORTED_COLOR_MODES, COLOR_MODE_ONOFF, ) +from homeassistant.components.switch.const import DOMAIN as SWITCH_DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er from homeassistant.setup import async_setup_component +from tests.common import MockConfigEntry from tests.components.light import common from tests.components.switch import common as switch_common @@ -96,3 +100,69 @@ async def test_switch_service_calls(hass): assert hass.states.get("switch.decorative_lights").state == "on" assert hass.states.get("light.light_switch").state == "on" + + +async def test_config_entry(hass: HomeAssistant): + """Test light switch setup from config entry.""" + config_entry = MockConfigEntry( + data={}, + domain=SWITCH_DOMAIN, + options={"entity_id": "switch.abc"}, + title="ABC", + ) + + config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + assert SWITCH_DOMAIN in hass.config.components + + state = hass.states.get("light.abc") + assert state.state == "unavailable" + # Name copied from config entry title + assert state.name == "ABC" + + # Check the light is added to the entity registry + registry = er.async_get(hass) + entity_entry = registry.async_get("light.abc") + assert entity_entry.unique_id == config_entry.entry_id + + +async def test_config_entry_uuid(hass: HomeAssistant): + """Test light switch setup from config entry with entity registry id.""" + registry = er.async_get(hass) + registry_entry = registry.async_get_or_create("switch", "test", "unique") + + config_entry = MockConfigEntry( + data={}, + domain=SWITCH_DOMAIN, + options={"entity_id": registry_entry.id}, + title="ABC", + ) + + config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + assert hass.states.get("light.abc") + + +async def test_config_entry_unregistered_uuid(hass: HomeAssistant): + """Test light switch setup from config entry with unknown entity registry id.""" + fake_uuid = "a266a680b608c32770e6c45bfe6b8411" + + config_entry = MockConfigEntry( + data={}, + domain=SWITCH_DOMAIN, + options={"entity_id": fake_uuid}, + title="ABC", + ) + + config_entry.add_to_hass(hass) + + assert not await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + assert len(hass.states.async_all()) == 0