Track ESPHome entities by (device_id, key) to support sub-devices with overlaping names (#148297)

This commit is contained in:
J. Nick Koston 2025-07-07 23:41:20 -05:00 committed by GitHub
parent ccc80c78a0
commit dcf8d7f74d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 278 additions and 25 deletions

View File

@ -33,7 +33,12 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
# Import config flow so that it's added to the registry
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData, build_device_unique_id
from .entry_data import (
DeviceEntityKey,
ESPHomeConfigEntry,
RuntimeEntryData,
build_device_unique_id,
)
from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__)
@ -59,17 +64,32 @@ def async_static_info_updated(
device_info = entry_data.device_info
if TYPE_CHECKING:
assert device_info is not None
new_infos: dict[int, EntityInfo] = {}
new_infos: dict[DeviceEntityKey, EntityInfo] = {}
add_entities: list[_EntityT] = []
ent_reg = er.async_get(hass)
dev_reg = dr.async_get(hass)
# Track info by (info.device_id, info.key) to properly handle entities
# moving between devices and support sub-devices with overlapping keys
for info in infos:
new_infos[info.key] = info
info_key = (info.device_id, info.key)
new_infos[info_key] = info
# Try to find existing entity - first with current device_id
old_info = current_infos.pop(info_key, None)
# If not found, search for entity with same key but different device_id
# This handles the case where entity moved between devices
if not old_info:
for existing_device_id, existing_key in list(current_infos):
if existing_key == info.key:
# Found entity with same key but different device_id
old_info = current_infos.pop((existing_device_id, existing_key))
break
# Create new entity if it doesn't exist
if not (old_info := current_infos.pop(info.key, None)):
if not old_info:
entity = entity_type(entry_data, platform.domain, info, state_type)
add_entities.append(entity)
continue
@ -78,7 +98,7 @@ def async_static_info_updated(
if old_info.device_id == info.device_id:
continue
# Entity has switched devices, need to migrate unique_id
# Entity has switched devices, need to migrate unique_id and handle state subscriptions
old_unique_id = build_device_unique_id(device_info.mac_address, old_info)
entity_id = ent_reg.async_get_entity_id(platform.domain, DOMAIN, old_unique_id)
@ -103,7 +123,7 @@ def async_static_info_updated(
if old_unique_id != new_unique_id:
updates["new_unique_id"] = new_unique_id
# Update device assignment
# Update device assignment in registry
if info.device_id:
# Entity now belongs to a sub device
new_device = dev_reg.async_get_device(
@ -118,10 +138,32 @@ def async_static_info_updated(
if new_device:
updates["device_id"] = new_device.id
# Apply all updates at once
# Apply all registry updates at once
if updates:
ent_reg.async_update_entity(entity_id, **updates)
# IMPORTANT: The entity's device assignment in Home Assistant is only read when the entity
# is first added. Updating the registry alone won't move the entity to the new device
# in the UI. Additionally, the entity's state subscription is tied to the old device_id,
# so it won't receive state updates for the new device_id.
#
# We must remove the old entity and re-add it to ensure:
# 1. The entity appears under the correct device in the UI
# 2. The entity's state subscription is updated to use the new device_id
_LOGGER.debug(
"Entity %s moving from device_id %s to %s",
info.key,
old_info.device_id,
info.device_id,
)
# Signal the existing entity to remove itself
# The entity is registered with the old device_id, so we signal with that
entry_data.async_signal_entity_removal(info_type, old_info.device_id, info.key)
# Create new entity with the new device_id
add_entities.append(entity_type(entry_data, platform.domain, info, state_type))
# Anything still in current_infos is now gone
if current_infos:
entry_data.async_remove_entities(
@ -341,7 +383,10 @@ class EsphomeEntity(EsphomeBaseEntity, Generic[_InfoT, _StateT]):
)
self.async_on_remove(
entry_data.async_subscribe_state_update(
self._state_type, self._key, self._on_state_update
self._static_info.device_id,
self._state_type,
self._key,
self._on_state_update,
)
)
self.async_on_remove(
@ -349,8 +394,29 @@ class EsphomeEntity(EsphomeBaseEntity, Generic[_InfoT, _StateT]):
self._static_info, self._on_static_info_update
)
)
# Register to be notified when this entity should remove itself
# This happens when the entity moves to a different device
self.async_on_remove(
entry_data.async_register_entity_removal_callback(
type(self._static_info),
self._static_info.device_id,
self._key,
self._on_removal_signal,
)
)
self._update_state_from_entry_data()
@callback
def _on_removal_signal(self) -> None:
"""Handle signal to remove this entity."""
_LOGGER.debug(
"Entity %s received removal signal due to device_id change",
self.entity_id,
)
# Schedule the entity to be removed
# This must be done as a task since we're in a callback
self.hass.async_create_task(self.async_remove())
@callback
def _on_static_info_update(self, static_info: EntityInfo) -> None:
"""Save the static info for this entity when it changes.

View File

@ -60,7 +60,9 @@ from .const import DOMAIN
from .dashboard import async_get_dashboard
type ESPHomeConfigEntry = ConfigEntry[RuntimeEntryData]
type EntityStateKey = tuple[type[EntityState], int, int] # (state_type, device_id, key)
type EntityInfoKey = tuple[type[EntityInfo], int, int] # (info_type, device_id, key)
type DeviceEntityKey = tuple[int, int] # (device_id, key)
INFO_TO_COMPONENT_TYPE: Final = {v: k for k, v in COMPONENT_TYPE_TO_INFO.items()}
@ -137,8 +139,10 @@ class RuntimeEntryData:
# When the disconnect callback is called, we mark all states
# as stale so we will always dispatch a state update when the
# device reconnects. This is the same format as state_subscriptions.
stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set)
info: dict[type[EntityInfo], dict[int, EntityInfo]] = field(default_factory=dict)
stale_state: set[EntityStateKey] = field(default_factory=set)
info: dict[type[EntityInfo], dict[DeviceEntityKey, EntityInfo]] = field(
default_factory=dict
)
services: dict[int, UserService] = field(default_factory=dict)
available: bool = False
expected_disconnect: bool = False # Last disconnect was expected (e.g. deep sleep)
@ -147,7 +151,7 @@ class RuntimeEntryData:
api_version: APIVersion = field(default_factory=APIVersion)
cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
disconnect_callbacks: set[CALLBACK_TYPE] = field(default_factory=set)
state_subscriptions: dict[tuple[type[EntityState], int], CALLBACK_TYPE] = field(
state_subscriptions: dict[EntityStateKey, CALLBACK_TYPE] = field(
default_factory=dict
)
device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set)
@ -164,7 +168,7 @@ class RuntimeEntryData:
type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
] = field(default_factory=dict)
entity_info_key_updated_callbacks: dict[
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
EntityInfoKey, list[Callable[[EntityInfo], None]]
] = field(default_factory=dict)
original_options: dict[str, Any] = field(default_factory=dict)
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
@ -177,6 +181,9 @@ class RuntimeEntryData:
default_factory=list
)
device_id_to_name: dict[int, str] = field(default_factory=dict)
entity_removal_callbacks: dict[EntityInfoKey, list[CALLBACK_TYPE]] = field(
default_factory=dict
)
@property
def name(self) -> str:
@ -210,7 +217,7 @@ class RuntimeEntryData:
callback_: Callable[[EntityInfo], None],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when static info is updated for a specific key."""
callback_key = (type(static_info), static_info.key)
callback_key = (type(static_info), static_info.device_id, static_info.key)
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
return partial(callbacks.remove, callback_)
@ -250,7 +257,9 @@ class RuntimeEntryData:
"""Call static info updated callbacks."""
callbacks = self.entity_info_key_updated_callbacks
for static_info in static_infos:
for callback_ in callbacks.get((type(static_info), static_info.key), ()):
for callback_ in callbacks.get(
(type(static_info), static_info.device_id, static_info.key), ()
):
callback_(static_info)
async def _ensure_platforms_loaded(
@ -342,12 +351,13 @@ class RuntimeEntryData:
@callback
def async_subscribe_state_update(
self,
device_id: int,
state_type: type[EntityState],
state_key: int,
entity_callback: CALLBACK_TYPE,
) -> CALLBACK_TYPE:
"""Subscribe to state updates."""
subscription_key = (state_type, state_key)
subscription_key = (state_type, device_id, state_key)
self.state_subscriptions[subscription_key] = entity_callback
return partial(delitem, self.state_subscriptions, subscription_key)
@ -359,7 +369,7 @@ class RuntimeEntryData:
stale_state = self.stale_state
current_state_by_type = self.state[state_type]
current_state = current_state_by_type.get(key, _SENTINEL)
subscription_key = (state_type, key)
subscription_key = (state_type, state.device_id, key)
if (
current_state == state
and subscription_key not in stale_state
@ -367,7 +377,7 @@ class RuntimeEntryData:
and not (
state_type is SensorState
and (platform_info := self.info.get(SensorInfo))
and (entity_info := platform_info.get(state.key))
and (entity_info := platform_info.get((state.device_id, state.key)))
and (cast(SensorInfo, entity_info)).force_update
)
):
@ -520,3 +530,26 @@ class RuntimeEntryData:
"""Notify listeners that the Assist satellite wake word has been set."""
for callback_ in self.assist_satellite_set_wake_word_callbacks.copy():
callback_(wake_word_id)
@callback
def async_register_entity_removal_callback(
self,
info_type: type[EntityInfo],
device_id: int,
key: int,
callback_: CALLBACK_TYPE,
) -> CALLBACK_TYPE:
"""Register to receive a callback when the entity should remove itself."""
callback_key = (info_type, device_id, key)
callbacks = self.entity_removal_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
return partial(callbacks.remove, callback_)
@callback
def async_signal_entity_removal(
self, info_type: type[EntityInfo], device_id: int, key: int
) -> None:
"""Signal that an entity should remove itself."""
callback_key = (info_type, device_id, key)
for callback_ in self.entity_removal_callbacks.get(callback_key, []).copy():
callback_()

View File

@ -588,7 +588,7 @@ class ESPHomeManager:
# Mark state as stale so that we will always dispatch
# the next state update of that type when the device reconnects
entry_data.stale_state = {
(type(entity_state), key)
(type(entity_state), entity_state.device_id, key)
for state_dict in entry_data.state.values()
for key, entity_state in state_dict.items()
}

View File

@ -1,6 +1,6 @@
"""Test ESPHome binary sensors."""
from aioesphomeapi import APIClient, BinarySensorInfo, BinarySensorState
from aioesphomeapi import APIClient, BinarySensorInfo, BinarySensorState, SubDeviceInfo
import pytest
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNKNOWN
@ -127,3 +127,157 @@ async def test_binary_sensor_has_state_false(
state = hass.states.get("binary_sensor.test_my_binary_sensor")
assert state is not None
assert state.state == STATE_ON
async def test_binary_sensors_same_key_different_device_id(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: MockESPHomeDeviceType,
) -> None:
"""Test binary sensors with same key but different device_id."""
# Create sub-devices
sub_devices = [
SubDeviceInfo(device_id=11111111, name="Sub Device 1", area_id=0),
SubDeviceInfo(device_id=22222222, name="Sub Device 2", area_id=0),
]
device_info = {
"name": "test",
"devices": sub_devices,
}
# Both sub-devices have a binary sensor with key=1
entity_info = [
BinarySensorInfo(
object_id="sensor",
key=1,
name="Motion",
unique_id="motion_1",
device_id=11111111,
),
BinarySensorInfo(
object_id="sensor",
key=1,
name="Motion",
unique_id="motion_2",
device_id=22222222,
),
]
# States for both sensors with same key but different device_id
states = [
BinarySensorState(key=1, state=True, missing_state=False, device_id=11111111),
BinarySensorState(key=1, state=False, missing_state=False, device_id=22222222),
]
mock_device = await mock_esphome_device(
mock_client=mock_client,
device_info=device_info,
entity_info=entity_info,
states=states,
)
# Verify both entities exist and have correct states
state1 = hass.states.get("binary_sensor.sub_device_1_motion")
assert state1 is not None
assert state1.state == STATE_ON
state2 = hass.states.get("binary_sensor.sub_device_2_motion")
assert state2 is not None
assert state2.state == STATE_OFF
# Update states to verify they update independently
mock_device.set_state(
BinarySensorState(key=1, state=False, missing_state=False, device_id=11111111)
)
await hass.async_block_till_done()
state1 = hass.states.get("binary_sensor.sub_device_1_motion")
assert state1.state == STATE_OFF
# Sub device 2 should remain unchanged
state2 = hass.states.get("binary_sensor.sub_device_2_motion")
assert state2.state == STATE_OFF
# Update sub device 2
mock_device.set_state(
BinarySensorState(key=1, state=True, missing_state=False, device_id=22222222)
)
await hass.async_block_till_done()
state2 = hass.states.get("binary_sensor.sub_device_2_motion")
assert state2.state == STATE_ON
# Sub device 1 should remain unchanged
state1 = hass.states.get("binary_sensor.sub_device_1_motion")
assert state1.state == STATE_OFF
async def test_binary_sensor_main_and_sub_device_same_key(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: MockESPHomeDeviceType,
) -> None:
"""Test binary sensor on main device and sub-device with same key."""
# Create sub-device
sub_devices = [
SubDeviceInfo(device_id=11111111, name="Sub Device", area_id=0),
]
device_info = {
"name": "test",
"devices": sub_devices,
}
# Main device and sub-device both have a binary sensor with key=1
entity_info = [
BinarySensorInfo(
object_id="main_sensor",
key=1,
name="Main Sensor",
unique_id="main_1",
device_id=0, # Main device
),
BinarySensorInfo(
object_id="sub_sensor",
key=1,
name="Sub Sensor",
unique_id="sub_1",
device_id=11111111,
),
]
# States for both sensors
states = [
BinarySensorState(key=1, state=True, missing_state=False, device_id=0),
BinarySensorState(key=1, state=False, missing_state=False, device_id=11111111),
]
mock_device = await mock_esphome_device(
mock_client=mock_client,
device_info=device_info,
entity_info=entity_info,
states=states,
)
# Verify both entities exist
main_state = hass.states.get("binary_sensor.test_main_sensor")
assert main_state is not None
assert main_state.state == STATE_ON
sub_state = hass.states.get("binary_sensor.sub_device_sub_sensor")
assert sub_state is not None
assert sub_state.state == STATE_OFF
# Update main device sensor
mock_device.set_state(
BinarySensorState(key=1, state=False, missing_state=False, device_id=0)
)
await hass.async_block_till_done()
main_state = hass.states.get("binary_sensor.test_main_sensor")
assert main_state.state == STATE_OFF
# Sub device sensor should remain unchanged
sub_state = hass.states.get("binary_sensor.sub_device_sub_sensor")
assert sub_state.state == STATE_OFF

View File

@ -754,9 +754,9 @@ async def test_entity_assignment_to_sub_device(
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
BinarySensorState(key=2, state=False, missing_state=False),
BinarySensorState(key=3, state=True, missing_state=False),
BinarySensorState(key=1, state=True, missing_state=False, device_id=0),
BinarySensorState(key=2, state=False, missing_state=False, device_id=11111111),
BinarySensorState(key=3, state=True, missing_state=False, device_id=22222222),
]
device = await mock_esphome_device(
@ -938,7 +938,7 @@ async def test_entity_switches_between_devices(
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
BinarySensorState(key=1, state=True, missing_state=False, device_id=0),
]
device = await mock_esphome_device(
@ -1507,7 +1507,7 @@ async def test_entity_device_id_rename_in_yaml(
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
BinarySensorState(key=1, state=True, missing_state=False, device_id=11111111),
]
device = await mock_esphome_device(