Fix switch_as_x entity_id tracking (#146386)

This commit is contained in:
Erik Montnemery 2025-06-09 13:24:40 +02:00 committed by GitHub
parent b1a2af9fd3
commit 46dcc91510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 26 deletions

View File

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.components.homeassistant import exposed_entities
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ENTITY_ID
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.core import Event, HomeAssistant, callback, valid_entity_id
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.event import async_track_entity_registry_updated_event
@ -44,10 +44,12 @@ def async_add_to_device(
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
registry = er.async_get(hass)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
try:
entity_id = er.async_validate_entity_id(registry, entry.options[CONF_ENTITY_ID])
entity_id = er.async_validate_entity_id(
entity_registry, entry.options[CONF_ENTITY_ID]
)
except vol.Invalid:
# The entity is identified by an unknown entity registry ID
_LOGGER.error(
@ -68,14 +70,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return
if "entity_id" in data["changes"]:
# Entity_id changed, reload the config entry
await hass.config_entries.async_reload(entry.entry_id)
# Entity_id changed, update or reload the config entry
if valid_entity_id(entry.options[CONF_ENTITY_ID]):
# If the entity is pointed to by an entity ID, update the entry
hass.config_entries.async_update_entry(
entry,
options={**entry.options, CONF_ENTITY_ID: data["entity_id"]},
)
else:
await hass.config_entries.async_reload(entry.entry_id)
if device_id and "device_id" in data["changes"]:
# If the tracked switch is no longer in the device, remove our config entry
# from the device
if (
not (entity_entry := registry.async_get(data[CONF_ENTITY_ID]))
not (entity_entry := entity_registry.async_get(data[CONF_ENTITY_ID]))
or not device_registry.async_get(device_id)
or entity_entry.device_id == device_id
):

View File

@ -39,6 +39,44 @@ EXPOSE_SETTINGS = {
}
@pytest.fixture
def switch_entity_registry_entry(
entity_registry: er.EntityRegistry,
) -> er.RegistryEntry:
"""Fixture to create a switch entity entry."""
return entity_registry.async_get_or_create(
"switch", "test", "unique", original_name="ABC"
)
@pytest.fixture
def switch_as_x_config_entry(
hass: HomeAssistant,
switch_entity_registry_entry: er.RegistryEntry,
target_domain: str,
use_entity_registry_id: bool,
) -> MockConfigEntry:
"""Fixture to create a switch_as_x config entry."""
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_ENTITY_ID: switch_entity_registry_entry.id
if use_entity_registry_id
else switch_entity_registry_entry.entity_id,
CONF_INVERT: False,
CONF_TARGET_DOMAIN: target_domain,
},
title="ABC",
version=SwitchAsXConfigFlowHandler.VERSION,
minor_version=SwitchAsXConfigFlowHandler.MINOR_VERSION,
)
config_entry.add_to_hass(hass)
return config_entry
@pytest.mark.parametrize("target_domain", PLATFORMS_TO_TEST)
async def test_config_entry_unregistered_uuid(
hass: HomeAssistant, target_domain: str
@ -67,6 +105,7 @@ async def test_config_entry_unregistered_uuid(
assert len(hass.states.async_all()) == 0
@pytest.mark.parametrize("use_entity_registry_id", [True, False])
@pytest.mark.parametrize(
("target_domain", "state_on", "state_off"),
[
@ -81,33 +120,17 @@ async def test_config_entry_unregistered_uuid(
async def test_entity_registry_events(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
switch_entity_registry_entry: er.RegistryEntry,
switch_as_x_config_entry: MockConfigEntry,
target_domain: str,
state_on: str,
state_off: str,
) -> None:
"""Test entity registry events are tracked."""
registry_entry = entity_registry.async_get_or_create(
"switch", "test", "unique", original_name="ABC"
)
switch_entity_id = registry_entry.entity_id
switch_entity_id = switch_entity_registry_entry.entity_id
hass.states.async_set(switch_entity_id, STATE_ON)
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_ENTITY_ID: registry_entry.id,
CONF_INVERT: False,
CONF_TARGET_DOMAIN: target_domain,
},
title="ABC",
version=SwitchAsXConfigFlowHandler.VERSION,
minor_version=SwitchAsXConfigFlowHandler.MINOR_VERSION,
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
assert await hass.config_entries.async_setup(switch_as_x_config_entry.entry_id)
await hass.async_block_till_done()
assert hass.states.get(f"{target_domain}.abc").state == state_on