Fix state overwrite race condition where two platforms request the same entity_id (#42151)

* Fix state overwrite race condition where two platforms request the same entity id

* fix test

* create reservations instead

* revert

* cannot use __slots__ because we patch async_all
This commit is contained in:
J. Nick Koston 2020-10-21 10:01:51 -05:00 committed by GitHub
parent bb641c23a9
commit df2ede6522
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 119 additions and 8 deletions

View File

@ -355,6 +355,9 @@ ATTR_STATE = "state"
ATTR_EDITABLE = "editable"
ATTR_OPTION = "option"
# The entity has been restored with restore state
ATTR_RESTORED = "restored"
# Bitfield of supported component features for the entity
ATTR_SUPPORTED_FEATURES = "supported_features"

View File

@ -973,6 +973,7 @@ class StateMachine:
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
"""Initialize state machine."""
self._states: Dict[str, State] = {}
self._reservations: Set[str] = set()
self._bus = bus
self._loop = loop
@ -1080,6 +1081,9 @@ class StateMachine:
entity_id = entity_id.lower()
old_state = self._states.pop(entity_id, None)
if entity_id in self._reservations:
self._reservations.remove(entity_id)
if old_state is None:
return False
@ -1116,6 +1120,29 @@ class StateMachine:
context,
).result()
@callback
def async_reserve(self, entity_id: str) -> None:
"""Reserve a state in the state machine for an entity being added.
This must not fire an event when the state is reserved.
This avoids a race condition where multiple entities with the same
entity_id are added.
"""
entity_id = entity_id.lower()
if entity_id in self._states or entity_id in self._reservations:
raise HomeAssistantError(
"async_reserve must not be called once the state is in the state machine."
)
self._reservations.add(entity_id)
@callback
def async_available(self, entity_id: str) -> bool:
"""Check to see if an entity_id is available to be used."""
entity_id = entity_id.lower()
return entity_id not in self._states and entity_id not in self._reservations
@callback
def async_set(
self,

View File

@ -77,7 +77,7 @@ def async_generate_entity_id(
test_string = preferred_string
tries = 1
while hass.states.get(test_string):
while not hass.states.async_available(test_string):
tries += 1
test_string = f"{preferred_string}_{tries}"

View File

@ -7,7 +7,7 @@ from types import ModuleType
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional
from homeassistant import config_entries
from homeassistant.const import DEVICE_DEFAULT_NAME
from homeassistant.const import ATTR_RESTORED, DEVICE_DEFAULT_NAME
from homeassistant.core import (
CALLBACK_TYPE,
ServiceCall,
@ -461,11 +461,15 @@ class EntityPlatform:
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
already_exists = entity.entity_id in self.entities
restored = False
if not already_exists:
if not already_exists and not self.hass.states.async_available(
entity.entity_id
):
existing = self.hass.states.get(entity.entity_id)
if existing and not existing.attributes.get("restored"):
if existing is not None and ATTR_RESTORED in existing.attributes:
restored = True
else:
already_exists = True
if already_exists:
@ -483,6 +487,15 @@ class EntityPlatform:
entity_id = entity.entity_id
self.entities[entity_id] = entity
if not restored:
# Reserve the state in the state machine
# because as soon as we return control to the event
# loop below, another entity could be added
# with the same id before `entity.add_to_platform_finish()`
# has a chance to finish.
self.hass.states.async_reserve(entity.entity_id)
entity.async_on_remove(lambda: self.entities.pop(entity_id))
await entity.add_to_platform_finish()

View File

@ -27,6 +27,7 @@ from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_FRIENDLY_NAME,
ATTR_ICON,
ATTR_RESTORED,
ATTR_SUPPORTED_FEATURES,
ATTR_UNIT_OF_MEASUREMENT,
EVENT_HOMEASSISTANT_START,
@ -56,8 +57,6 @@ DISABLED_HASS = "hass"
DISABLED_USER = "user"
DISABLED_INTEGRATION = "integration"
ATTR_RESTORED = "restored"
STORAGE_VERSION = 1
STORAGE_KEY = "core.entity_registry"
@ -183,7 +182,7 @@ class EntityRegistry:
while (
test_string in self.entities
or test_string in known_object_ids
or self.hass.states.get(test_string)
or not self.hass.states.async_available(test_string)
):
tries += 1
test_string = f"{preferred_string}_{tries}"

View File

@ -975,3 +975,49 @@ async def test_setup_entry_with_entities_that_block_forever(hass, caplog):
assert "test_domain.test1" in caplog.text
assert "test_domain" in caplog.text
assert "test" in caplog.text
async def test_two_platforms_add_same_entity(hass):
"""Test two platforms in the same domain adding an entity with the same name."""
entity_platform1 = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = SlowEntity(name="entity_1")
entity_platform2 = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity2 = SlowEntity(name="entity_1")
await asyncio.gather(
entity_platform1.async_add_entities([entity1]),
entity_platform2.async_add_entities([entity2]),
)
entities = []
@callback
def handle_service(entity, *_):
entities.append(entity)
entity_platform1.async_register_entity_service("hello", {}, handle_service)
await hass.services.async_call(
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
)
assert len(entities) == 2
assert {entity1.entity_id, entity2.entity_id} == {
"mock_integration.entity_1",
"mock_integration.entity_1_2",
}
assert entity1 in entities
assert entity2 in entities
class SlowEntity(MockEntity):
"""An entity that will sleep during add."""
async def async_added_to_hass(self):
"""Make sure control is returned to the event loop on add."""
await asyncio.sleep(0.1)
await super().async_added_to_hass()

View File

@ -1537,3 +1537,26 @@ async def test_hassjob_forbid_coroutine():
# To avoid warning about unawaited coro
await coro
async def test_reserving_states(hass):
"""Test we can reserve a state in the state machine."""
hass.states.async_reserve("light.bedroom")
assert hass.states.async_available("light.bedroom") is False
hass.states.async_set("light.bedroom", "on")
assert hass.states.async_available("light.bedroom") is False
with pytest.raises(ha.HomeAssistantError):
hass.states.async_reserve("light.bedroom")
hass.states.async_remove("light.bedroom")
assert hass.states.async_available("light.bedroom") is True
hass.states.async_set("light.bedroom", "on")
with pytest.raises(ha.HomeAssistantError):
hass.states.async_reserve("light.bedroom")
assert hass.states.async_available("light.bedroom") is False
hass.states.async_remove("light.bedroom")
assert hass.states.async_available("light.bedroom") is True