Prevent polling from recreating an entity after removal (#67750)

This commit is contained in:
J. Nick Koston 2022-03-08 05:42:16 +01:00 committed by GitHub
parent aed2c1cce8
commit a75bbc79a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 6 deletions

View File

@ -6,6 +6,7 @@ import asyncio
from collections.abc import Awaitable, Iterable, Mapping, MutableMapping from collections.abc import Awaitable, Iterable, Mapping, MutableMapping
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum, auto
import functools as ft import functools as ft
import logging import logging
import math import math
@ -207,6 +208,19 @@ class EntityCategory(StrEnum):
SYSTEM = "system" SYSTEM = "system"
class EntityPlatformState(Enum):
"""The platform state of an entity."""
# Not Added: Not yet added to a platform, polling updates are written to the state machine
NOT_ADDED = auto()
# Added: Added to a platform, polling updates are written to the state machine
ADDED = auto()
# Removed: Removed from a platform, polling updates are not written to the state machine
REMOVED = auto()
def convert_to_entity_category( def convert_to_entity_category(
value: EntityCategory | str | None, raise_report: bool = True value: EntityCategory | str | None, raise_report: bool = True
) -> EntityCategory | None: ) -> EntityCategory | None:
@ -294,7 +308,7 @@ class Entity(ABC):
_context_set: datetime | None = None _context_set: datetime | None = None
# If entity is added to an entity platform # If entity is added to an entity platform
_added = False _platform_state = EntityPlatformState.NOT_ADDED
# Entity Properties # Entity Properties
_attr_assumed_state: bool = False _attr_assumed_state: bool = False
@ -553,6 +567,10 @@ class Entity(ABC):
@callback @callback
def _async_write_ha_state(self) -> None: def _async_write_ha_state(self) -> None:
"""Write the state to the state machine.""" """Write the state to the state machine."""
if self._platform_state == EntityPlatformState.REMOVED:
# Polling returned after the entity has already been removed
return
if self.registry_entry and self.registry_entry.disabled_by: if self.registry_entry and self.registry_entry.disabled_by:
if not self._disabled_reported: if not self._disabled_reported:
self._disabled_reported = True self._disabled_reported = True
@ -758,7 +776,7 @@ class Entity(ABC):
parallel_updates: asyncio.Semaphore | None, parallel_updates: asyncio.Semaphore | None,
) -> None: ) -> None:
"""Start adding an entity to a platform.""" """Start adding an entity to a platform."""
if self._added: if self._platform_state == EntityPlatformState.ADDED:
raise HomeAssistantError( raise HomeAssistantError(
f"Entity {self.entity_id} cannot be added a second time to an entity platform" f"Entity {self.entity_id} cannot be added a second time to an entity platform"
) )
@ -766,7 +784,7 @@ class Entity(ABC):
self.hass = hass self.hass = hass
self.platform = platform self.platform = platform
self.parallel_updates = parallel_updates self.parallel_updates = parallel_updates
self._added = True self._platform_state = EntityPlatformState.ADDED
@callback @callback
def add_to_platform_abort(self) -> None: def add_to_platform_abort(self) -> None:
@ -774,7 +792,7 @@ class Entity(ABC):
self.hass = None # type: ignore[assignment] self.hass = None # type: ignore[assignment]
self.platform = None self.platform = None
self.parallel_updates = None self.parallel_updates = None
self._added = False self._platform_state = EntityPlatformState.NOT_ADDED
async def add_to_platform_finish(self) -> None: async def add_to_platform_finish(self) -> None:
"""Finish adding an entity to a platform.""" """Finish adding an entity to a platform."""
@ -792,12 +810,12 @@ class Entity(ABC):
If the entity doesn't have a non disabled entry in the entity registry, If the entity doesn't have a non disabled entry in the entity registry,
or if force_remove=True, its state will be removed. or if force_remove=True, its state will be removed.
""" """
if self.platform and not self._added: if self.platform and self._platform_state != EntityPlatformState.ADDED:
raise HomeAssistantError( raise HomeAssistantError(
f"Entity {self.entity_id} async_remove called twice" f"Entity {self.entity_id} async_remove called twice"
) )
self._added = False self._platform_state = EntityPlatformState.REMOVED
if self._on_remove is not None: if self._on_remove is not None:
while self._on_remove: while self._on_remove:

View File

@ -545,6 +545,22 @@ async def test_async_remove_runs_callbacks(hass):
assert len(result) == 1 assert len(result) == 1
async def test_async_remove_ignores_in_flight_polling(hass):
"""Test in flight polling is ignored after removing."""
result = []
ent = entity.Entity()
ent.hass = hass
ent.entity_id = "test.test"
ent.async_on_remove(lambda: result.append(1))
ent.async_write_ha_state()
assert hass.states.get("test.test").state == STATE_UNKNOWN
await ent.async_remove()
assert len(result) == 1
assert hass.states.get("test.test") is None
ent.async_write_ha_state()
async def test_set_context(hass): async def test_set_context(hass):
"""Test setting context.""" """Test setting context."""
context = Context() context = Context()

View File

@ -390,6 +390,30 @@ async def test_async_remove_with_platform(hass):
assert len(hass.states.async_entity_ids()) == 0 assert len(hass.states.async_entity_ids()) == 0
async def test_async_remove_with_platform_update_finishes(hass):
"""Remove an entity when an update finishes after its been removed."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = MockEntity(name="test_1")
async def _delayed_update(*args, **kwargs):
await asyncio.sleep(0.01)
entity1.async_update = _delayed_update
# Add, remove, add, remove and make sure no updates
# cause the entity to reappear after removal
for i in range(2):
await component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1
entity1.async_write_ha_state()
assert hass.states.get(entity1.entity_id) is not None
task = asyncio.create_task(entity1.async_update_ha_state(True))
await entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0
await task
assert len(hass.states.async_entity_ids()) == 0
async def test_not_adding_duplicate_entities_with_unique_id(hass, caplog): async def test_not_adding_duplicate_entities_with_unique_id(hass, caplog):
"""Test for not adding duplicate entities.""" """Test for not adding duplicate entities."""
caplog.set_level(logging.ERROR) caplog.set_level(logging.ERROR)