mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 01:37:08 +00:00
Add switch_as_x entity to wrapped switch's device (#67961)
This commit is contained in:
parent
66d757115c
commit
8948bada58
@ -7,8 +7,8 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_ENTITY_ID
|
||||
from homeassistant.core import Event, HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers.event import async_track_entity_registry_updated_event
|
||||
|
||||
from .light import LightSwitch
|
||||
@ -20,9 +20,31 @@ DOMAIN = "switch_as_x"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_add_to_device(
|
||||
hass: HomeAssistant, entry: ConfigEntry, entity_id: str
|
||||
) -> str | None:
|
||||
"""Add our config entry to the tracked entity's device."""
|
||||
registry = er.async_get(hass)
|
||||
device_registry = dr.async_get(hass)
|
||||
device_id = None
|
||||
|
||||
if (
|
||||
not (wrapped_switch := registry.async_get(entity_id))
|
||||
or not (device_id := wrapped_switch.device_id)
|
||||
or not (device_registry.async_get(device_id))
|
||||
):
|
||||
return device_id
|
||||
|
||||
device_registry.async_update_device(device_id, add_config_entry_id=entry.entry_id)
|
||||
|
||||
return device_id
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
registry = er.async_get(hass)
|
||||
device_registry = dr.async_get(hass)
|
||||
try:
|
||||
entity_id = er.async_validate_entity_id(registry, entry.options[CONF_ENTITY_ID])
|
||||
except vol.Invalid:
|
||||
@ -39,11 +61,27 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
if data["action"] == "remove":
|
||||
await hass.config_entries.async_remove(entry.entry_id)
|
||||
|
||||
if data["action"] != "update" or "entity_id" not in data["changes"]:
|
||||
if data["action"] != "update":
|
||||
return
|
||||
|
||||
# Entity_id changed, reload the config entry
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
if "entity_id" in data["changes"]:
|
||||
# Entity_id changed, reload the config entry
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
||||
if device_id and "device_id" in data["changes"]:
|
||||
# If the tracked switch is no longer in the device, remove our config entry
|
||||
# from the device
|
||||
if (
|
||||
not (entity_entry := registry.async_get(data["entity_id"]))
|
||||
or not device_registry.async_get(device_id)
|
||||
or entity_entry.device_id == device_id
|
||||
):
|
||||
# No need to do any cleanup
|
||||
return
|
||||
|
||||
device_registry.async_update_device(
|
||||
device_id, remove_config_entry_id=entry.entry_id
|
||||
)
|
||||
|
||||
entry.async_on_unload(
|
||||
async_track_entity_registry_updated_event(
|
||||
@ -51,6 +89,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
)
|
||||
)
|
||||
|
||||
device_id = async_add_to_device(hass, entry, entity_id)
|
||||
|
||||
hass.config_entries.async_setup_platforms(entry, (entry.options["target_domain"],))
|
||||
return True
|
||||
|
||||
|
@ -31,6 +31,8 @@ async def async_setup_entry(
|
||||
entity_id = er.async_validate_entity_id(
|
||||
registry, config_entry.options[CONF_ENTITY_ID]
|
||||
)
|
||||
wrapped_switch = registry.async_get(entity_id)
|
||||
device_id = wrapped_switch.device_id if wrapped_switch else None
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
@ -38,6 +40,7 @@ async def async_setup_entry(
|
||||
config_entry.title,
|
||||
entity_id,
|
||||
config_entry.entry_id,
|
||||
device_id,
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -50,8 +53,15 @@ class LightSwitch(LightEntity):
|
||||
_attr_should_poll = False
|
||||
_attr_supported_color_modes = {COLOR_MODE_ONOFF}
|
||||
|
||||
def __init__(self, name: str, switch_entity_id: str, unique_id: str | None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
switch_entity_id: str,
|
||||
unique_id: str | None,
|
||||
device_id: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize Light Switch."""
|
||||
self._device_id = device_id
|
||||
self._attr_name = name
|
||||
self._attr_unique_id = unique_id
|
||||
self._switch_entity_id = switch_entity_id
|
||||
@ -100,3 +110,7 @@ class LightSwitch(LightEntity):
|
||||
|
||||
# Call once on adding
|
||||
async_state_changed_listener()
|
||||
|
||||
# Add this entity to the wrapped switch's device
|
||||
registry = er.async_get(self.hass)
|
||||
registry.async_update_entity(self.entity_id, device_id=self._device_id)
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from homeassistant.components.switch_as_x import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -82,3 +82,102 @@ async def test_entity_registry_events(hass: HomeAssistant, target_domain):
|
||||
assert hass.states.get(f"{target_domain}.abc") is None
|
||||
assert registry.async_get(f"{target_domain}.abc") is None
|
||||
assert len(hass.config_entries.async_entries("switch_as_x")) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("target_domain", ("light",))
|
||||
async def test_device_registry_config_entry_1(hass: HomeAssistant, target_domain):
|
||||
"""Test we add our config entry to the tracked switch's device."""
|
||||
device_registry = dr.async_get(hass)
|
||||
entity_registry = er.async_get(hass)
|
||||
|
||||
switch_config_entry = MockConfigEntry()
|
||||
|
||||
device_entry = device_registry.async_get_or_create(
|
||||
config_entry_id=switch_config_entry.entry_id,
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
switch_entity_entry = entity_registry.async_get_or_create(
|
||||
"switch",
|
||||
"test",
|
||||
"unique",
|
||||
config_entry=switch_config_entry,
|
||||
device_id=device_entry.id,
|
||||
)
|
||||
# Add another config entry to the same device
|
||||
device_registry.async_update_device(
|
||||
device_entry.id, add_config_entry_id=MockConfigEntry().entry_id
|
||||
)
|
||||
|
||||
switch_as_x_config_entry = MockConfigEntry(
|
||||
data={},
|
||||
domain=DOMAIN,
|
||||
options={"entity_id": switch_entity_entry.id, "target_domain": target_domain},
|
||||
title="ABC",
|
||||
)
|
||||
|
||||
switch_as_x_config_entry.add_to_hass(hass)
|
||||
|
||||
assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entity_entry = entity_registry.async_get(f"{target_domain}.abc")
|
||||
assert entity_entry.device_id == switch_entity_entry.device_id
|
||||
|
||||
device_entry = device_registry.async_get(device_entry.id)
|
||||
assert switch_as_x_config_entry.entry_id in device_entry.config_entries
|
||||
|
||||
# Remove the wrapped switch's config entry from the device
|
||||
device_registry.async_update_device(
|
||||
device_entry.id, remove_config_entry_id=switch_config_entry.entry_id
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
# Check that the switch_as_x config entry is removed from the device
|
||||
device_entry = device_registry.async_get(device_entry.id)
|
||||
assert switch_as_x_config_entry.entry_id not in device_entry.config_entries
|
||||
|
||||
|
||||
@pytest.mark.parametrize("target_domain", ("light",))
|
||||
async def test_device_registry_config_entry_2(hass: HomeAssistant, target_domain):
|
||||
"""Test we add our config entry to the tracked switch's device."""
|
||||
device_registry = dr.async_get(hass)
|
||||
entity_registry = er.async_get(hass)
|
||||
|
||||
switch_config_entry = MockConfigEntry()
|
||||
|
||||
device_entry = device_registry.async_get_or_create(
|
||||
config_entry_id=switch_config_entry.entry_id,
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
switch_entity_entry = entity_registry.async_get_or_create(
|
||||
"switch",
|
||||
"test",
|
||||
"unique",
|
||||
config_entry=switch_config_entry,
|
||||
device_id=device_entry.id,
|
||||
)
|
||||
|
||||
switch_as_x_config_entry = MockConfigEntry(
|
||||
data={},
|
||||
domain=DOMAIN,
|
||||
options={"entity_id": switch_entity_entry.id, "target_domain": target_domain},
|
||||
title="ABC",
|
||||
)
|
||||
|
||||
switch_as_x_config_entry.add_to_hass(hass)
|
||||
|
||||
assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entity_entry = entity_registry.async_get(f"{target_domain}.abc")
|
||||
assert entity_entry.device_id == switch_entity_entry.device_id
|
||||
|
||||
device_entry = device_registry.async_get(device_entry.id)
|
||||
assert switch_as_x_config_entry.entry_id in device_entry.config_entries
|
||||
|
||||
# Remove the wrapped switch from the device
|
||||
entity_registry.async_update_entity(switch_entity_entry.entity_id, device_id=None)
|
||||
await hass.async_block_till_done()
|
||||
# Check that the switch_as_x config entry is removed from the device
|
||||
device_entry = device_registry.async_get(device_entry.id)
|
||||
assert switch_as_x_config_entry.entry_id not in device_entry.config_entries
|
||||
|
@ -8,7 +8,7 @@ from homeassistant.components.light import (
|
||||
)
|
||||
from homeassistant.components.switch_as_x import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
@ -153,3 +153,35 @@ async def test_config_entry_uuid(hass: HomeAssistant, target_domain):
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get(f"{target_domain}.abc")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("target_domain", ("light",))
|
||||
async def test_device(hass: HomeAssistant, target_domain):
|
||||
"""Test the entity is added to the wrapped entity's device."""
|
||||
device_registry = dr.async_get(hass)
|
||||
entity_registry = er.async_get(hass)
|
||||
|
||||
test_config_entry = MockConfigEntry()
|
||||
|
||||
device_entry = device_registry.async_get_or_create(
|
||||
config_entry_id=test_config_entry.entry_id,
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
switch_entity_entry = entity_registry.async_get_or_create(
|
||||
"switch", "test", "unique", device_id=device_entry.id
|
||||
)
|
||||
|
||||
switch_as_x_config_entry = MockConfigEntry(
|
||||
data={},
|
||||
domain=DOMAIN,
|
||||
options={"entity_id": switch_entity_entry.id, "target_domain": target_domain},
|
||||
title="ABC",
|
||||
)
|
||||
|
||||
switch_as_x_config_entry.add_to_hass(hass)
|
||||
|
||||
assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entity_entry = entity_registry.async_get(f"{target_domain}.abc")
|
||||
assert entity_entry.device_id == switch_entity_entry.device_id
|
||||
|
Loading…
x
Reference in New Issue
Block a user