Fix refactoring error with updating polling entities in sequence (#93693)

* Fix refactoring error with updating in sequence

see #93649

* coverage

* make sure entities are being updated in parallel

* make sure entities are being updated in sequence
This commit is contained in:
J. Nick Koston 2023-05-28 09:20:48 -05:00 committed by GitHub
parent 49c3a8886f
commit 083cf7a38b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 6 deletions

View File

@ -136,7 +136,7 @@ class EntityPlatform:
self._process_updates: asyncio.Lock | None = None
self.parallel_updates: asyncio.Semaphore | None = None
self._update_in_parallel: bool = True
self._update_in_sequence: bool = False
# Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities
@ -187,7 +187,7 @@ class EntityPlatform:
if parallel_updates is not None:
self.parallel_updates = asyncio.Semaphore(parallel_updates)
self._update_in_parallel = parallel_updates != 1
self._update_in_sequence = parallel_updates == 1
return self.parallel_updates
@ -846,11 +846,13 @@ class EntityPlatform:
return
async with self._process_updates:
if self._update_in_parallel or len(self.entities) <= 1:
# If we know are going to update sequentially, we want to update
# to avoid scheduling the coroutines as tasks that will we know
# are going to wait on the semaphore lock.
if self._update_in_sequence or len(self.entities) <= 1:
# If we know we will update sequentially, we want to avoid scheduling
# the coroutines as tasks that will wait on the semaphore lock.
for entity in list(self.entities.values()):
# If the entity is removed from hass during the previous
# entity being updated, we need to skip updating the
# entity.
if entity.should_poll and entity.hass:
await entity.async_update_ha_state(True)
return

View File

@ -2,6 +2,7 @@
import asyncio
from datetime import timedelta
import logging
from typing import Any
from unittest.mock import ANY, Mock, patch
import pytest
@ -307,6 +308,7 @@ async def test_parallel_updates_async_platform(hass: HomeAssistant) -> None:
entity = AsyncEntity()
await handle.async_add_entities([entity])
assert entity.parallel_updates is None
assert handle._update_in_sequence is False
async def test_parallel_updates_async_platform_with_constant(
@ -336,6 +338,7 @@ async def test_parallel_updates_async_platform_with_constant(
await handle.async_add_entities([entity])
assert entity.parallel_updates is not None
assert entity.parallel_updates._value == 2
assert handle._update_in_sequence is False
async def test_parallel_updates_sync_platform(hass: HomeAssistant) -> None:
@ -412,6 +415,104 @@ async def test_parallel_updates_sync_platform_with_constant(
assert entity.parallel_updates._value == 2
async def test_parallel_updates_async_platform_updates_in_parallel(
hass: HomeAssistant,
) -> None:
"""Test an async platform is updated in parallel."""
platform = MockPlatform()
mock_entity_platform(hass, "test_domain.async_platform", platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
await component.async_setup({DOMAIN: {"platform": "async_platform"}})
await hass.async_block_till_done()
handle = list(component._platforms.values())[-1]
updating = []
peak_update_count = 0
class AsyncEntity(MockEntity):
"""Mock entity that has async_update."""
async def async_update(self):
pass
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
nonlocal peak_update_count
updating.append(self.entity_id)
await asyncio.sleep(0)
peak_update_count = max(len(updating), peak_update_count)
await asyncio.sleep(0)
updating.remove(self.entity_id)
entity1 = AsyncEntity()
entity2 = AsyncEntity()
entity3 = AsyncEntity()
await handle.async_add_entities([entity1, entity2, entity3])
assert entity1.parallel_updates is None
assert entity2.parallel_updates is None
assert entity3.parallel_updates is None
assert handle._update_in_sequence is False
await handle._update_entity_states(dt_util.utcnow())
assert peak_update_count > 1
async def test_parallel_updates_sync_platform_updates_in_sequence(
hass: HomeAssistant,
) -> None:
"""Test a sync platform is updated in sequence."""
platform = MockPlatform()
mock_entity_platform(hass, "test_domain.platform", platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
await component.async_setup({DOMAIN: {"platform": "platform"}})
await hass.async_block_till_done()
handle = list(component._platforms.values())[-1]
updating = []
peak_update_count = 0
class SyncEntity(MockEntity):
"""Mock entity that has update."""
def update(self):
pass
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
nonlocal peak_update_count
updating.append(self.entity_id)
await asyncio.sleep(0)
peak_update_count = max(len(updating), peak_update_count)
await asyncio.sleep(0)
updating.remove(self.entity_id)
entity1 = SyncEntity()
entity2 = SyncEntity()
entity3 = SyncEntity()
await handle.async_add_entities([entity1, entity2, entity3])
assert entity1.parallel_updates is not None
assert entity1.parallel_updates._value == 1
assert entity2.parallel_updates is not None
assert entity2.parallel_updates._value == 1
assert entity3.parallel_updates is not None
assert entity3.parallel_updates._value == 1
assert handle._update_in_sequence is True
await handle._update_entity_states(dt_util.utcnow())
assert peak_update_count == 1
async def test_raise_error_on_update(hass: HomeAssistant) -> None:
"""Test the add entity if they raise an error on update."""
updates = []