diff --git a/homeassistant/components/switch_as_x/__init__.py b/homeassistant/components/switch_as_x/__init__.py index 71cb9e9c225..b07bf0fdaec 100644 --- a/homeassistant/components/switch_as_x/__init__.py +++ b/homeassistant/components/switch_as_x/__init__.py @@ -9,7 +9,7 @@ import voluptuous as vol from homeassistant.components.homeassistant import exposed_entities from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_ENTITY_ID -from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, callback, valid_entity_id from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.event import async_track_entity_registry_updated_event @@ -44,10 +44,12 @@ def async_add_to_device( async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" - registry = er.async_get(hass) + entity_registry = er.async_get(hass) device_registry = dr.async_get(hass) try: - entity_id = er.async_validate_entity_id(registry, entry.options[CONF_ENTITY_ID]) + entity_id = er.async_validate_entity_id( + entity_registry, entry.options[CONF_ENTITY_ID] + ) except vol.Invalid: # The entity is identified by an unknown entity registry ID _LOGGER.error( @@ -68,14 +70,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return if "entity_id" in data["changes"]: - # Entity_id changed, reload the config entry - await hass.config_entries.async_reload(entry.entry_id) + # Entity_id changed, update or reload the config entry + if valid_entity_id(entry.options[CONF_ENTITY_ID]): + # If the entity is pointed to by an entity ID, update the entry + hass.config_entries.async_update_entry( + entry, + options={**entry.options, CONF_ENTITY_ID: data["entity_id"]}, + ) + else: + await hass.config_entries.async_reload(entry.entry_id) if device_id and "device_id" in data["changes"]: # If the tracked switch is no longer in the device, remove our config entry # from the device if ( - not (entity_entry := registry.async_get(data[CONF_ENTITY_ID])) + not (entity_entry := entity_registry.async_get(data[CONF_ENTITY_ID])) or not device_registry.async_get(device_id) or entity_entry.device_id == device_id ): diff --git a/tests/components/switch_as_x/test_init.py b/tests/components/switch_as_x/test_init.py index cd80fab69bc..0b965fc2ad1 100644 --- a/tests/components/switch_as_x/test_init.py +++ b/tests/components/switch_as_x/test_init.py @@ -39,6 +39,44 @@ EXPOSE_SETTINGS = { } +@pytest.fixture +def switch_entity_registry_entry( + entity_registry: er.EntityRegistry, +) -> er.RegistryEntry: + """Fixture to create a switch entity entry.""" + return entity_registry.async_get_or_create( + "switch", "test", "unique", original_name="ABC" + ) + + +@pytest.fixture +def switch_as_x_config_entry( + hass: HomeAssistant, + switch_entity_registry_entry: er.RegistryEntry, + target_domain: str, + use_entity_registry_id: bool, +) -> MockConfigEntry: + """Fixture to create a switch_as_x config entry.""" + config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + CONF_ENTITY_ID: switch_entity_registry_entry.id + if use_entity_registry_id + else switch_entity_registry_entry.entity_id, + CONF_INVERT: False, + CONF_TARGET_DOMAIN: target_domain, + }, + title="ABC", + version=SwitchAsXConfigFlowHandler.VERSION, + minor_version=SwitchAsXConfigFlowHandler.MINOR_VERSION, + ) + + config_entry.add_to_hass(hass) + + return config_entry + + @pytest.mark.parametrize("target_domain", PLATFORMS_TO_TEST) async def test_config_entry_unregistered_uuid( hass: HomeAssistant, target_domain: str @@ -67,6 +105,7 @@ async def test_config_entry_unregistered_uuid( assert len(hass.states.async_all()) == 0 +@pytest.mark.parametrize("use_entity_registry_id", [True, False]) @pytest.mark.parametrize( ("target_domain", "state_on", "state_off"), [ @@ -81,33 +120,17 @@ async def test_config_entry_unregistered_uuid( async def test_entity_registry_events( hass: HomeAssistant, entity_registry: er.EntityRegistry, + switch_entity_registry_entry: er.RegistryEntry, + switch_as_x_config_entry: MockConfigEntry, target_domain: str, state_on: str, state_off: str, ) -> None: """Test entity registry events are tracked.""" - registry_entry = entity_registry.async_get_or_create( - "switch", "test", "unique", original_name="ABC" - ) - switch_entity_id = registry_entry.entity_id + switch_entity_id = switch_entity_registry_entry.entity_id hass.states.async_set(switch_entity_id, STATE_ON) - config_entry = MockConfigEntry( - data={}, - domain=DOMAIN, - options={ - CONF_ENTITY_ID: registry_entry.id, - CONF_INVERT: False, - CONF_TARGET_DOMAIN: target_domain, - }, - title="ABC", - version=SwitchAsXConfigFlowHandler.VERSION, - minor_version=SwitchAsXConfigFlowHandler.MINOR_VERSION, - ) - - config_entry.add_to_hass(hass) - - assert await hass.config_entries.async_setup(config_entry.entry_id) + assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id) await hass.async_block_till_done() assert hass.states.get(f"{target_domain}.abc").state == state_on