mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Fix refactoring error with updating polling entities in sequence (#93693)
* Fix refactoring error with updating in sequence see #93649 * coverage * make sure entities are being updated in parallel * make sure entities are being updated in sequence
This commit is contained in:
parent
49c3a8886f
commit
083cf7a38b
@ -136,7 +136,7 @@ class EntityPlatform:
|
||||
self._process_updates: asyncio.Lock | None = None
|
||||
|
||||
self.parallel_updates: asyncio.Semaphore | None = None
|
||||
self._update_in_parallel: bool = True
|
||||
self._update_in_sequence: bool = False
|
||||
|
||||
# Platform is None for the EntityComponent "catch-all" EntityPlatform
|
||||
# which powers entity_component.add_entities
|
||||
@ -187,7 +187,7 @@ class EntityPlatform:
|
||||
|
||||
if parallel_updates is not None:
|
||||
self.parallel_updates = asyncio.Semaphore(parallel_updates)
|
||||
self._update_in_parallel = parallel_updates != 1
|
||||
self._update_in_sequence = parallel_updates == 1
|
||||
|
||||
return self.parallel_updates
|
||||
|
||||
@ -846,11 +846,13 @@ class EntityPlatform:
|
||||
return
|
||||
|
||||
async with self._process_updates:
|
||||
if self._update_in_parallel or len(self.entities) <= 1:
|
||||
# If we know are going to update sequentially, we want to update
|
||||
# to avoid scheduling the coroutines as tasks that will we know
|
||||
# are going to wait on the semaphore lock.
|
||||
if self._update_in_sequence or len(self.entities) <= 1:
|
||||
# If we know we will update sequentially, we want to avoid scheduling
|
||||
# the coroutines as tasks that will wait on the semaphore lock.
|
||||
for entity in list(self.entities.values()):
|
||||
# If the entity is removed from hass during the previous
|
||||
# entity being updated, we need to skip updating the
|
||||
# entity.
|
||||
if entity.should_poll and entity.hass:
|
||||
await entity.async_update_ha_state(True)
|
||||
return
|
||||
|
@ -2,6 +2,7 @@
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -307,6 +308,7 @@ async def test_parallel_updates_async_platform(hass: HomeAssistant) -> None:
|
||||
entity = AsyncEntity()
|
||||
await handle.async_add_entities([entity])
|
||||
assert entity.parallel_updates is None
|
||||
assert handle._update_in_sequence is False
|
||||
|
||||
|
||||
async def test_parallel_updates_async_platform_with_constant(
|
||||
@ -336,6 +338,7 @@ async def test_parallel_updates_async_platform_with_constant(
|
||||
await handle.async_add_entities([entity])
|
||||
assert entity.parallel_updates is not None
|
||||
assert entity.parallel_updates._value == 2
|
||||
assert handle._update_in_sequence is False
|
||||
|
||||
|
||||
async def test_parallel_updates_sync_platform(hass: HomeAssistant) -> None:
|
||||
@ -412,6 +415,104 @@ async def test_parallel_updates_sync_platform_with_constant(
|
||||
assert entity.parallel_updates._value == 2
|
||||
|
||||
|
||||
async def test_parallel_updates_async_platform_updates_in_parallel(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test an async platform is updated in parallel."""
|
||||
platform = MockPlatform()
|
||||
|
||||
mock_entity_platform(hass, "test_domain.async_platform", platform)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component._platforms = {}
|
||||
|
||||
await component.async_setup({DOMAIN: {"platform": "async_platform"}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
updating = []
|
||||
peak_update_count = 0
|
||||
|
||||
class AsyncEntity(MockEntity):
|
||||
"""Mock entity that has async_update."""
|
||||
|
||||
async def async_update(self):
|
||||
pass
|
||||
|
||||
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
|
||||
nonlocal peak_update_count
|
||||
updating.append(self.entity_id)
|
||||
await asyncio.sleep(0)
|
||||
peak_update_count = max(len(updating), peak_update_count)
|
||||
await asyncio.sleep(0)
|
||||
updating.remove(self.entity_id)
|
||||
|
||||
entity1 = AsyncEntity()
|
||||
entity2 = AsyncEntity()
|
||||
entity3 = AsyncEntity()
|
||||
|
||||
await handle.async_add_entities([entity1, entity2, entity3])
|
||||
|
||||
assert entity1.parallel_updates is None
|
||||
assert entity2.parallel_updates is None
|
||||
assert entity3.parallel_updates is None
|
||||
|
||||
assert handle._update_in_sequence is False
|
||||
|
||||
await handle._update_entity_states(dt_util.utcnow())
|
||||
assert peak_update_count > 1
|
||||
|
||||
|
||||
async def test_parallel_updates_sync_platform_updates_in_sequence(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test a sync platform is updated in sequence."""
|
||||
platform = MockPlatform()
|
||||
|
||||
mock_entity_platform(hass, "test_domain.platform", platform)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component._platforms = {}
|
||||
|
||||
await component.async_setup({DOMAIN: {"platform": "platform"}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
updating = []
|
||||
peak_update_count = 0
|
||||
|
||||
class SyncEntity(MockEntity):
|
||||
"""Mock entity that has update."""
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
|
||||
nonlocal peak_update_count
|
||||
updating.append(self.entity_id)
|
||||
await asyncio.sleep(0)
|
||||
peak_update_count = max(len(updating), peak_update_count)
|
||||
await asyncio.sleep(0)
|
||||
updating.remove(self.entity_id)
|
||||
|
||||
entity1 = SyncEntity()
|
||||
entity2 = SyncEntity()
|
||||
entity3 = SyncEntity()
|
||||
|
||||
await handle.async_add_entities([entity1, entity2, entity3])
|
||||
assert entity1.parallel_updates is not None
|
||||
assert entity1.parallel_updates._value == 1
|
||||
assert entity2.parallel_updates is not None
|
||||
assert entity2.parallel_updates._value == 1
|
||||
assert entity3.parallel_updates is not None
|
||||
assert entity3.parallel_updates._value == 1
|
||||
|
||||
assert handle._update_in_sequence is True
|
||||
|
||||
await handle._update_entity_states(dt_util.utcnow())
|
||||
assert peak_update_count == 1
|
||||
|
||||
|
||||
async def test_raise_error_on_update(hass: HomeAssistant) -> None:
|
||||
"""Test the add entity if they raise an error on update."""
|
||||
updates = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user