Add switch_as_x entity to wrapped switch's device (#67961)

This commit is contained in:
Erik Montnemery 2022-03-11 09:46:32 +01:00 committed by GitHub
parent 66d757115c
commit 8948bada58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 193 additions and 8 deletions

View File

@ -7,8 +7,8 @@ import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ENTITY_ID
from homeassistant.core import Event, HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.event import async_track_entity_registry_updated_event
from .light import LightSwitch
@ -20,9 +20,31 @@ DOMAIN = "switch_as_x"
_LOGGER = logging.getLogger(__name__)
@callback
def async_add_to_device(
hass: HomeAssistant, entry: ConfigEntry, entity_id: str
) -> str | None:
"""Add our config entry to the tracked entity's device."""
registry = er.async_get(hass)
device_registry = dr.async_get(hass)
device_id = None
if (
not (wrapped_switch := registry.async_get(entity_id))
or not (device_id := wrapped_switch.device_id)
or not (device_registry.async_get(device_id))
):
return device_id
device_registry.async_update_device(device_id, add_config_entry_id=entry.entry_id)
return device_id
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
registry = er.async_get(hass)
device_registry = dr.async_get(hass)
try:
entity_id = er.async_validate_entity_id(registry, entry.options[CONF_ENTITY_ID])
except vol.Invalid:
@ -39,11 +61,27 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if data["action"] == "remove":
await hass.config_entries.async_remove(entry.entry_id)
if data["action"] != "update" or "entity_id" not in data["changes"]:
if data["action"] != "update":
return
# Entity_id changed, reload the config entry
await hass.config_entries.async_reload(entry.entry_id)
if "entity_id" in data["changes"]:
# Entity_id changed, reload the config entry
await hass.config_entries.async_reload(entry.entry_id)
if device_id and "device_id" in data["changes"]:
# If the tracked switch is no longer in the device, remove our config entry
# from the device
if (
not (entity_entry := registry.async_get(data["entity_id"]))
or not device_registry.async_get(device_id)
or entity_entry.device_id == device_id
):
# No need to do any cleanup
return
device_registry.async_update_device(
device_id, remove_config_entry_id=entry.entry_id
)
entry.async_on_unload(
async_track_entity_registry_updated_event(
@ -51,6 +89,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
)
device_id = async_add_to_device(hass, entry, entity_id)
hass.config_entries.async_setup_platforms(entry, (entry.options["target_domain"],))
return True

View File

@ -31,6 +31,8 @@ async def async_setup_entry(
entity_id = er.async_validate_entity_id(
registry, config_entry.options[CONF_ENTITY_ID]
)
wrapped_switch = registry.async_get(entity_id)
device_id = wrapped_switch.device_id if wrapped_switch else None
async_add_entities(
[
@ -38,6 +40,7 @@ async def async_setup_entry(
config_entry.title,
entity_id,
config_entry.entry_id,
device_id,
)
]
)
@ -50,8 +53,15 @@ class LightSwitch(LightEntity):
_attr_should_poll = False
_attr_supported_color_modes = {COLOR_MODE_ONOFF}
def __init__(self, name: str, switch_entity_id: str, unique_id: str | None) -> None:
def __init__(
self,
name: str,
switch_entity_id: str,
unique_id: str | None,
device_id: str | None = None,
) -> None:
"""Initialize Light Switch."""
self._device_id = device_id
self._attr_name = name
self._attr_unique_id = unique_id
self._switch_entity_id = switch_entity_id
@ -100,3 +110,7 @@ class LightSwitch(LightEntity):
# Call once on adding
async_state_changed_listener()
# Add this entity to the wrapped switch's device
registry = er.async_get(self.hass)
registry.async_update_entity(self.entity_id, device_id=self._device_id)

View File

@ -5,7 +5,7 @@ import pytest
from homeassistant.components.switch_as_x import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers import device_registry as dr, entity_registry as er
from tests.common import MockConfigEntry
@ -82,3 +82,102 @@ async def test_entity_registry_events(hass: HomeAssistant, target_domain):
assert hass.states.get(f"{target_domain}.abc") is None
assert registry.async_get(f"{target_domain}.abc") is None
assert len(hass.config_entries.async_entries("switch_as_x")) == 0
@pytest.mark.parametrize("target_domain", ("light",))
async def test_device_registry_config_entry_1(hass: HomeAssistant, target_domain):
"""Test we add our config entry to the tracked switch's device."""
device_registry = dr.async_get(hass)
entity_registry = er.async_get(hass)
switch_config_entry = MockConfigEntry()
device_entry = device_registry.async_get_or_create(
config_entry_id=switch_config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
switch_entity_entry = entity_registry.async_get_or_create(
"switch",
"test",
"unique",
config_entry=switch_config_entry,
device_id=device_entry.id,
)
# Add another config entry to the same device
device_registry.async_update_device(
device_entry.id, add_config_entry_id=MockConfigEntry().entry_id
)
switch_as_x_config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={"entity_id": switch_entity_entry.id, "target_domain": target_domain},
title="ABC",
)
switch_as_x_config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id)
await hass.async_block_till_done()
entity_entry = entity_registry.async_get(f"{target_domain}.abc")
assert entity_entry.device_id == switch_entity_entry.device_id
device_entry = device_registry.async_get(device_entry.id)
assert switch_as_x_config_entry.entry_id in device_entry.config_entries
# Remove the wrapped switch's config entry from the device
device_registry.async_update_device(
device_entry.id, remove_config_entry_id=switch_config_entry.entry_id
)
await hass.async_block_till_done()
await hass.async_block_till_done()
# Check that the switch_as_x config entry is removed from the device
device_entry = device_registry.async_get(device_entry.id)
assert switch_as_x_config_entry.entry_id not in device_entry.config_entries
@pytest.mark.parametrize("target_domain", ("light",))
async def test_device_registry_config_entry_2(hass: HomeAssistant, target_domain):
"""Test we add our config entry to the tracked switch's device."""
device_registry = dr.async_get(hass)
entity_registry = er.async_get(hass)
switch_config_entry = MockConfigEntry()
device_entry = device_registry.async_get_or_create(
config_entry_id=switch_config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
switch_entity_entry = entity_registry.async_get_or_create(
"switch",
"test",
"unique",
config_entry=switch_config_entry,
device_id=device_entry.id,
)
switch_as_x_config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={"entity_id": switch_entity_entry.id, "target_domain": target_domain},
title="ABC",
)
switch_as_x_config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id)
await hass.async_block_till_done()
entity_entry = entity_registry.async_get(f"{target_domain}.abc")
assert entity_entry.device_id == switch_entity_entry.device_id
device_entry = device_registry.async_get(device_entry.id)
assert switch_as_x_config_entry.entry_id in device_entry.config_entries
# Remove the wrapped switch from the device
entity_registry.async_update_entity(switch_entity_entry.entity_id, device_id=None)
await hass.async_block_till_done()
# Check that the switch_as_x config entry is removed from the device
device_entry = device_registry.async_get(device_entry.id)
assert switch_as_x_config_entry.entry_id not in device_entry.config_entries

View File

@ -8,7 +8,7 @@ from homeassistant.components.light import (
)
from homeassistant.components.switch_as_x import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@ -153,3 +153,35 @@ async def test_config_entry_uuid(hass: HomeAssistant, target_domain):
await hass.async_block_till_done()
assert hass.states.get(f"{target_domain}.abc")
@pytest.mark.parametrize("target_domain", ("light",))
async def test_device(hass: HomeAssistant, target_domain):
"""Test the entity is added to the wrapped entity's device."""
device_registry = dr.async_get(hass)
entity_registry = er.async_get(hass)
test_config_entry = MockConfigEntry()
device_entry = device_registry.async_get_or_create(
config_entry_id=test_config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
switch_entity_entry = entity_registry.async_get_or_create(
"switch", "test", "unique", device_id=device_entry.id
)
switch_as_x_config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={"entity_id": switch_entity_entry.id, "target_domain": target_domain},
title="ABC",
)
switch_as_x_config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id)
await hass.async_block_till_done()
entity_entry = entity_registry.async_get(f"{target_domain}.abc")
assert entity_entry.device_id == switch_entity_entry.device_id