diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index be335c7a40f..52d436ca997 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -136,7 +136,7 @@ class EntityPlatform: self._process_updates: asyncio.Lock | None = None self.parallel_updates: asyncio.Semaphore | None = None - self._update_in_parallel: bool = True + self._update_in_sequence: bool = False # Platform is None for the EntityComponent "catch-all" EntityPlatform # which powers entity_component.add_entities @@ -187,7 +187,7 @@ class EntityPlatform: if parallel_updates is not None: self.parallel_updates = asyncio.Semaphore(parallel_updates) - self._update_in_parallel = parallel_updates != 1 + self._update_in_sequence = parallel_updates == 1 return self.parallel_updates @@ -846,11 +846,13 @@ class EntityPlatform: return async with self._process_updates: - if self._update_in_parallel or len(self.entities) <= 1: - # If we know are going to update sequentially, we want to update - # to avoid scheduling the coroutines as tasks that will we know - # are going to wait on the semaphore lock. + if self._update_in_sequence or len(self.entities) <= 1: + # If we know we will update sequentially, we want to avoid scheduling + # the coroutines as tasks that will wait on the semaphore lock. for entity in list(self.entities.values()): + # If the entity is removed from hass during the previous + # entity being updated, we need to skip updating the + # entity. if entity.should_poll and entity.hass: await entity.async_update_ha_state(True) return diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index e6b864be09c..346a124424e 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -2,6 +2,7 @@ import asyncio from datetime import timedelta import logging +from typing import Any from unittest.mock import ANY, Mock, patch import pytest @@ -307,6 +308,7 @@ async def test_parallel_updates_async_platform(hass: HomeAssistant) -> None: entity = AsyncEntity() await handle.async_add_entities([entity]) assert entity.parallel_updates is None + assert handle._update_in_sequence is False async def test_parallel_updates_async_platform_with_constant( @@ -336,6 +338,7 @@ async def test_parallel_updates_async_platform_with_constant( await handle.async_add_entities([entity]) assert entity.parallel_updates is not None assert entity.parallel_updates._value == 2 + assert handle._update_in_sequence is False async def test_parallel_updates_sync_platform(hass: HomeAssistant) -> None: @@ -412,6 +415,104 @@ async def test_parallel_updates_sync_platform_with_constant( assert entity.parallel_updates._value == 2 +async def test_parallel_updates_async_platform_updates_in_parallel( + hass: HomeAssistant, +) -> None: + """Test an async platform is updated in parallel.""" + platform = MockPlatform() + + mock_entity_platform(hass, "test_domain.async_platform", platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + component._platforms = {} + + await component.async_setup({DOMAIN: {"platform": "async_platform"}}) + await hass.async_block_till_done() + + handle = list(component._platforms.values())[-1] + updating = [] + peak_update_count = 0 + + class AsyncEntity(MockEntity): + """Mock entity that has async_update.""" + + async def async_update(self): + pass + + async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None: + nonlocal peak_update_count + updating.append(self.entity_id) + await asyncio.sleep(0) + peak_update_count = max(len(updating), peak_update_count) + await asyncio.sleep(0) + updating.remove(self.entity_id) + + entity1 = AsyncEntity() + entity2 = AsyncEntity() + entity3 = AsyncEntity() + + await handle.async_add_entities([entity1, entity2, entity3]) + + assert entity1.parallel_updates is None + assert entity2.parallel_updates is None + assert entity3.parallel_updates is None + + assert handle._update_in_sequence is False + + await handle._update_entity_states(dt_util.utcnow()) + assert peak_update_count > 1 + + +async def test_parallel_updates_sync_platform_updates_in_sequence( + hass: HomeAssistant, +) -> None: + """Test a sync platform is updated in sequence.""" + platform = MockPlatform() + + mock_entity_platform(hass, "test_domain.platform", platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + component._platforms = {} + + await component.async_setup({DOMAIN: {"platform": "platform"}}) + await hass.async_block_till_done() + + handle = list(component._platforms.values())[-1] + updating = [] + peak_update_count = 0 + + class SyncEntity(MockEntity): + """Mock entity that has update.""" + + def update(self): + pass + + async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None: + nonlocal peak_update_count + updating.append(self.entity_id) + await asyncio.sleep(0) + peak_update_count = max(len(updating), peak_update_count) + await asyncio.sleep(0) + updating.remove(self.entity_id) + + entity1 = SyncEntity() + entity2 = SyncEntity() + entity3 = SyncEntity() + + await handle.async_add_entities([entity1, entity2, entity3]) + assert entity1.parallel_updates is not None + assert entity1.parallel_updates._value == 1 + assert entity2.parallel_updates is not None + assert entity2.parallel_updates._value == 1 + assert entity3.parallel_updates is not None + assert entity3.parallel_updates._value == 1 + + assert handle._update_in_sequence is True + + await handle._update_entity_states(dt_util.utcnow()) + assert peak_update_count == 1 + + async def test_raise_error_on_update(hass: HomeAssistant) -> None: """Test the add entity if they raise an error on update.""" updates = []