mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Fix state overwrite race condition where two platforms request the same entity_id (#42151)
* Fix state overwrite race condition where two platforms request the same entity id * fix test * create reservations instead * revert * cannot use __slots__ because we patch async_all
This commit is contained in:
parent
bb641c23a9
commit
df2ede6522
@ -355,6 +355,9 @@ ATTR_STATE = "state"
|
|||||||
ATTR_EDITABLE = "editable"
|
ATTR_EDITABLE = "editable"
|
||||||
ATTR_OPTION = "option"
|
ATTR_OPTION = "option"
|
||||||
|
|
||||||
|
# The entity has been restored with restore state
|
||||||
|
ATTR_RESTORED = "restored"
|
||||||
|
|
||||||
# Bitfield of supported component features for the entity
|
# Bitfield of supported component features for the entity
|
||||||
ATTR_SUPPORTED_FEATURES = "supported_features"
|
ATTR_SUPPORTED_FEATURES = "supported_features"
|
||||||
|
|
||||||
|
@ -973,6 +973,7 @@ class StateMachine:
|
|||||||
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
|
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
|
||||||
"""Initialize state machine."""
|
"""Initialize state machine."""
|
||||||
self._states: Dict[str, State] = {}
|
self._states: Dict[str, State] = {}
|
||||||
|
self._reservations: Set[str] = set()
|
||||||
self._bus = bus
|
self._bus = bus
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
|
|
||||||
@ -1080,6 +1081,9 @@ class StateMachine:
|
|||||||
entity_id = entity_id.lower()
|
entity_id = entity_id.lower()
|
||||||
old_state = self._states.pop(entity_id, None)
|
old_state = self._states.pop(entity_id, None)
|
||||||
|
|
||||||
|
if entity_id in self._reservations:
|
||||||
|
self._reservations.remove(entity_id)
|
||||||
|
|
||||||
if old_state is None:
|
if old_state is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1116,6 +1120,29 @@ class StateMachine:
|
|||||||
context,
|
context,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_reserve(self, entity_id: str) -> None:
|
||||||
|
"""Reserve a state in the state machine for an entity being added.
|
||||||
|
|
||||||
|
This must not fire an event when the state is reserved.
|
||||||
|
|
||||||
|
This avoids a race condition where multiple entities with the same
|
||||||
|
entity_id are added.
|
||||||
|
"""
|
||||||
|
entity_id = entity_id.lower()
|
||||||
|
if entity_id in self._states or entity_id in self._reservations:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"async_reserve must not be called once the state is in the state machine."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._reservations.add(entity_id)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_available(self, entity_id: str) -> bool:
|
||||||
|
"""Check to see if an entity_id is available to be used."""
|
||||||
|
entity_id = entity_id.lower()
|
||||||
|
return entity_id not in self._states and entity_id not in self._reservations
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set(
|
def async_set(
|
||||||
self,
|
self,
|
||||||
|
@ -77,7 +77,7 @@ def async_generate_entity_id(
|
|||||||
|
|
||||||
test_string = preferred_string
|
test_string = preferred_string
|
||||||
tries = 1
|
tries = 1
|
||||||
while hass.states.get(test_string):
|
while not hass.states.async_available(test_string):
|
||||||
tries += 1
|
tries += 1
|
||||||
test_string = f"{preferred_string}_{tries}"
|
test_string = f"{preferred_string}_{tries}"
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from types import ModuleType
|
|||||||
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional
|
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.const import DEVICE_DEFAULT_NAME
|
from homeassistant.const import ATTR_RESTORED, DEVICE_DEFAULT_NAME
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
CALLBACK_TYPE,
|
CALLBACK_TYPE,
|
||||||
ServiceCall,
|
ServiceCall,
|
||||||
@ -461,11 +461,15 @@ class EntityPlatform:
|
|||||||
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
|
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
|
||||||
|
|
||||||
already_exists = entity.entity_id in self.entities
|
already_exists = entity.entity_id in self.entities
|
||||||
|
restored = False
|
||||||
|
|
||||||
if not already_exists:
|
if not already_exists and not self.hass.states.async_available(
|
||||||
|
entity.entity_id
|
||||||
|
):
|
||||||
existing = self.hass.states.get(entity.entity_id)
|
existing = self.hass.states.get(entity.entity_id)
|
||||||
|
if existing is not None and ATTR_RESTORED in existing.attributes:
|
||||||
if existing and not existing.attributes.get("restored"):
|
restored = True
|
||||||
|
else:
|
||||||
already_exists = True
|
already_exists = True
|
||||||
|
|
||||||
if already_exists:
|
if already_exists:
|
||||||
@ -483,6 +487,15 @@ class EntityPlatform:
|
|||||||
|
|
||||||
entity_id = entity.entity_id
|
entity_id = entity.entity_id
|
||||||
self.entities[entity_id] = entity
|
self.entities[entity_id] = entity
|
||||||
|
|
||||||
|
if not restored:
|
||||||
|
# Reserve the state in the state machine
|
||||||
|
# because as soon as we return control to the event
|
||||||
|
# loop below, another entity could be added
|
||||||
|
# with the same id before `entity.add_to_platform_finish()`
|
||||||
|
# has a chance to finish.
|
||||||
|
self.hass.states.async_reserve(entity.entity_id)
|
||||||
|
|
||||||
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
||||||
|
|
||||||
await entity.add_to_platform_finish()
|
await entity.add_to_platform_finish()
|
||||||
|
@ -27,6 +27,7 @@ from homeassistant.const import (
|
|||||||
ATTR_DEVICE_CLASS,
|
ATTR_DEVICE_CLASS,
|
||||||
ATTR_FRIENDLY_NAME,
|
ATTR_FRIENDLY_NAME,
|
||||||
ATTR_ICON,
|
ATTR_ICON,
|
||||||
|
ATTR_RESTORED,
|
||||||
ATTR_SUPPORTED_FEATURES,
|
ATTR_SUPPORTED_FEATURES,
|
||||||
ATTR_UNIT_OF_MEASUREMENT,
|
ATTR_UNIT_OF_MEASUREMENT,
|
||||||
EVENT_HOMEASSISTANT_START,
|
EVENT_HOMEASSISTANT_START,
|
||||||
@ -56,8 +57,6 @@ DISABLED_HASS = "hass"
|
|||||||
DISABLED_USER = "user"
|
DISABLED_USER = "user"
|
||||||
DISABLED_INTEGRATION = "integration"
|
DISABLED_INTEGRATION = "integration"
|
||||||
|
|
||||||
ATTR_RESTORED = "restored"
|
|
||||||
|
|
||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION = 1
|
||||||
STORAGE_KEY = "core.entity_registry"
|
STORAGE_KEY = "core.entity_registry"
|
||||||
|
|
||||||
@ -183,7 +182,7 @@ class EntityRegistry:
|
|||||||
while (
|
while (
|
||||||
test_string in self.entities
|
test_string in self.entities
|
||||||
or test_string in known_object_ids
|
or test_string in known_object_ids
|
||||||
or self.hass.states.get(test_string)
|
or not self.hass.states.async_available(test_string)
|
||||||
):
|
):
|
||||||
tries += 1
|
tries += 1
|
||||||
test_string = f"{preferred_string}_{tries}"
|
test_string = f"{preferred_string}_{tries}"
|
||||||
|
@ -975,3 +975,49 @@ async def test_setup_entry_with_entities_that_block_forever(hass, caplog):
|
|||||||
assert "test_domain.test1" in caplog.text
|
assert "test_domain.test1" in caplog.text
|
||||||
assert "test_domain" in caplog.text
|
assert "test_domain" in caplog.text
|
||||||
assert "test" in caplog.text
|
assert "test" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_two_platforms_add_same_entity(hass):
|
||||||
|
"""Test two platforms in the same domain adding an entity with the same name."""
|
||||||
|
entity_platform1 = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = SlowEntity(name="entity_1")
|
||||||
|
|
||||||
|
entity_platform2 = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity2 = SlowEntity(name="entity_1")
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
entity_platform1.async_add_entities([entity1]),
|
||||||
|
entity_platform2.async_add_entities([entity2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def handle_service(entity, *_):
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
entity_platform1.async_register_entity_service("hello", {}, handle_service)
|
||||||
|
await hass.services.async_call(
|
||||||
|
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(entities) == 2
|
||||||
|
assert {entity1.entity_id, entity2.entity_id} == {
|
||||||
|
"mock_integration.entity_1",
|
||||||
|
"mock_integration.entity_1_2",
|
||||||
|
}
|
||||||
|
assert entity1 in entities
|
||||||
|
assert entity2 in entities
|
||||||
|
|
||||||
|
|
||||||
|
class SlowEntity(MockEntity):
|
||||||
|
"""An entity that will sleep during add."""
|
||||||
|
|
||||||
|
async def async_added_to_hass(self):
|
||||||
|
"""Make sure control is returned to the event loop on add."""
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
@ -1537,3 +1537,26 @@ async def test_hassjob_forbid_coroutine():
|
|||||||
|
|
||||||
# To avoid warning about unawaited coro
|
# To avoid warning about unawaited coro
|
||||||
await coro
|
await coro
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reserving_states(hass):
|
||||||
|
"""Test we can reserve a state in the state machine."""
|
||||||
|
|
||||||
|
hass.states.async_reserve("light.bedroom")
|
||||||
|
assert hass.states.async_available("light.bedroom") is False
|
||||||
|
hass.states.async_set("light.bedroom", "on")
|
||||||
|
assert hass.states.async_available("light.bedroom") is False
|
||||||
|
|
||||||
|
with pytest.raises(ha.HomeAssistantError):
|
||||||
|
hass.states.async_reserve("light.bedroom")
|
||||||
|
|
||||||
|
hass.states.async_remove("light.bedroom")
|
||||||
|
assert hass.states.async_available("light.bedroom") is True
|
||||||
|
hass.states.async_set("light.bedroom", "on")
|
||||||
|
|
||||||
|
with pytest.raises(ha.HomeAssistantError):
|
||||||
|
hass.states.async_reserve("light.bedroom")
|
||||||
|
|
||||||
|
assert hass.states.async_available("light.bedroom") is False
|
||||||
|
hass.states.async_remove("light.bedroom")
|
||||||
|
assert hass.states.async_available("light.bedroom") is True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user