mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 20:27:08 +00:00
Fix state overwrite race condition where two platforms request the same entity_id (#42151)
* Fix state overwrite race condition where two platforms request the same entity id * fix test * create reservations instead * revert * cannot use __slots__ because we patch async_all
This commit is contained in:
parent
bb641c23a9
commit
df2ede6522
@ -355,6 +355,9 @@ ATTR_STATE = "state"
|
||||
ATTR_EDITABLE = "editable"
|
||||
ATTR_OPTION = "option"
|
||||
|
||||
# The entity has been restored with restore state
|
||||
ATTR_RESTORED = "restored"
|
||||
|
||||
# Bitfield of supported component features for the entity
|
||||
ATTR_SUPPORTED_FEATURES = "supported_features"
|
||||
|
||||
|
@ -973,6 +973,7 @@ class StateMachine:
|
||||
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
|
||||
"""Initialize state machine."""
|
||||
self._states: Dict[str, State] = {}
|
||||
self._reservations: Set[str] = set()
|
||||
self._bus = bus
|
||||
self._loop = loop
|
||||
|
||||
@ -1080,6 +1081,9 @@ class StateMachine:
|
||||
entity_id = entity_id.lower()
|
||||
old_state = self._states.pop(entity_id, None)
|
||||
|
||||
if entity_id in self._reservations:
|
||||
self._reservations.remove(entity_id)
|
||||
|
||||
if old_state is None:
|
||||
return False
|
||||
|
||||
@ -1116,6 +1120,29 @@ class StateMachine:
|
||||
context,
|
||||
).result()
|
||||
|
||||
@callback
|
||||
def async_reserve(self, entity_id: str) -> None:
|
||||
"""Reserve a state in the state machine for an entity being added.
|
||||
|
||||
This must not fire an event when the state is reserved.
|
||||
|
||||
This avoids a race condition where multiple entities with the same
|
||||
entity_id are added.
|
||||
"""
|
||||
entity_id = entity_id.lower()
|
||||
if entity_id in self._states or entity_id in self._reservations:
|
||||
raise HomeAssistantError(
|
||||
"async_reserve must not be called once the state is in the state machine."
|
||||
)
|
||||
|
||||
self._reservations.add(entity_id)
|
||||
|
||||
@callback
|
||||
def async_available(self, entity_id: str) -> bool:
|
||||
"""Check to see if an entity_id is available to be used."""
|
||||
entity_id = entity_id.lower()
|
||||
return entity_id not in self._states and entity_id not in self._reservations
|
||||
|
||||
@callback
|
||||
def async_set(
|
||||
self,
|
||||
|
@ -77,7 +77,7 @@ def async_generate_entity_id(
|
||||
|
||||
test_string = preferred_string
|
||||
tries = 1
|
||||
while hass.states.get(test_string):
|
||||
while not hass.states.async_available(test_string):
|
||||
tries += 1
|
||||
test_string = f"{preferred_string}_{tries}"
|
||||
|
||||
|
@ -7,7 +7,7 @@ from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import DEVICE_DEFAULT_NAME
|
||||
from homeassistant.const import ATTR_RESTORED, DEVICE_DEFAULT_NAME
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
ServiceCall,
|
||||
@ -461,11 +461,15 @@ class EntityPlatform:
|
||||
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
|
||||
|
||||
already_exists = entity.entity_id in self.entities
|
||||
restored = False
|
||||
|
||||
if not already_exists:
|
||||
if not already_exists and not self.hass.states.async_available(
|
||||
entity.entity_id
|
||||
):
|
||||
existing = self.hass.states.get(entity.entity_id)
|
||||
|
||||
if existing and not existing.attributes.get("restored"):
|
||||
if existing is not None and ATTR_RESTORED in existing.attributes:
|
||||
restored = True
|
||||
else:
|
||||
already_exists = True
|
||||
|
||||
if already_exists:
|
||||
@ -483,6 +487,15 @@ class EntityPlatform:
|
||||
|
||||
entity_id = entity.entity_id
|
||||
self.entities[entity_id] = entity
|
||||
|
||||
if not restored:
|
||||
# Reserve the state in the state machine
|
||||
# because as soon as we return control to the event
|
||||
# loop below, another entity could be added
|
||||
# with the same id before `entity.add_to_platform_finish()`
|
||||
# has a chance to finish.
|
||||
self.hass.states.async_reserve(entity.entity_id)
|
||||
|
||||
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
||||
|
||||
await entity.add_to_platform_finish()
|
||||
|
@ -27,6 +27,7 @@ from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
ATTR_FRIENDLY_NAME,
|
||||
ATTR_ICON,
|
||||
ATTR_RESTORED,
|
||||
ATTR_SUPPORTED_FEATURES,
|
||||
ATTR_UNIT_OF_MEASUREMENT,
|
||||
EVENT_HOMEASSISTANT_START,
|
||||
@ -56,8 +57,6 @@ DISABLED_HASS = "hass"
|
||||
DISABLED_USER = "user"
|
||||
DISABLED_INTEGRATION = "integration"
|
||||
|
||||
ATTR_RESTORED = "restored"
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "core.entity_registry"
|
||||
|
||||
@ -183,7 +182,7 @@ class EntityRegistry:
|
||||
while (
|
||||
test_string in self.entities
|
||||
or test_string in known_object_ids
|
||||
or self.hass.states.get(test_string)
|
||||
or not self.hass.states.async_available(test_string)
|
||||
):
|
||||
tries += 1
|
||||
test_string = f"{preferred_string}_{tries}"
|
||||
|
@ -975,3 +975,49 @@ async def test_setup_entry_with_entities_that_block_forever(hass, caplog):
|
||||
assert "test_domain.test1" in caplog.text
|
||||
assert "test_domain" in caplog.text
|
||||
assert "test" in caplog.text
|
||||
|
||||
|
||||
async def test_two_platforms_add_same_entity(hass):
|
||||
"""Test two platforms in the same domain adding an entity with the same name."""
|
||||
entity_platform1 = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
entity1 = SlowEntity(name="entity_1")
|
||||
|
||||
entity_platform2 = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
entity2 = SlowEntity(name="entity_1")
|
||||
|
||||
await asyncio.gather(
|
||||
entity_platform1.async_add_entities([entity1]),
|
||||
entity_platform2.async_add_entities([entity2]),
|
||||
)
|
||||
|
||||
entities = []
|
||||
|
||||
@callback
|
||||
def handle_service(entity, *_):
|
||||
entities.append(entity)
|
||||
|
||||
entity_platform1.async_register_entity_service("hello", {}, handle_service)
|
||||
await hass.services.async_call(
|
||||
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
|
||||
)
|
||||
|
||||
assert len(entities) == 2
|
||||
assert {entity1.entity_id, entity2.entity_id} == {
|
||||
"mock_integration.entity_1",
|
||||
"mock_integration.entity_1_2",
|
||||
}
|
||||
assert entity1 in entities
|
||||
assert entity2 in entities
|
||||
|
||||
|
||||
class SlowEntity(MockEntity):
|
||||
"""An entity that will sleep during add."""
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Make sure control is returned to the event loop on add."""
|
||||
await asyncio.sleep(0.1)
|
||||
await super().async_added_to_hass()
|
||||
|
@ -1537,3 +1537,26 @@ async def test_hassjob_forbid_coroutine():
|
||||
|
||||
# To avoid warning about unawaited coro
|
||||
await coro
|
||||
|
||||
|
||||
async def test_reserving_states(hass):
|
||||
"""Test we can reserve a state in the state machine."""
|
||||
|
||||
hass.states.async_reserve("light.bedroom")
|
||||
assert hass.states.async_available("light.bedroom") is False
|
||||
hass.states.async_set("light.bedroom", "on")
|
||||
assert hass.states.async_available("light.bedroom") is False
|
||||
|
||||
with pytest.raises(ha.HomeAssistantError):
|
||||
hass.states.async_reserve("light.bedroom")
|
||||
|
||||
hass.states.async_remove("light.bedroom")
|
||||
assert hass.states.async_available("light.bedroom") is True
|
||||
hass.states.async_set("light.bedroom", "on")
|
||||
|
||||
with pytest.raises(ha.HomeAssistantError):
|
||||
hass.states.async_reserve("light.bedroom")
|
||||
|
||||
assert hass.states.async_available("light.bedroom") is False
|
||||
hass.states.async_remove("light.bedroom")
|
||||
assert hass.states.async_available("light.bedroom") is True
|
||||
|
Loading…
x
Reference in New Issue
Block a user