mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Track ESPHome entities by (device_id, key) to support sub-devices with overlaping names (#148297)
This commit is contained in:
parent
ccc80c78a0
commit
dcf8d7f74d
@ -33,7 +33,12 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from .const import DOMAIN
|
||||
|
||||
# Import config flow so that it's added to the registry
|
||||
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData, build_device_unique_id
|
||||
from .entry_data import (
|
||||
DeviceEntityKey,
|
||||
ESPHomeConfigEntry,
|
||||
RuntimeEntryData,
|
||||
build_device_unique_id,
|
||||
)
|
||||
from .enum_mapper import EsphomeEnumMapper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -59,17 +64,32 @@ def async_static_info_updated(
|
||||
device_info = entry_data.device_info
|
||||
if TYPE_CHECKING:
|
||||
assert device_info is not None
|
||||
new_infos: dict[int, EntityInfo] = {}
|
||||
new_infos: dict[DeviceEntityKey, EntityInfo] = {}
|
||||
add_entities: list[_EntityT] = []
|
||||
|
||||
ent_reg = er.async_get(hass)
|
||||
dev_reg = dr.async_get(hass)
|
||||
|
||||
# Track info by (info.device_id, info.key) to properly handle entities
|
||||
# moving between devices and support sub-devices with overlapping keys
|
||||
for info in infos:
|
||||
new_infos[info.key] = info
|
||||
info_key = (info.device_id, info.key)
|
||||
new_infos[info_key] = info
|
||||
|
||||
# Try to find existing entity - first with current device_id
|
||||
old_info = current_infos.pop(info_key, None)
|
||||
|
||||
# If not found, search for entity with same key but different device_id
|
||||
# This handles the case where entity moved between devices
|
||||
if not old_info:
|
||||
for existing_device_id, existing_key in list(current_infos):
|
||||
if existing_key == info.key:
|
||||
# Found entity with same key but different device_id
|
||||
old_info = current_infos.pop((existing_device_id, existing_key))
|
||||
break
|
||||
|
||||
# Create new entity if it doesn't exist
|
||||
if not (old_info := current_infos.pop(info.key, None)):
|
||||
if not old_info:
|
||||
entity = entity_type(entry_data, platform.domain, info, state_type)
|
||||
add_entities.append(entity)
|
||||
continue
|
||||
@ -78,7 +98,7 @@ def async_static_info_updated(
|
||||
if old_info.device_id == info.device_id:
|
||||
continue
|
||||
|
||||
# Entity has switched devices, need to migrate unique_id
|
||||
# Entity has switched devices, need to migrate unique_id and handle state subscriptions
|
||||
old_unique_id = build_device_unique_id(device_info.mac_address, old_info)
|
||||
entity_id = ent_reg.async_get_entity_id(platform.domain, DOMAIN, old_unique_id)
|
||||
|
||||
@ -103,7 +123,7 @@ def async_static_info_updated(
|
||||
if old_unique_id != new_unique_id:
|
||||
updates["new_unique_id"] = new_unique_id
|
||||
|
||||
# Update device assignment
|
||||
# Update device assignment in registry
|
||||
if info.device_id:
|
||||
# Entity now belongs to a sub device
|
||||
new_device = dev_reg.async_get_device(
|
||||
@ -118,10 +138,32 @@ def async_static_info_updated(
|
||||
if new_device:
|
||||
updates["device_id"] = new_device.id
|
||||
|
||||
# Apply all updates at once
|
||||
# Apply all registry updates at once
|
||||
if updates:
|
||||
ent_reg.async_update_entity(entity_id, **updates)
|
||||
|
||||
# IMPORTANT: The entity's device assignment in Home Assistant is only read when the entity
|
||||
# is first added. Updating the registry alone won't move the entity to the new device
|
||||
# in the UI. Additionally, the entity's state subscription is tied to the old device_id,
|
||||
# so it won't receive state updates for the new device_id.
|
||||
#
|
||||
# We must remove the old entity and re-add it to ensure:
|
||||
# 1. The entity appears under the correct device in the UI
|
||||
# 2. The entity's state subscription is updated to use the new device_id
|
||||
_LOGGER.debug(
|
||||
"Entity %s moving from device_id %s to %s",
|
||||
info.key,
|
||||
old_info.device_id,
|
||||
info.device_id,
|
||||
)
|
||||
|
||||
# Signal the existing entity to remove itself
|
||||
# The entity is registered with the old device_id, so we signal with that
|
||||
entry_data.async_signal_entity_removal(info_type, old_info.device_id, info.key)
|
||||
|
||||
# Create new entity with the new device_id
|
||||
add_entities.append(entity_type(entry_data, platform.domain, info, state_type))
|
||||
|
||||
# Anything still in current_infos is now gone
|
||||
if current_infos:
|
||||
entry_data.async_remove_entities(
|
||||
@ -341,7 +383,10 @@ class EsphomeEntity(EsphomeBaseEntity, Generic[_InfoT, _StateT]):
|
||||
)
|
||||
self.async_on_remove(
|
||||
entry_data.async_subscribe_state_update(
|
||||
self._state_type, self._key, self._on_state_update
|
||||
self._static_info.device_id,
|
||||
self._state_type,
|
||||
self._key,
|
||||
self._on_state_update,
|
||||
)
|
||||
)
|
||||
self.async_on_remove(
|
||||
@ -349,8 +394,29 @@ class EsphomeEntity(EsphomeBaseEntity, Generic[_InfoT, _StateT]):
|
||||
self._static_info, self._on_static_info_update
|
||||
)
|
||||
)
|
||||
# Register to be notified when this entity should remove itself
|
||||
# This happens when the entity moves to a different device
|
||||
self.async_on_remove(
|
||||
entry_data.async_register_entity_removal_callback(
|
||||
type(self._static_info),
|
||||
self._static_info.device_id,
|
||||
self._key,
|
||||
self._on_removal_signal,
|
||||
)
|
||||
)
|
||||
self._update_state_from_entry_data()
|
||||
|
||||
@callback
|
||||
def _on_removal_signal(self) -> None:
|
||||
"""Handle signal to remove this entity."""
|
||||
_LOGGER.debug(
|
||||
"Entity %s received removal signal due to device_id change",
|
||||
self.entity_id,
|
||||
)
|
||||
# Schedule the entity to be removed
|
||||
# This must be done as a task since we're in a callback
|
||||
self.hass.async_create_task(self.async_remove())
|
||||
|
||||
@callback
|
||||
def _on_static_info_update(self, static_info: EntityInfo) -> None:
|
||||
"""Save the static info for this entity when it changes.
|
||||
|
@ -60,7 +60,9 @@ from .const import DOMAIN
|
||||
from .dashboard import async_get_dashboard
|
||||
|
||||
type ESPHomeConfigEntry = ConfigEntry[RuntimeEntryData]
|
||||
|
||||
type EntityStateKey = tuple[type[EntityState], int, int] # (state_type, device_id, key)
|
||||
type EntityInfoKey = tuple[type[EntityInfo], int, int] # (info_type, device_id, key)
|
||||
type DeviceEntityKey = tuple[int, int] # (device_id, key)
|
||||
|
||||
INFO_TO_COMPONENT_TYPE: Final = {v: k for k, v in COMPONENT_TYPE_TO_INFO.items()}
|
||||
|
||||
@ -137,8 +139,10 @@ class RuntimeEntryData:
|
||||
# When the disconnect callback is called, we mark all states
|
||||
# as stale so we will always dispatch a state update when the
|
||||
# device reconnects. This is the same format as state_subscriptions.
|
||||
stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set)
|
||||
info: dict[type[EntityInfo], dict[int, EntityInfo]] = field(default_factory=dict)
|
||||
stale_state: set[EntityStateKey] = field(default_factory=set)
|
||||
info: dict[type[EntityInfo], dict[DeviceEntityKey, EntityInfo]] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
services: dict[int, UserService] = field(default_factory=dict)
|
||||
available: bool = False
|
||||
expected_disconnect: bool = False # Last disconnect was expected (e.g. deep sleep)
|
||||
@ -147,7 +151,7 @@ class RuntimeEntryData:
|
||||
api_version: APIVersion = field(default_factory=APIVersion)
|
||||
cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
|
||||
disconnect_callbacks: set[CALLBACK_TYPE] = field(default_factory=set)
|
||||
state_subscriptions: dict[tuple[type[EntityState], int], CALLBACK_TYPE] = field(
|
||||
state_subscriptions: dict[EntityStateKey, CALLBACK_TYPE] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set)
|
||||
@ -164,7 +168,7 @@ class RuntimeEntryData:
|
||||
type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
|
||||
] = field(default_factory=dict)
|
||||
entity_info_key_updated_callbacks: dict[
|
||||
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
|
||||
EntityInfoKey, list[Callable[[EntityInfo], None]]
|
||||
] = field(default_factory=dict)
|
||||
original_options: dict[str, Any] = field(default_factory=dict)
|
||||
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
|
||||
@ -177,6 +181,9 @@ class RuntimeEntryData:
|
||||
default_factory=list
|
||||
)
|
||||
device_id_to_name: dict[int, str] = field(default_factory=dict)
|
||||
entity_removal_callbacks: dict[EntityInfoKey, list[CALLBACK_TYPE]] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -210,7 +217,7 @@ class RuntimeEntryData:
|
||||
callback_: Callable[[EntityInfo], None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register to receive callbacks when static info is updated for a specific key."""
|
||||
callback_key = (type(static_info), static_info.key)
|
||||
callback_key = (type(static_info), static_info.device_id, static_info.key)
|
||||
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
|
||||
callbacks.append(callback_)
|
||||
return partial(callbacks.remove, callback_)
|
||||
@ -250,7 +257,9 @@ class RuntimeEntryData:
|
||||
"""Call static info updated callbacks."""
|
||||
callbacks = self.entity_info_key_updated_callbacks
|
||||
for static_info in static_infos:
|
||||
for callback_ in callbacks.get((type(static_info), static_info.key), ()):
|
||||
for callback_ in callbacks.get(
|
||||
(type(static_info), static_info.device_id, static_info.key), ()
|
||||
):
|
||||
callback_(static_info)
|
||||
|
||||
async def _ensure_platforms_loaded(
|
||||
@ -342,12 +351,13 @@ class RuntimeEntryData:
|
||||
@callback
|
||||
def async_subscribe_state_update(
|
||||
self,
|
||||
device_id: int,
|
||||
state_type: type[EntityState],
|
||||
state_key: int,
|
||||
entity_callback: CALLBACK_TYPE,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Subscribe to state updates."""
|
||||
subscription_key = (state_type, state_key)
|
||||
subscription_key = (state_type, device_id, state_key)
|
||||
self.state_subscriptions[subscription_key] = entity_callback
|
||||
return partial(delitem, self.state_subscriptions, subscription_key)
|
||||
|
||||
@ -359,7 +369,7 @@ class RuntimeEntryData:
|
||||
stale_state = self.stale_state
|
||||
current_state_by_type = self.state[state_type]
|
||||
current_state = current_state_by_type.get(key, _SENTINEL)
|
||||
subscription_key = (state_type, key)
|
||||
subscription_key = (state_type, state.device_id, key)
|
||||
if (
|
||||
current_state == state
|
||||
and subscription_key not in stale_state
|
||||
@ -367,7 +377,7 @@ class RuntimeEntryData:
|
||||
and not (
|
||||
state_type is SensorState
|
||||
and (platform_info := self.info.get(SensorInfo))
|
||||
and (entity_info := platform_info.get(state.key))
|
||||
and (entity_info := platform_info.get((state.device_id, state.key)))
|
||||
and (cast(SensorInfo, entity_info)).force_update
|
||||
)
|
||||
):
|
||||
@ -520,3 +530,26 @@ class RuntimeEntryData:
|
||||
"""Notify listeners that the Assist satellite wake word has been set."""
|
||||
for callback_ in self.assist_satellite_set_wake_word_callbacks.copy():
|
||||
callback_(wake_word_id)
|
||||
|
||||
@callback
|
||||
def async_register_entity_removal_callback(
|
||||
self,
|
||||
info_type: type[EntityInfo],
|
||||
device_id: int,
|
||||
key: int,
|
||||
callback_: CALLBACK_TYPE,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register to receive a callback when the entity should remove itself."""
|
||||
callback_key = (info_type, device_id, key)
|
||||
callbacks = self.entity_removal_callbacks.setdefault(callback_key, [])
|
||||
callbacks.append(callback_)
|
||||
return partial(callbacks.remove, callback_)
|
||||
|
||||
@callback
|
||||
def async_signal_entity_removal(
|
||||
self, info_type: type[EntityInfo], device_id: int, key: int
|
||||
) -> None:
|
||||
"""Signal that an entity should remove itself."""
|
||||
callback_key = (info_type, device_id, key)
|
||||
for callback_ in self.entity_removal_callbacks.get(callback_key, []).copy():
|
||||
callback_()
|
||||
|
@ -588,7 +588,7 @@ class ESPHomeManager:
|
||||
# Mark state as stale so that we will always dispatch
|
||||
# the next state update of that type when the device reconnects
|
||||
entry_data.stale_state = {
|
||||
(type(entity_state), key)
|
||||
(type(entity_state), entity_state.device_id, key)
|
||||
for state_dict in entry_data.state.values()
|
||||
for key, entity_state in state_dict.items()
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test ESPHome binary sensors."""
|
||||
|
||||
from aioesphomeapi import APIClient, BinarySensorInfo, BinarySensorState
|
||||
from aioesphomeapi import APIClient, BinarySensorInfo, BinarySensorState, SubDeviceInfo
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNKNOWN
|
||||
@ -127,3 +127,157 @@ async def test_binary_sensor_has_state_false(
|
||||
state = hass.states.get("binary_sensor.test_my_binary_sensor")
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
|
||||
|
||||
async def test_binary_sensors_same_key_different_device_id(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: MockESPHomeDeviceType,
|
||||
) -> None:
|
||||
"""Test binary sensors with same key but different device_id."""
|
||||
# Create sub-devices
|
||||
sub_devices = [
|
||||
SubDeviceInfo(device_id=11111111, name="Sub Device 1", area_id=0),
|
||||
SubDeviceInfo(device_id=22222222, name="Sub Device 2", area_id=0),
|
||||
]
|
||||
|
||||
device_info = {
|
||||
"name": "test",
|
||||
"devices": sub_devices,
|
||||
}
|
||||
|
||||
# Both sub-devices have a binary sensor with key=1
|
||||
entity_info = [
|
||||
BinarySensorInfo(
|
||||
object_id="sensor",
|
||||
key=1,
|
||||
name="Motion",
|
||||
unique_id="motion_1",
|
||||
device_id=11111111,
|
||||
),
|
||||
BinarySensorInfo(
|
||||
object_id="sensor",
|
||||
key=1,
|
||||
name="Motion",
|
||||
unique_id="motion_2",
|
||||
device_id=22222222,
|
||||
),
|
||||
]
|
||||
|
||||
# States for both sensors with same key but different device_id
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False, device_id=11111111),
|
||||
BinarySensorState(key=1, state=False, missing_state=False, device_id=22222222),
|
||||
]
|
||||
|
||||
mock_device = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
device_info=device_info,
|
||||
entity_info=entity_info,
|
||||
states=states,
|
||||
)
|
||||
|
||||
# Verify both entities exist and have correct states
|
||||
state1 = hass.states.get("binary_sensor.sub_device_1_motion")
|
||||
assert state1 is not None
|
||||
assert state1.state == STATE_ON
|
||||
|
||||
state2 = hass.states.get("binary_sensor.sub_device_2_motion")
|
||||
assert state2 is not None
|
||||
assert state2.state == STATE_OFF
|
||||
|
||||
# Update states to verify they update independently
|
||||
mock_device.set_state(
|
||||
BinarySensorState(key=1, state=False, missing_state=False, device_id=11111111)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state1 = hass.states.get("binary_sensor.sub_device_1_motion")
|
||||
assert state1.state == STATE_OFF
|
||||
|
||||
# Sub device 2 should remain unchanged
|
||||
state2 = hass.states.get("binary_sensor.sub_device_2_motion")
|
||||
assert state2.state == STATE_OFF
|
||||
|
||||
# Update sub device 2
|
||||
mock_device.set_state(
|
||||
BinarySensorState(key=1, state=True, missing_state=False, device_id=22222222)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state2 = hass.states.get("binary_sensor.sub_device_2_motion")
|
||||
assert state2.state == STATE_ON
|
||||
|
||||
# Sub device 1 should remain unchanged
|
||||
state1 = hass.states.get("binary_sensor.sub_device_1_motion")
|
||||
assert state1.state == STATE_OFF
|
||||
|
||||
|
||||
async def test_binary_sensor_main_and_sub_device_same_key(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: MockESPHomeDeviceType,
|
||||
) -> None:
|
||||
"""Test binary sensor on main device and sub-device with same key."""
|
||||
# Create sub-device
|
||||
sub_devices = [
|
||||
SubDeviceInfo(device_id=11111111, name="Sub Device", area_id=0),
|
||||
]
|
||||
|
||||
device_info = {
|
||||
"name": "test",
|
||||
"devices": sub_devices,
|
||||
}
|
||||
|
||||
# Main device and sub-device both have a binary sensor with key=1
|
||||
entity_info = [
|
||||
BinarySensorInfo(
|
||||
object_id="main_sensor",
|
||||
key=1,
|
||||
name="Main Sensor",
|
||||
unique_id="main_1",
|
||||
device_id=0, # Main device
|
||||
),
|
||||
BinarySensorInfo(
|
||||
object_id="sub_sensor",
|
||||
key=1,
|
||||
name="Sub Sensor",
|
||||
unique_id="sub_1",
|
||||
device_id=11111111,
|
||||
),
|
||||
]
|
||||
|
||||
# States for both sensors
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False, device_id=0),
|
||||
BinarySensorState(key=1, state=False, missing_state=False, device_id=11111111),
|
||||
]
|
||||
|
||||
mock_device = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
device_info=device_info,
|
||||
entity_info=entity_info,
|
||||
states=states,
|
||||
)
|
||||
|
||||
# Verify both entities exist
|
||||
main_state = hass.states.get("binary_sensor.test_main_sensor")
|
||||
assert main_state is not None
|
||||
assert main_state.state == STATE_ON
|
||||
|
||||
sub_state = hass.states.get("binary_sensor.sub_device_sub_sensor")
|
||||
assert sub_state is not None
|
||||
assert sub_state.state == STATE_OFF
|
||||
|
||||
# Update main device sensor
|
||||
mock_device.set_state(
|
||||
BinarySensorState(key=1, state=False, missing_state=False, device_id=0)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
main_state = hass.states.get("binary_sensor.test_main_sensor")
|
||||
assert main_state.state == STATE_OFF
|
||||
|
||||
# Sub device sensor should remain unchanged
|
||||
sub_state = hass.states.get("binary_sensor.sub_device_sub_sensor")
|
||||
assert sub_state.state == STATE_OFF
|
||||
|
@ -754,9 +754,9 @@ async def test_entity_assignment_to_sub_device(
|
||||
]
|
||||
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False),
|
||||
BinarySensorState(key=2, state=False, missing_state=False),
|
||||
BinarySensorState(key=3, state=True, missing_state=False),
|
||||
BinarySensorState(key=1, state=True, missing_state=False, device_id=0),
|
||||
BinarySensorState(key=2, state=False, missing_state=False, device_id=11111111),
|
||||
BinarySensorState(key=3, state=True, missing_state=False, device_id=22222222),
|
||||
]
|
||||
|
||||
device = await mock_esphome_device(
|
||||
@ -938,7 +938,7 @@ async def test_entity_switches_between_devices(
|
||||
]
|
||||
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False),
|
||||
BinarySensorState(key=1, state=True, missing_state=False, device_id=0),
|
||||
]
|
||||
|
||||
device = await mock_esphome_device(
|
||||
@ -1507,7 +1507,7 @@ async def test_entity_device_id_rename_in_yaml(
|
||||
]
|
||||
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False),
|
||||
BinarySensorState(key=1, state=True, missing_state=False, device_id=11111111),
|
||||
]
|
||||
|
||||
device = await mock_esphome_device(
|
||||
|
Loading…
x
Reference in New Issue
Block a user