diff --git a/homeassistant/components/esphome/entity.py b/homeassistant/components/esphome/entity.py index b9f0125094a..a6267ba17a5 100644 --- a/homeassistant/components/esphome/entity.py +++ b/homeassistant/components/esphome/entity.py @@ -33,7 +33,12 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN # Import config flow so that it's added to the registry -from .entry_data import ESPHomeConfigEntry, RuntimeEntryData, build_device_unique_id +from .entry_data import ( + DeviceEntityKey, + ESPHomeConfigEntry, + RuntimeEntryData, + build_device_unique_id, +) from .enum_mapper import EsphomeEnumMapper _LOGGER = logging.getLogger(__name__) @@ -59,17 +64,32 @@ def async_static_info_updated( device_info = entry_data.device_info if TYPE_CHECKING: assert device_info is not None - new_infos: dict[int, EntityInfo] = {} + new_infos: dict[DeviceEntityKey, EntityInfo] = {} add_entities: list[_EntityT] = [] ent_reg = er.async_get(hass) dev_reg = dr.async_get(hass) + # Track info by (info.device_id, info.key) to properly handle entities + # moving between devices and support sub-devices with overlapping keys for info in infos: - new_infos[info.key] = info + info_key = (info.device_id, info.key) + new_infos[info_key] = info + + # Try to find existing entity - first with current device_id + old_info = current_infos.pop(info_key, None) + + # If not found, search for entity with same key but different device_id + # This handles the case where entity moved between devices + if not old_info: + for existing_device_id, existing_key in list(current_infos): + if existing_key == info.key: + # Found entity with same key but different device_id + old_info = current_infos.pop((existing_device_id, existing_key)) + break # Create new entity if it doesn't exist - if not (old_info := current_infos.pop(info.key, None)): + if not old_info: entity = entity_type(entry_data, platform.domain, info, state_type) add_entities.append(entity) continue @@ -78,7 +98,7 @@ def async_static_info_updated( if old_info.device_id == info.device_id: continue - # Entity has switched devices, need to migrate unique_id + # Entity has switched devices, need to migrate unique_id and handle state subscriptions old_unique_id = build_device_unique_id(device_info.mac_address, old_info) entity_id = ent_reg.async_get_entity_id(platform.domain, DOMAIN, old_unique_id) @@ -103,7 +123,7 @@ def async_static_info_updated( if old_unique_id != new_unique_id: updates["new_unique_id"] = new_unique_id - # Update device assignment + # Update device assignment in registry if info.device_id: # Entity now belongs to a sub device new_device = dev_reg.async_get_device( @@ -118,10 +138,32 @@ def async_static_info_updated( if new_device: updates["device_id"] = new_device.id - # Apply all updates at once + # Apply all registry updates at once if updates: ent_reg.async_update_entity(entity_id, **updates) + # IMPORTANT: The entity's device assignment in Home Assistant is only read when the entity + # is first added. Updating the registry alone won't move the entity to the new device + # in the UI. Additionally, the entity's state subscription is tied to the old device_id, + # so it won't receive state updates for the new device_id. + # + # We must remove the old entity and re-add it to ensure: + # 1. The entity appears under the correct device in the UI + # 2. The entity's state subscription is updated to use the new device_id + _LOGGER.debug( + "Entity %s moving from device_id %s to %s", + info.key, + old_info.device_id, + info.device_id, + ) + + # Signal the existing entity to remove itself + # The entity is registered with the old device_id, so we signal with that + entry_data.async_signal_entity_removal(info_type, old_info.device_id, info.key) + + # Create new entity with the new device_id + add_entities.append(entity_type(entry_data, platform.domain, info, state_type)) + # Anything still in current_infos is now gone if current_infos: entry_data.async_remove_entities( @@ -341,7 +383,10 @@ class EsphomeEntity(EsphomeBaseEntity, Generic[_InfoT, _StateT]): ) self.async_on_remove( entry_data.async_subscribe_state_update( - self._state_type, self._key, self._on_state_update + self._static_info.device_id, + self._state_type, + self._key, + self._on_state_update, ) ) self.async_on_remove( @@ -349,8 +394,29 @@ class EsphomeEntity(EsphomeBaseEntity, Generic[_InfoT, _StateT]): self._static_info, self._on_static_info_update ) ) + # Register to be notified when this entity should remove itself + # This happens when the entity moves to a different device + self.async_on_remove( + entry_data.async_register_entity_removal_callback( + type(self._static_info), + self._static_info.device_id, + self._key, + self._on_removal_signal, + ) + ) self._update_state_from_entry_data() + @callback + def _on_removal_signal(self) -> None: + """Handle signal to remove this entity.""" + _LOGGER.debug( + "Entity %s received removal signal due to device_id change", + self.entity_id, + ) + # Schedule the entity to be removed + # This must be done as a task since we're in a callback + self.hass.async_create_task(self.async_remove()) + @callback def _on_static_info_update(self, static_info: EntityInfo) -> None: """Save the static info for this entity when it changes. diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 71680873611..dddbb598a57 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -60,7 +60,9 @@ from .const import DOMAIN from .dashboard import async_get_dashboard type ESPHomeConfigEntry = ConfigEntry[RuntimeEntryData] - +type EntityStateKey = tuple[type[EntityState], int, int] # (state_type, device_id, key) +type EntityInfoKey = tuple[type[EntityInfo], int, int] # (info_type, device_id, key) +type DeviceEntityKey = tuple[int, int] # (device_id, key) INFO_TO_COMPONENT_TYPE: Final = {v: k for k, v in COMPONENT_TYPE_TO_INFO.items()} @@ -137,8 +139,10 @@ class RuntimeEntryData: # When the disconnect callback is called, we mark all states # as stale so we will always dispatch a state update when the # device reconnects. This is the same format as state_subscriptions. - stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set) - info: dict[type[EntityInfo], dict[int, EntityInfo]] = field(default_factory=dict) + stale_state: set[EntityStateKey] = field(default_factory=set) + info: dict[type[EntityInfo], dict[DeviceEntityKey, EntityInfo]] = field( + default_factory=dict + ) services: dict[int, UserService] = field(default_factory=dict) available: bool = False expected_disconnect: bool = False # Last disconnect was expected (e.g. deep sleep) @@ -147,7 +151,7 @@ class RuntimeEntryData: api_version: APIVersion = field(default_factory=APIVersion) cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list) disconnect_callbacks: set[CALLBACK_TYPE] = field(default_factory=set) - state_subscriptions: dict[tuple[type[EntityState], int], CALLBACK_TYPE] = field( + state_subscriptions: dict[EntityStateKey, CALLBACK_TYPE] = field( default_factory=dict ) device_update_subscriptions: set[CALLBACK_TYPE] = field(default_factory=set) @@ -164,7 +168,7 @@ class RuntimeEntryData: type[EntityInfo], list[Callable[[list[EntityInfo]], None]] ] = field(default_factory=dict) entity_info_key_updated_callbacks: dict[ - tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]] + EntityInfoKey, list[Callable[[EntityInfo], None]] ] = field(default_factory=dict) original_options: dict[str, Any] = field(default_factory=dict) media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field( @@ -177,6 +181,9 @@ class RuntimeEntryData: default_factory=list ) device_id_to_name: dict[int, str] = field(default_factory=dict) + entity_removal_callbacks: dict[EntityInfoKey, list[CALLBACK_TYPE]] = field( + default_factory=dict + ) @property def name(self) -> str: @@ -210,7 +217,7 @@ class RuntimeEntryData: callback_: Callable[[EntityInfo], None], ) -> CALLBACK_TYPE: """Register to receive callbacks when static info is updated for a specific key.""" - callback_key = (type(static_info), static_info.key) + callback_key = (type(static_info), static_info.device_id, static_info.key) callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, []) callbacks.append(callback_) return partial(callbacks.remove, callback_) @@ -250,7 +257,9 @@ class RuntimeEntryData: """Call static info updated callbacks.""" callbacks = self.entity_info_key_updated_callbacks for static_info in static_infos: - for callback_ in callbacks.get((type(static_info), static_info.key), ()): + for callback_ in callbacks.get( + (type(static_info), static_info.device_id, static_info.key), () + ): callback_(static_info) async def _ensure_platforms_loaded( @@ -342,12 +351,13 @@ class RuntimeEntryData: @callback def async_subscribe_state_update( self, + device_id: int, state_type: type[EntityState], state_key: int, entity_callback: CALLBACK_TYPE, ) -> CALLBACK_TYPE: """Subscribe to state updates.""" - subscription_key = (state_type, state_key) + subscription_key = (state_type, device_id, state_key) self.state_subscriptions[subscription_key] = entity_callback return partial(delitem, self.state_subscriptions, subscription_key) @@ -359,7 +369,7 @@ class RuntimeEntryData: stale_state = self.stale_state current_state_by_type = self.state[state_type] current_state = current_state_by_type.get(key, _SENTINEL) - subscription_key = (state_type, key) + subscription_key = (state_type, state.device_id, key) if ( current_state == state and subscription_key not in stale_state @@ -367,7 +377,7 @@ class RuntimeEntryData: and not ( state_type is SensorState and (platform_info := self.info.get(SensorInfo)) - and (entity_info := platform_info.get(state.key)) + and (entity_info := platform_info.get((state.device_id, state.key))) and (cast(SensorInfo, entity_info)).force_update ) ): @@ -520,3 +530,26 @@ class RuntimeEntryData: """Notify listeners that the Assist satellite wake word has been set.""" for callback_ in self.assist_satellite_set_wake_word_callbacks.copy(): callback_(wake_word_id) + + @callback + def async_register_entity_removal_callback( + self, + info_type: type[EntityInfo], + device_id: int, + key: int, + callback_: CALLBACK_TYPE, + ) -> CALLBACK_TYPE: + """Register to receive a callback when the entity should remove itself.""" + callback_key = (info_type, device_id, key) + callbacks = self.entity_removal_callbacks.setdefault(callback_key, []) + callbacks.append(callback_) + return partial(callbacks.remove, callback_) + + @callback + def async_signal_entity_removal( + self, info_type: type[EntityInfo], device_id: int, key: int + ) -> None: + """Signal that an entity should remove itself.""" + callback_key = (info_type, device_id, key) + for callback_ in self.entity_removal_callbacks.get(callback_key, []).copy(): + callback_() diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index 6c2da31e48b..5e9e11171af 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -588,7 +588,7 @@ class ESPHomeManager: # Mark state as stale so that we will always dispatch # the next state update of that type when the device reconnects entry_data.stale_state = { - (type(entity_state), key) + (type(entity_state), entity_state.device_id, key) for state_dict in entry_data.state.values() for key, entity_state in state_dict.items() } diff --git a/tests/components/esphome/test_binary_sensor.py b/tests/components/esphome/test_binary_sensor.py index d2cab36c672..d6e94e61766 100644 --- a/tests/components/esphome/test_binary_sensor.py +++ b/tests/components/esphome/test_binary_sensor.py @@ -1,6 +1,6 @@ """Test ESPHome binary sensors.""" -from aioesphomeapi import APIClient, BinarySensorInfo, BinarySensorState +from aioesphomeapi import APIClient, BinarySensorInfo, BinarySensorState, SubDeviceInfo import pytest from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNKNOWN @@ -127,3 +127,157 @@ async def test_binary_sensor_has_state_false( state = hass.states.get("binary_sensor.test_my_binary_sensor") assert state is not None assert state.state == STATE_ON + + +async def test_binary_sensors_same_key_different_device_id( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, +) -> None: + """Test binary sensors with same key but different device_id.""" + # Create sub-devices + sub_devices = [ + SubDeviceInfo(device_id=11111111, name="Sub Device 1", area_id=0), + SubDeviceInfo(device_id=22222222, name="Sub Device 2", area_id=0), + ] + + device_info = { + "name": "test", + "devices": sub_devices, + } + + # Both sub-devices have a binary sensor with key=1 + entity_info = [ + BinarySensorInfo( + object_id="sensor", + key=1, + name="Motion", + unique_id="motion_1", + device_id=11111111, + ), + BinarySensorInfo( + object_id="sensor", + key=1, + name="Motion", + unique_id="motion_2", + device_id=22222222, + ), + ] + + # States for both sensors with same key but different device_id + states = [ + BinarySensorState(key=1, state=True, missing_state=False, device_id=11111111), + BinarySensorState(key=1, state=False, missing_state=False, device_id=22222222), + ] + + mock_device = await mock_esphome_device( + mock_client=mock_client, + device_info=device_info, + entity_info=entity_info, + states=states, + ) + + # Verify both entities exist and have correct states + state1 = hass.states.get("binary_sensor.sub_device_1_motion") + assert state1 is not None + assert state1.state == STATE_ON + + state2 = hass.states.get("binary_sensor.sub_device_2_motion") + assert state2 is not None + assert state2.state == STATE_OFF + + # Update states to verify they update independently + mock_device.set_state( + BinarySensorState(key=1, state=False, missing_state=False, device_id=11111111) + ) + await hass.async_block_till_done() + + state1 = hass.states.get("binary_sensor.sub_device_1_motion") + assert state1.state == STATE_OFF + + # Sub device 2 should remain unchanged + state2 = hass.states.get("binary_sensor.sub_device_2_motion") + assert state2.state == STATE_OFF + + # Update sub device 2 + mock_device.set_state( + BinarySensorState(key=1, state=True, missing_state=False, device_id=22222222) + ) + await hass.async_block_till_done() + + state2 = hass.states.get("binary_sensor.sub_device_2_motion") + assert state2.state == STATE_ON + + # Sub device 1 should remain unchanged + state1 = hass.states.get("binary_sensor.sub_device_1_motion") + assert state1.state == STATE_OFF + + +async def test_binary_sensor_main_and_sub_device_same_key( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, +) -> None: + """Test binary sensor on main device and sub-device with same key.""" + # Create sub-device + sub_devices = [ + SubDeviceInfo(device_id=11111111, name="Sub Device", area_id=0), + ] + + device_info = { + "name": "test", + "devices": sub_devices, + } + + # Main device and sub-device both have a binary sensor with key=1 + entity_info = [ + BinarySensorInfo( + object_id="main_sensor", + key=1, + name="Main Sensor", + unique_id="main_1", + device_id=0, # Main device + ), + BinarySensorInfo( + object_id="sub_sensor", + key=1, + name="Sub Sensor", + unique_id="sub_1", + device_id=11111111, + ), + ] + + # States for both sensors + states = [ + BinarySensorState(key=1, state=True, missing_state=False, device_id=0), + BinarySensorState(key=1, state=False, missing_state=False, device_id=11111111), + ] + + mock_device = await mock_esphome_device( + mock_client=mock_client, + device_info=device_info, + entity_info=entity_info, + states=states, + ) + + # Verify both entities exist + main_state = hass.states.get("binary_sensor.test_main_sensor") + assert main_state is not None + assert main_state.state == STATE_ON + + sub_state = hass.states.get("binary_sensor.sub_device_sub_sensor") + assert sub_state is not None + assert sub_state.state == STATE_OFF + + # Update main device sensor + mock_device.set_state( + BinarySensorState(key=1, state=False, missing_state=False, device_id=0) + ) + await hass.async_block_till_done() + + main_state = hass.states.get("binary_sensor.test_main_sensor") + assert main_state.state == STATE_OFF + + # Sub device sensor should remain unchanged + sub_state = hass.states.get("binary_sensor.sub_device_sub_sensor") + assert sub_state.state == STATE_OFF diff --git a/tests/components/esphome/test_entity.py b/tests/components/esphome/test_entity.py index ba6a82bbd23..f364e1f528f 100644 --- a/tests/components/esphome/test_entity.py +++ b/tests/components/esphome/test_entity.py @@ -754,9 +754,9 @@ async def test_entity_assignment_to_sub_device( ] states = [ - BinarySensorState(key=1, state=True, missing_state=False), - BinarySensorState(key=2, state=False, missing_state=False), - BinarySensorState(key=3, state=True, missing_state=False), + BinarySensorState(key=1, state=True, missing_state=False, device_id=0), + BinarySensorState(key=2, state=False, missing_state=False, device_id=11111111), + BinarySensorState(key=3, state=True, missing_state=False, device_id=22222222), ] device = await mock_esphome_device( @@ -938,7 +938,7 @@ async def test_entity_switches_between_devices( ] states = [ - BinarySensorState(key=1, state=True, missing_state=False), + BinarySensorState(key=1, state=True, missing_state=False, device_id=0), ] device = await mock_esphome_device( @@ -1507,7 +1507,7 @@ async def test_entity_device_id_rename_in_yaml( ] states = [ - BinarySensorState(key=1, state=True, missing_state=False), + BinarySensorState(key=1, state=True, missing_state=False, device_id=11111111), ] device = await mock_esphome_device(