Fix ESPHome not fully removing entities when entity info changes (#108823)

This commit is contained in:
J. Nick Koston 2024-01-24 17:29:11 -10:00 committed by GitHub
parent 7f56330e3b
commit d588ec8202
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 196 additions and 36 deletions

View File

@ -37,6 +37,51 @@ _EntityT = TypeVar("_EntityT", bound="EsphomeEntity[Any,Any]")
_StateT = TypeVar("_StateT", bound=EntityState) _StateT = TypeVar("_StateT", bound=EntityState)
@callback
def async_static_info_updated(
hass: HomeAssistant,
entry_data: RuntimeEntryData,
platform: entity_platform.EntityPlatform,
async_add_entities: AddEntitiesCallback,
info_type: type[_InfoT],
entity_type: type[_EntityT],
state_type: type[_StateT],
infos: list[EntityInfo],
) -> None:
"""Update entities of this platform when entities are listed."""
current_infos = entry_data.info[info_type]
new_infos: dict[int, EntityInfo] = {}
add_entities: list[_EntityT] = []
for info in infos:
if not current_infos.pop(info.key, None):
# Create new entity
entity = entity_type(entry_data, platform.domain, info, state_type)
add_entities.append(entity)
new_infos[info.key] = info
# Anything still in current_infos is now gone
if current_infos:
device_info = entry_data.device_info
if TYPE_CHECKING:
assert device_info is not None
hass.async_create_task(
entry_data.async_remove_entities(
hass, current_infos.values(), device_info.mac_address
)
)
# Then update the actual info
entry_data.info[info_type] = new_infos
if new_infos:
entry_data.async_update_entity_infos(new_infos.values())
if add_entities:
# Add entities to Home Assistant
async_add_entities(add_entities)
async def platform_async_setup_entry( async def platform_async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
entry: ConfigEntry, entry: ConfigEntry,
@ -55,39 +100,21 @@ async def platform_async_setup_entry(
entry_data.info[info_type] = {} entry_data.info[info_type] = {}
entry_data.state.setdefault(state_type, {}) entry_data.state.setdefault(state_type, {})
platform = entity_platform.async_get_current_platform() platform = entity_platform.async_get_current_platform()
on_static_info_update = functools.partial(
@callback async_static_info_updated,
def async_list_entities(infos: list[EntityInfo]) -> None: hass,
"""Update entities of this platform when entities are listed.""" entry_data,
current_infos = entry_data.info[info_type] platform,
new_infos: dict[int, EntityInfo] = {} async_add_entities,
add_entities: list[_EntityT] = [] info_type,
entity_type,
for info in infos: state_type,
if not current_infos.pop(info.key, None):
# Create new entity
entity = entity_type(entry_data, platform.domain, info, state_type)
add_entities.append(entity)
new_infos[info.key] = info
# Anything still in current_infos is now gone
if current_infos:
hass.async_create_task(
entry_data.async_remove_entities(current_infos.values())
) )
# Then update the actual info
entry_data.info[info_type] = new_infos
if new_infos:
entry_data.async_update_entity_infos(new_infos.values())
if add_entities:
# Add entities to Home Assistant
async_add_entities(add_entities)
entry_data.cleanup_callbacks.append( entry_data.cleanup_callbacks.append(
entry_data.async_register_static_info_callback(info_type, async_list_entities) entry_data.async_register_static_info_callback(
info_type,
on_static_info_update,
)
) )

View File

@ -243,8 +243,18 @@ class RuntimeEntryData:
"""Unsubscribe to assist pipeline updates.""" """Unsubscribe to assist pipeline updates."""
self.assist_pipeline_update_callbacks.remove(update_callback) self.assist_pipeline_update_callbacks.remove(update_callback)
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None: async def async_remove_entities(
self, hass: HomeAssistant, static_infos: Iterable[EntityInfo], mac: str
) -> None:
"""Schedule the removal of an entity.""" """Schedule the removal of an entity."""
# Remove from entity registry first so the entity is fully removed
ent_reg = er.async_get(hass)
for info in static_infos:
if entry := ent_reg.async_get_entity_id(
INFO_TYPE_TO_PLATFORM[type(info)], DOMAIN, build_unique_id(mac, info)
):
ent_reg.async_remove(entry)
callbacks: list[Coroutine[Any, Any, None]] = [] callbacks: list[Coroutine[Any, Any, None]] = []
for static_info in static_infos: for static_info in static_infos:
callback_key = (type(static_info), static_info.key) callback_key = (type(static_info), static_info.key)

View File

@ -177,9 +177,10 @@ async def mock_dashboard(hass):
class MockESPHomeDevice: class MockESPHomeDevice:
"""Mock an esphome device.""" """Mock an esphome device."""
def __init__(self, entry: MockConfigEntry) -> None: def __init__(self, entry: MockConfigEntry, client: APIClient) -> None:
"""Init the mock.""" """Init the mock."""
self.entry = entry self.entry = entry
self.client = client
self.state_callback: Callable[[EntityState], None] self.state_callback: Callable[[EntityState], None]
self.service_call_callback: Callable[[HomeassistantServiceCall], None] self.service_call_callback: Callable[[HomeassistantServiceCall], None]
self.on_disconnect: Callable[[bool], None] self.on_disconnect: Callable[[bool], None]
@ -258,7 +259,7 @@ async def _mock_generic_device_entry(
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
mock_device = MockESPHomeDevice(entry) mock_device = MockESPHomeDevice(entry, mock_client)
default_device_info = { default_device_info = {
"name": "test", "name": "test",

View File

@ -1,6 +1,7 @@
"""Test ESPHome binary sensors.""" """Test ESPHome binary sensors."""
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from unittest.mock import AsyncMock
from aioesphomeapi import ( from aioesphomeapi import (
APIClient, APIClient,
@ -21,6 +22,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from .conftest import MockESPHomeDevice from .conftest import MockESPHomeDevice
@ -34,7 +36,8 @@ async def test_entities_removed(
Awaitable[MockESPHomeDevice], Awaitable[MockESPHomeDevice],
], ],
) -> None: ) -> None:
"""Test a generic binary_sensor where has_state is false.""" """Test entities are removed when static info changes."""
ent_reg = er.async_get(hass)
entity_info = [ entity_info = [
BinarySensorInfo( BinarySensorInfo(
object_id="mybinary_sensor", object_id="mybinary_sensor",
@ -80,6 +83,8 @@ async def test_entities_removed(
assert state.attributes[ATTR_RESTORED] is True assert state.attributes[ATTR_RESTORED] is True
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None assert state is not None
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert state.attributes[ATTR_RESTORED] is True assert state.attributes[ATTR_RESTORED] is True
entity_info = [ entity_info = [
@ -106,11 +111,128 @@ async def test_entities_removed(
assert state.state == STATE_ON assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is None assert state is None
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is None
await hass.config_entries.async_unload(entry.entry_id) await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1 assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1
async def test_entities_removed_after_reload(
hass: HomeAssistant,
mock_client: APIClient,
hass_storage: dict[str, Any],
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test entities and their registry entry are removed when static info changes after a reload."""
ent_reg = er.async_get(hass)
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
BinarySensorInfo(
object_id="mybinary_sensor_to_be_removed",
key=2,
name="my binary_sensor to be removed",
unique_id="mybinary_sensor_to_be_removed",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
BinarySensorState(key=2, state=True, missing_state=False),
]
user_service = []
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
)
entry = mock_device.entry
entry_id = entry.entry_id
storage_key = f"esphome.{entry_id}"
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None
assert state.state == STATE_ON
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert ATTR_RESTORED not in state.attributes
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None
assert ATTR_RESTORED not in state.attributes
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
]
mock_device.client.list_entities_services = AsyncMock(
return_value=(entity_info, user_service)
)
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert mock_device.entry.entry_id == entry_id
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is None
await hass.async_block_till_done()
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is None
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1
async def test_entity_info_object_ids( async def test_entity_info_object_ids(
hass: HomeAssistant, hass: HomeAssistant,
mock_client: APIClient, mock_client: APIClient,