From e85b089effbf584af3ca10eb118dd613f8509f27 Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Mon, 25 Mar 2019 23:53:36 -0700 Subject: [PATCH] Set default parallel_update value should base on async_update (#22149) * Set default parallel_update value should base on async_update * Set default parallel_update value should base on async_update * Delay the parallel_update_semaphore creation * Remove outdated comment --- homeassistant/helpers/entity_platform.py | 43 +++-- tests/helpers/test_entity.py | 216 +++++++++++++++-------- tests/helpers/test_entity_platform.py | 104 ++++++++--- 3 files changed, 242 insertions(+), 121 deletions(-) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 87cc4d4fd90..a092c89405e 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -27,7 +27,6 @@ class EntityPlatform: domain: str platform_name: str scan_interval: timedelta - parallel_updates: int entity_namespace: str async_entities_added_callback: @callback method """ @@ -52,22 +51,21 @@ class EntityPlatform: # which powers entity_component.add_entities if platform is None: self.parallel_updates = None + self.parallel_updates_semaphore = None return - # Async platforms do all updates in parallel by default - if hasattr(platform, 'async_setup_platform'): - default_parallel_updates = 0 - else: - default_parallel_updates = 1 + self.parallel_updates = getattr(platform, 'PARALLEL_UPDATES', None) + # semaphore will be created on demand + self.parallel_updates_semaphore = None - parallel_updates = getattr(platform, 'PARALLEL_UPDATES', - default_parallel_updates) - - if parallel_updates: - self.parallel_updates = asyncio.Semaphore( - parallel_updates, loop=hass.loop) - else: - self.parallel_updates = None + def _get_parallel_updates_semaphore(self): + """Get or create a semaphore for parallel updates.""" + if self.parallel_updates_semaphore is None: + self.parallel_updates_semaphore = asyncio.Semaphore( + self.parallel_updates if self.parallel_updates else 1, + loop=self.hass.loop + ) + return self.parallel_updates_semaphore async def async_setup(self, platform_config, discovery_info=None): """Set up the platform from a config file.""" @@ -240,7 +238,22 @@ class EntityPlatform: entity.hass = self.hass entity.platform = self - entity.parallel_updates = self.parallel_updates + + # Async entity + # PARALLEL_UPDATE == None: entity.parallel_updates = None + # PARALLEL_UPDATE == 0: entity.parallel_updates = None + # PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p) + # Sync entity + # PARALLEL_UPDATE == None: entity.parallel_updates = Semaphore(1) + # PARALLEL_UPDATE == 0: entity.parallel_updates = None + # PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p) + if hasattr(entity, 'async_update') and not self.parallel_updates: + entity.parallel_updates = None + elif (not hasattr(entity, 'async_update') + and self.parallel_updates == 0): + entity.parallel_updates = None + else: + entity.parallel_updates = self._get_parallel_updates_semaphore() # Update properties before we generate the entity_id if update_before_add: diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index d79f84d416d..383cd05a009 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -1,6 +1,7 @@ """Test the entity helper.""" # pylint: disable=protected-access import asyncio +import threading from datetime import timedelta from unittest.mock import MagicMock, patch, PropertyMock @@ -225,11 +226,10 @@ def test_async_schedule_update_ha_state(hass): assert update_call is True -@asyncio.coroutine -def test_async_parallel_updates_with_zero(hass): +async def test_async_parallel_updates_with_zero(hass): """Test parallel updates with 0 (disabled).""" updates = [] - test_lock = asyncio.Event(loop=hass.loop) + test_lock = asyncio.Event() class AsyncEntity(entity.Entity): @@ -239,37 +239,73 @@ def test_async_parallel_updates_with_zero(hass): self.hass = hass self._count = count - @asyncio.coroutine - def async_update(self): + async def async_update(self): """Test update.""" updates.append(self._count) - yield from test_lock.wait() + await test_lock.wait() ent_1 = AsyncEntity("sensor.test_1", 1) ent_2 = AsyncEntity("sensor.test_2", 2) - ent_1.async_schedule_update_ha_state(True) - ent_2.async_schedule_update_ha_state(True) + try: + ent_1.async_schedule_update_ha_state(True) + ent_2.async_schedule_update_ha_state(True) - while True: - if len(updates) == 2: - break - yield from asyncio.sleep(0, loop=hass.loop) + while True: + if len(updates) >= 2: + break + await asyncio.sleep(0) - assert len(updates) == 2 - assert updates == [1, 2] - - test_lock.set() + assert len(updates) == 2 + assert updates == [1, 2] + finally: + test_lock.set() -@asyncio.coroutine -def test_async_parallel_updates_with_one(hass): +async def test_async_parallel_updates_with_zero_on_sync_update(hass): + """Test parallel updates with 0 (disabled).""" + updates = [] + test_lock = threading.Event() + + class AsyncEntity(entity.Entity): + + def __init__(self, entity_id, count): + """Initialize Async test entity.""" + self.entity_id = entity_id + self.hass = hass + self._count = count + + def update(self): + """Test update.""" + updates.append(self._count) + if not test_lock.wait(timeout=1): + # if timeout populate more data to fail the test + updates.append(self._count) + + ent_1 = AsyncEntity("sensor.test_1", 1) + ent_2 = AsyncEntity("sensor.test_2", 2) + + try: + ent_1.async_schedule_update_ha_state(True) + ent_2.async_schedule_update_ha_state(True) + + while True: + if len(updates) >= 2: + break + await asyncio.sleep(0) + + assert len(updates) == 2 + assert updates == [1, 2] + finally: + test_lock.set() + await asyncio.sleep(0) + + +async def test_async_parallel_updates_with_one(hass): """Test parallel updates with 1 (sequential).""" updates = [] - test_lock = asyncio.Lock(loop=hass.loop) - test_semaphore = asyncio.Semaphore(1, loop=hass.loop) - - yield from test_lock.acquire() + test_lock = asyncio.Lock() + test_semaphore = asyncio.Semaphore(1) class AsyncEntity(entity.Entity): @@ -280,59 +316,71 @@ def test_async_parallel_updates_with_one(hass): self._count = count self.parallel_updates = test_semaphore - @asyncio.coroutine - def async_update(self): + async def async_update(self): """Test update.""" updates.append(self._count) - yield from test_lock.acquire() + await test_lock.acquire() ent_1 = AsyncEntity("sensor.test_1", 1) ent_2 = AsyncEntity("sensor.test_2", 2) ent_3 = AsyncEntity("sensor.test_3", 3) - ent_1.async_schedule_update_ha_state(True) - ent_2.async_schedule_update_ha_state(True) - ent_3.async_schedule_update_ha_state(True) + await test_lock.acquire() - while True: - if len(updates) == 1: - break - yield from asyncio.sleep(0, loop=hass.loop) + try: + ent_1.async_schedule_update_ha_state(True) + ent_2.async_schedule_update_ha_state(True) + ent_3.async_schedule_update_ha_state(True) - assert len(updates) == 1 - assert updates == [1] + while True: + if len(updates) >= 1: + break + await asyncio.sleep(0) - test_lock.release() + assert len(updates) == 1 + assert updates == [1] - while True: - if len(updates) == 2: - break - yield from asyncio.sleep(0, loop=hass.loop) + updates.clear() + test_lock.release() + await asyncio.sleep(0) - assert len(updates) == 2 - assert updates == [1, 2] + while True: + if len(updates) >= 1: + break + await asyncio.sleep(0) - test_lock.release() + assert len(updates) == 1 + assert updates == [2] - while True: - if len(updates) == 3: - break - yield from asyncio.sleep(0, loop=hass.loop) + updates.clear() + test_lock.release() + await asyncio.sleep(0) - assert len(updates) == 3 - assert updates == [1, 2, 3] + while True: + if len(updates) >= 1: + break + await asyncio.sleep(0) - test_lock.release() + assert len(updates) == 1 + assert updates == [3] + + updates.clear() + test_lock.release() + await asyncio.sleep(0) + + finally: + # we may have more than one lock need to release in case test failed + for _ in updates: + test_lock.release() + await asyncio.sleep(0) + test_lock.release() -@asyncio.coroutine -def test_async_parallel_updates_with_two(hass): +async def test_async_parallel_updates_with_two(hass): """Test parallel updates with 2 (parallel).""" updates = [] - test_lock = asyncio.Lock(loop=hass.loop) - test_semaphore = asyncio.Semaphore(2, loop=hass.loop) - - yield from test_lock.acquire() + test_lock = asyncio.Lock() + test_semaphore = asyncio.Semaphore(2) class AsyncEntity(entity.Entity): @@ -354,34 +402,48 @@ def test_async_parallel_updates_with_two(hass): ent_3 = AsyncEntity("sensor.test_3", 3) ent_4 = AsyncEntity("sensor.test_4", 4) - ent_1.async_schedule_update_ha_state(True) - ent_2.async_schedule_update_ha_state(True) - ent_3.async_schedule_update_ha_state(True) - ent_4.async_schedule_update_ha_state(True) + await test_lock.acquire() - while True: - if len(updates) == 2: - break - yield from asyncio.sleep(0, loop=hass.loop) + try: - assert len(updates) == 2 - assert updates == [1, 2] + ent_1.async_schedule_update_ha_state(True) + ent_2.async_schedule_update_ha_state(True) + ent_3.async_schedule_update_ha_state(True) + ent_4.async_schedule_update_ha_state(True) - test_lock.release() - yield from asyncio.sleep(0, loop=hass.loop) - test_lock.release() + while True: + if len(updates) >= 2: + break + await asyncio.sleep(0) - while True: - if len(updates) == 4: - break - yield from asyncio.sleep(0, loop=hass.loop) + assert len(updates) == 2 + assert updates == [1, 2] - assert len(updates) == 4 - assert updates == [1, 2, 3, 4] + updates.clear() + test_lock.release() + await asyncio.sleep(0) + test_lock.release() + await asyncio.sleep(0) - test_lock.release() - yield from asyncio.sleep(0, loop=hass.loop) - test_lock.release() + while True: + if len(updates) >= 2: + break + await asyncio.sleep(0) + + assert len(updates) == 2 + assert updates == [3, 4] + + updates.clear() + test_lock.release() + await asyncio.sleep(0) + test_lock.release() + await asyncio.sleep(0) + finally: + # we may have more than one lock need to release in case test failed + for _ in updates: + test_lock.release() + await asyncio.sleep(0) + test_lock.release() @asyncio.coroutine diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index e985771e486..6cf0bb0eeeb 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -251,80 +251,126 @@ def test_updated_state_used_for_entity_id(hass): assert entity_ids[0] == "test_domain.living_room" -@asyncio.coroutine -def test_parallel_updates_async_platform(hass): - """Warn we log when platform setup takes a long time.""" +async def test_parallel_updates_async_platform(hass): + """Test async platform does not have parallel_updates limit by default.""" platform = MockPlatform() - @asyncio.coroutine - def mock_update(*args, **kwargs): - pass - - platform.async_setup_platform = mock_update - loader.set_component(hass, 'test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, hass) component._platforms = {} - yield from component.async_setup({ + await component.async_setup({ DOMAIN: { 'platform': 'platform', } }) handle = list(component._platforms.values())[-1] - assert handle.parallel_updates is None + class AsyncEntity(MockEntity): + """Mock entity that has async_update.""" -@asyncio.coroutine -def test_parallel_updates_async_platform_with_constant(hass): - """Warn we log when platform setup takes a long time.""" + async def async_update(self): + pass + + entity = AsyncEntity() + await handle.async_add_entities([entity]) + assert entity.parallel_updates is None + + +async def test_parallel_updates_async_platform_with_constant(hass): + """Test async platform can set parallel_updates limit.""" + platform = MockPlatform() + platform.PARALLEL_UPDATES = 2 + + loader.set_component(hass, 'test_domain.platform', platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + component._platforms = {} + + await component.async_setup({ + DOMAIN: { + 'platform': 'platform', + } + }) + + handle = list(component._platforms.values())[-1] + + assert handle.parallel_updates == 2 + + class AsyncEntity(MockEntity): + """Mock entity that has async_update.""" + + async def async_update(self): + pass + + entity = AsyncEntity() + await handle.async_add_entities([entity]) + assert entity.parallel_updates is not None + assert entity.parallel_updates._value == 2 + + +async def test_parallel_updates_sync_platform(hass): + """Test sync platform parallel_updates default set to 1.""" platform = MockPlatform() - @asyncio.coroutine - def mock_update(*args, **kwargs): - pass - - platform.async_setup_platform = mock_update - platform.PARALLEL_UPDATES = 1 - loader.set_component(hass, 'test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, hass) component._platforms = {} - yield from component.async_setup({ + await component.async_setup({ DOMAIN: { 'platform': 'platform', } }) handle = list(component._platforms.values())[-1] + assert handle.parallel_updates is None - assert handle.parallel_updates is not None + class SyncEntity(MockEntity): + """Mock entity that has update.""" + + async def update(self): + pass + + entity = SyncEntity() + await handle.async_add_entities([entity]) + assert entity.parallel_updates is not None + assert entity.parallel_updates._value == 1 -@asyncio.coroutine -def test_parallel_updates_sync_platform(hass): - """Warn we log when platform setup takes a long time.""" - platform = MockPlatform(setup_platform=lambda *args: None) +async def test_parallel_updates_sync_platform_with_constant(hass): + """Test sync platform can set parallel_updates limit.""" + platform = MockPlatform() + platform.PARALLEL_UPDATES = 2 loader.set_component(hass, 'test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, hass) component._platforms = {} - yield from component.async_setup({ + await component.async_setup({ DOMAIN: { 'platform': 'platform', } }) handle = list(component._platforms.values())[-1] + assert handle.parallel_updates == 2 - assert handle.parallel_updates is not None + class SyncEntity(MockEntity): + """Mock entity that has update.""" + + async def update(self): + pass + + entity = SyncEntity() + await handle.async_add_entities([entity]) + assert entity.parallel_updates is not None + assert entity.parallel_updates._value == 2 @asyncio.coroutine