diff --git a/homeassistant/components/switch_as_x/__init__.py b/homeassistant/components/switch_as_x/__init__.py index 65b95c59c6d..2c647b1e953 100644 --- a/homeassistant/components/switch_as_x/__init__.py +++ b/homeassistant/components/switch_as_x/__init__.py @@ -7,8 +7,8 @@ import voluptuous as vol from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_ENTITY_ID -from homeassistant.core import Event, HomeAssistant -from homeassistant.helpers import entity_registry as er +from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.event import async_track_entity_registry_updated_event from .light import LightSwitch @@ -20,9 +20,31 @@ DOMAIN = "switch_as_x" _LOGGER = logging.getLogger(__name__) +@callback +def async_add_to_device( + hass: HomeAssistant, entry: ConfigEntry, entity_id: str +) -> str | None: + """Add our config entry to the tracked entity's device.""" + registry = er.async_get(hass) + device_registry = dr.async_get(hass) + device_id = None + + if ( + not (wrapped_switch := registry.async_get(entity_id)) + or not (device_id := wrapped_switch.device_id) + or not (device_registry.async_get(device_id)) + ): + return device_id + + device_registry.async_update_device(device_id, add_config_entry_id=entry.entry_id) + + return device_id + + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" registry = er.async_get(hass) + device_registry = dr.async_get(hass) try: entity_id = er.async_validate_entity_id(registry, entry.options[CONF_ENTITY_ID]) except vol.Invalid: @@ -39,11 +61,27 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if data["action"] == "remove": await hass.config_entries.async_remove(entry.entry_id) - if data["action"] != "update" or "entity_id" not in data["changes"]: + if data["action"] != "update": return - # Entity_id changed, reload the config entry - await hass.config_entries.async_reload(entry.entry_id) + if "entity_id" in data["changes"]: + # Entity_id changed, reload the config entry + await hass.config_entries.async_reload(entry.entry_id) + + if device_id and "device_id" in data["changes"]: + # If the tracked switch is no longer in the device, remove our config entry + # from the device + if ( + not (entity_entry := registry.async_get(data["entity_id"])) + or not device_registry.async_get(device_id) + or entity_entry.device_id == device_id + ): + # No need to do any cleanup + return + + device_registry.async_update_device( + device_id, remove_config_entry_id=entry.entry_id + ) entry.async_on_unload( async_track_entity_registry_updated_event( @@ -51,6 +89,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) ) + device_id = async_add_to_device(hass, entry, entity_id) + hass.config_entries.async_setup_platforms(entry, (entry.options["target_domain"],)) return True diff --git a/homeassistant/components/switch_as_x/light.py b/homeassistant/components/switch_as_x/light.py index 8dd863029da..e6fc334caef 100644 --- a/homeassistant/components/switch_as_x/light.py +++ b/homeassistant/components/switch_as_x/light.py @@ -31,6 +31,8 @@ async def async_setup_entry( entity_id = er.async_validate_entity_id( registry, config_entry.options[CONF_ENTITY_ID] ) + wrapped_switch = registry.async_get(entity_id) + device_id = wrapped_switch.device_id if wrapped_switch else None async_add_entities( [ @@ -38,6 +40,7 @@ async def async_setup_entry( config_entry.title, entity_id, config_entry.entry_id, + device_id, ) ] ) @@ -50,8 +53,15 @@ class LightSwitch(LightEntity): _attr_should_poll = False _attr_supported_color_modes = {COLOR_MODE_ONOFF} - def __init__(self, name: str, switch_entity_id: str, unique_id: str | None) -> None: + def __init__( + self, + name: str, + switch_entity_id: str, + unique_id: str | None, + device_id: str | None = None, + ) -> None: """Initialize Light Switch.""" + self._device_id = device_id self._attr_name = name self._attr_unique_id = unique_id self._switch_entity_id = switch_entity_id @@ -100,3 +110,7 @@ class LightSwitch(LightEntity): # Call once on adding async_state_changed_listener() + + # Add this entity to the wrapped switch's device + registry = er.async_get(self.hass) + registry.async_update_entity(self.entity_id, device_id=self._device_id) diff --git a/tests/components/switch_as_x/test_init.py b/tests/components/switch_as_x/test_init.py index a8875def0ad..8eafc417c04 100644 --- a/tests/components/switch_as_x/test_init.py +++ b/tests/components/switch_as_x/test_init.py @@ -5,7 +5,7 @@ import pytest from homeassistant.components.switch_as_x import DOMAIN from homeassistant.core import HomeAssistant -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from tests.common import MockConfigEntry @@ -82,3 +82,102 @@ async def test_entity_registry_events(hass: HomeAssistant, target_domain): assert hass.states.get(f"{target_domain}.abc") is None assert registry.async_get(f"{target_domain}.abc") is None assert len(hass.config_entries.async_entries("switch_as_x")) == 0 + + +@pytest.mark.parametrize("target_domain", ("light",)) +async def test_device_registry_config_entry_1(hass: HomeAssistant, target_domain): + """Test we add our config entry to the tracked switch's device.""" + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) + + switch_config_entry = MockConfigEntry() + + device_entry = device_registry.async_get_or_create( + config_entry_id=switch_config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + switch_entity_entry = entity_registry.async_get_or_create( + "switch", + "test", + "unique", + config_entry=switch_config_entry, + device_id=device_entry.id, + ) + # Add another config entry to the same device + device_registry.async_update_device( + device_entry.id, add_config_entry_id=MockConfigEntry().entry_id + ) + + switch_as_x_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={"entity_id": switch_entity_entry.id, "target_domain": target_domain}, + title="ABC", + ) + + switch_as_x_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id) + await hass.async_block_till_done() + + entity_entry = entity_registry.async_get(f"{target_domain}.abc") + assert entity_entry.device_id == switch_entity_entry.device_id + + device_entry = device_registry.async_get(device_entry.id) + assert switch_as_x_config_entry.entry_id in device_entry.config_entries + + # Remove the wrapped switch's config entry from the device + device_registry.async_update_device( + device_entry.id, remove_config_entry_id=switch_config_entry.entry_id + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + # Check that the switch_as_x config entry is removed from the device + device_entry = device_registry.async_get(device_entry.id) + assert switch_as_x_config_entry.entry_id not in device_entry.config_entries + + +@pytest.mark.parametrize("target_domain", ("light",)) +async def test_device_registry_config_entry_2(hass: HomeAssistant, target_domain): + """Test we add our config entry to the tracked switch's device.""" + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) + + switch_config_entry = MockConfigEntry() + + device_entry = device_registry.async_get_or_create( + config_entry_id=switch_config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + switch_entity_entry = entity_registry.async_get_or_create( + "switch", + "test", + "unique", + config_entry=switch_config_entry, + device_id=device_entry.id, + ) + + switch_as_x_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={"entity_id": switch_entity_entry.id, "target_domain": target_domain}, + title="ABC", + ) + + switch_as_x_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id) + await hass.async_block_till_done() + + entity_entry = entity_registry.async_get(f"{target_domain}.abc") + assert entity_entry.device_id == switch_entity_entry.device_id + + device_entry = device_registry.async_get(device_entry.id) + assert switch_as_x_config_entry.entry_id in device_entry.config_entries + + # Remove the wrapped switch from the device + entity_registry.async_update_entity(switch_entity_entry.entity_id, device_id=None) + await hass.async_block_till_done() + # Check that the switch_as_x config entry is removed from the device + device_entry = device_registry.async_get(device_entry.id) + assert switch_as_x_config_entry.entry_id not in device_entry.config_entries diff --git a/tests/components/switch_as_x/test_light.py b/tests/components/switch_as_x/test_light.py index 9a480776510..13ae058f8d9 100644 --- a/tests/components/switch_as_x/test_light.py +++ b/tests/components/switch_as_x/test_light.py @@ -8,7 +8,7 @@ from homeassistant.components.light import ( ) from homeassistant.components.switch_as_x import DOMAIN from homeassistant.core import HomeAssistant -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -153,3 +153,35 @@ async def test_config_entry_uuid(hass: HomeAssistant, target_domain): await hass.async_block_till_done() assert hass.states.get(f"{target_domain}.abc") + + +@pytest.mark.parametrize("target_domain", ("light",)) +async def test_device(hass: HomeAssistant, target_domain): + """Test the entity is added to the wrapped entity's device.""" + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) + + test_config_entry = MockConfigEntry() + + device_entry = device_registry.async_get_or_create( + config_entry_id=test_config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + switch_entity_entry = entity_registry.async_get_or_create( + "switch", "test", "unique", device_id=device_entry.id + ) + + switch_as_x_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={"entity_id": switch_entity_entry.id, "target_domain": target_domain}, + title="ABC", + ) + + switch_as_x_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id) + await hass.async_block_till_done() + + entity_entry = entity_registry.async_get(f"{target_domain}.abc") + assert entity_entry.device_id == switch_entity_entry.device_id