mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Set default parallel_update value should base on async_update (#22149)
* Set default parallel_update value should base on async_update * Set default parallel_update value should base on async_update * Delay the parallel_update_semaphore creation * Remove outdated comment
This commit is contained in:
parent
a62c116959
commit
e85b089eff
@ -27,7 +27,6 @@ class EntityPlatform:
|
||||
domain: str
|
||||
platform_name: str
|
||||
scan_interval: timedelta
|
||||
parallel_updates: int
|
||||
entity_namespace: str
|
||||
async_entities_added_callback: @callback method
|
||||
"""
|
||||
@ -52,22 +51,21 @@ class EntityPlatform:
|
||||
# which powers entity_component.add_entities
|
||||
if platform is None:
|
||||
self.parallel_updates = None
|
||||
self.parallel_updates_semaphore = None
|
||||
return
|
||||
|
||||
# Async platforms do all updates in parallel by default
|
||||
if hasattr(platform, 'async_setup_platform'):
|
||||
default_parallel_updates = 0
|
||||
else:
|
||||
default_parallel_updates = 1
|
||||
self.parallel_updates = getattr(platform, 'PARALLEL_UPDATES', None)
|
||||
# semaphore will be created on demand
|
||||
self.parallel_updates_semaphore = None
|
||||
|
||||
parallel_updates = getattr(platform, 'PARALLEL_UPDATES',
|
||||
default_parallel_updates)
|
||||
|
||||
if parallel_updates:
|
||||
self.parallel_updates = asyncio.Semaphore(
|
||||
parallel_updates, loop=hass.loop)
|
||||
else:
|
||||
self.parallel_updates = None
|
||||
def _get_parallel_updates_semaphore(self):
|
||||
"""Get or create a semaphore for parallel updates."""
|
||||
if self.parallel_updates_semaphore is None:
|
||||
self.parallel_updates_semaphore = asyncio.Semaphore(
|
||||
self.parallel_updates if self.parallel_updates else 1,
|
||||
loop=self.hass.loop
|
||||
)
|
||||
return self.parallel_updates_semaphore
|
||||
|
||||
async def async_setup(self, platform_config, discovery_info=None):
|
||||
"""Set up the platform from a config file."""
|
||||
@ -240,7 +238,22 @@ class EntityPlatform:
|
||||
|
||||
entity.hass = self.hass
|
||||
entity.platform = self
|
||||
entity.parallel_updates = self.parallel_updates
|
||||
|
||||
# Async entity
|
||||
# PARALLEL_UPDATE == None: entity.parallel_updates = None
|
||||
# PARALLEL_UPDATE == 0: entity.parallel_updates = None
|
||||
# PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p)
|
||||
# Sync entity
|
||||
# PARALLEL_UPDATE == None: entity.parallel_updates = Semaphore(1)
|
||||
# PARALLEL_UPDATE == 0: entity.parallel_updates = None
|
||||
# PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p)
|
||||
if hasattr(entity, 'async_update') and not self.parallel_updates:
|
||||
entity.parallel_updates = None
|
||||
elif (not hasattr(entity, 'async_update')
|
||||
and self.parallel_updates == 0):
|
||||
entity.parallel_updates = None
|
||||
else:
|
||||
entity.parallel_updates = self._get_parallel_updates_semaphore()
|
||||
|
||||
# Update properties before we generate the entity_id
|
||||
if update_before_add:
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test the entity helper."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
@ -225,11 +226,10 @@ def test_async_schedule_update_ha_state(hass):
|
||||
assert update_call is True
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_async_parallel_updates_with_zero(hass):
|
||||
async def test_async_parallel_updates_with_zero(hass):
|
||||
"""Test parallel updates with 0 (disabled)."""
|
||||
updates = []
|
||||
test_lock = asyncio.Event(loop=hass.loop)
|
||||
test_lock = asyncio.Event()
|
||||
|
||||
class AsyncEntity(entity.Entity):
|
||||
|
||||
@ -239,37 +239,73 @@ def test_async_parallel_updates_with_zero(hass):
|
||||
self.hass = hass
|
||||
self._count = count
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update(self):
|
||||
async def async_update(self):
|
||||
"""Test update."""
|
||||
updates.append(self._count)
|
||||
yield from test_lock.wait()
|
||||
await test_lock.wait()
|
||||
|
||||
ent_1 = AsyncEntity("sensor.test_1", 1)
|
||||
ent_2 = AsyncEntity("sensor.test_2", 2)
|
||||
|
||||
ent_1.async_schedule_update_ha_state(True)
|
||||
ent_2.async_schedule_update_ha_state(True)
|
||||
try:
|
||||
ent_1.async_schedule_update_ha_state(True)
|
||||
ent_2.async_schedule_update_ha_state(True)
|
||||
|
||||
while True:
|
||||
if len(updates) == 2:
|
||||
break
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
while True:
|
||||
if len(updates) >= 2:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates == [1, 2]
|
||||
|
||||
test_lock.set()
|
||||
assert len(updates) == 2
|
||||
assert updates == [1, 2]
|
||||
finally:
|
||||
test_lock.set()
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_async_parallel_updates_with_one(hass):
|
||||
async def test_async_parallel_updates_with_zero_on_sync_update(hass):
|
||||
"""Test parallel updates with 0 (disabled)."""
|
||||
updates = []
|
||||
test_lock = threading.Event()
|
||||
|
||||
class AsyncEntity(entity.Entity):
|
||||
|
||||
def __init__(self, entity_id, count):
|
||||
"""Initialize Async test entity."""
|
||||
self.entity_id = entity_id
|
||||
self.hass = hass
|
||||
self._count = count
|
||||
|
||||
def update(self):
|
||||
"""Test update."""
|
||||
updates.append(self._count)
|
||||
if not test_lock.wait(timeout=1):
|
||||
# if timeout populate more data to fail the test
|
||||
updates.append(self._count)
|
||||
|
||||
ent_1 = AsyncEntity("sensor.test_1", 1)
|
||||
ent_2 = AsyncEntity("sensor.test_2", 2)
|
||||
|
||||
try:
|
||||
ent_1.async_schedule_update_ha_state(True)
|
||||
ent_2.async_schedule_update_ha_state(True)
|
||||
|
||||
while True:
|
||||
if len(updates) >= 2:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates == [1, 2]
|
||||
finally:
|
||||
test_lock.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def test_async_parallel_updates_with_one(hass):
|
||||
"""Test parallel updates with 1 (sequential)."""
|
||||
updates = []
|
||||
test_lock = asyncio.Lock(loop=hass.loop)
|
||||
test_semaphore = asyncio.Semaphore(1, loop=hass.loop)
|
||||
|
||||
yield from test_lock.acquire()
|
||||
test_lock = asyncio.Lock()
|
||||
test_semaphore = asyncio.Semaphore(1)
|
||||
|
||||
class AsyncEntity(entity.Entity):
|
||||
|
||||
@ -280,59 +316,71 @@ def test_async_parallel_updates_with_one(hass):
|
||||
self._count = count
|
||||
self.parallel_updates = test_semaphore
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update(self):
|
||||
async def async_update(self):
|
||||
"""Test update."""
|
||||
updates.append(self._count)
|
||||
yield from test_lock.acquire()
|
||||
await test_lock.acquire()
|
||||
|
||||
ent_1 = AsyncEntity("sensor.test_1", 1)
|
||||
ent_2 = AsyncEntity("sensor.test_2", 2)
|
||||
ent_3 = AsyncEntity("sensor.test_3", 3)
|
||||
|
||||
ent_1.async_schedule_update_ha_state(True)
|
||||
ent_2.async_schedule_update_ha_state(True)
|
||||
ent_3.async_schedule_update_ha_state(True)
|
||||
await test_lock.acquire()
|
||||
|
||||
while True:
|
||||
if len(updates) == 1:
|
||||
break
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
try:
|
||||
ent_1.async_schedule_update_ha_state(True)
|
||||
ent_2.async_schedule_update_ha_state(True)
|
||||
ent_3.async_schedule_update_ha_state(True)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates == [1]
|
||||
while True:
|
||||
if len(updates) >= 1:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
test_lock.release()
|
||||
assert len(updates) == 1
|
||||
assert updates == [1]
|
||||
|
||||
while True:
|
||||
if len(updates) == 2:
|
||||
break
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
updates.clear()
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates == [1, 2]
|
||||
while True:
|
||||
if len(updates) >= 1:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
test_lock.release()
|
||||
assert len(updates) == 1
|
||||
assert updates == [2]
|
||||
|
||||
while True:
|
||||
if len(updates) == 3:
|
||||
break
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
updates.clear()
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(updates) == 3
|
||||
assert updates == [1, 2, 3]
|
||||
while True:
|
||||
if len(updates) >= 1:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
test_lock.release()
|
||||
assert len(updates) == 1
|
||||
assert updates == [3]
|
||||
|
||||
updates.clear()
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
finally:
|
||||
# we may have more than one lock need to release in case test failed
|
||||
for _ in updates:
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
test_lock.release()
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_async_parallel_updates_with_two(hass):
|
||||
async def test_async_parallel_updates_with_two(hass):
|
||||
"""Test parallel updates with 2 (parallel)."""
|
||||
updates = []
|
||||
test_lock = asyncio.Lock(loop=hass.loop)
|
||||
test_semaphore = asyncio.Semaphore(2, loop=hass.loop)
|
||||
|
||||
yield from test_lock.acquire()
|
||||
test_lock = asyncio.Lock()
|
||||
test_semaphore = asyncio.Semaphore(2)
|
||||
|
||||
class AsyncEntity(entity.Entity):
|
||||
|
||||
@ -354,34 +402,48 @@ def test_async_parallel_updates_with_two(hass):
|
||||
ent_3 = AsyncEntity("sensor.test_3", 3)
|
||||
ent_4 = AsyncEntity("sensor.test_4", 4)
|
||||
|
||||
ent_1.async_schedule_update_ha_state(True)
|
||||
ent_2.async_schedule_update_ha_state(True)
|
||||
ent_3.async_schedule_update_ha_state(True)
|
||||
ent_4.async_schedule_update_ha_state(True)
|
||||
await test_lock.acquire()
|
||||
|
||||
while True:
|
||||
if len(updates) == 2:
|
||||
break
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
try:
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates == [1, 2]
|
||||
ent_1.async_schedule_update_ha_state(True)
|
||||
ent_2.async_schedule_update_ha_state(True)
|
||||
ent_3.async_schedule_update_ha_state(True)
|
||||
ent_4.async_schedule_update_ha_state(True)
|
||||
|
||||
test_lock.release()
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
test_lock.release()
|
||||
while True:
|
||||
if len(updates) >= 2:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
while True:
|
||||
if len(updates) == 4:
|
||||
break
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
assert len(updates) == 2
|
||||
assert updates == [1, 2]
|
||||
|
||||
assert len(updates) == 4
|
||||
assert updates == [1, 2, 3, 4]
|
||||
updates.clear()
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
test_lock.release()
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
test_lock.release()
|
||||
while True:
|
||||
if len(updates) >= 2:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates == [3, 4]
|
||||
|
||||
updates.clear()
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
finally:
|
||||
# we may have more than one lock need to release in case test failed
|
||||
for _ in updates:
|
||||
test_lock.release()
|
||||
await asyncio.sleep(0)
|
||||
test_lock.release()
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -251,80 +251,126 @@ def test_updated_state_used_for_entity_id(hass):
|
||||
assert entity_ids[0] == "test_domain.living_room"
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_parallel_updates_async_platform(hass):
|
||||
"""Warn we log when platform setup takes a long time."""
|
||||
async def test_parallel_updates_async_platform(hass):
|
||||
"""Test async platform does not have parallel_updates limit by default."""
|
||||
platform = MockPlatform()
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_update(*args, **kwargs):
|
||||
pass
|
||||
|
||||
platform.async_setup_platform = mock_update
|
||||
|
||||
loader.set_component(hass, 'test_domain.platform', platform)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component._platforms = {}
|
||||
|
||||
yield from component.async_setup({
|
||||
await component.async_setup({
|
||||
DOMAIN: {
|
||||
'platform': 'platform',
|
||||
}
|
||||
})
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
|
||||
assert handle.parallel_updates is None
|
||||
|
||||
class AsyncEntity(MockEntity):
|
||||
"""Mock entity that has async_update."""
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_parallel_updates_async_platform_with_constant(hass):
|
||||
"""Warn we log when platform setup takes a long time."""
|
||||
async def async_update(self):
|
||||
pass
|
||||
|
||||
entity = AsyncEntity()
|
||||
await handle.async_add_entities([entity])
|
||||
assert entity.parallel_updates is None
|
||||
|
||||
|
||||
async def test_parallel_updates_async_platform_with_constant(hass):
|
||||
"""Test async platform can set parallel_updates limit."""
|
||||
platform = MockPlatform()
|
||||
platform.PARALLEL_UPDATES = 2
|
||||
|
||||
loader.set_component(hass, 'test_domain.platform', platform)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component._platforms = {}
|
||||
|
||||
await component.async_setup({
|
||||
DOMAIN: {
|
||||
'platform': 'platform',
|
||||
}
|
||||
})
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
|
||||
assert handle.parallel_updates == 2
|
||||
|
||||
class AsyncEntity(MockEntity):
|
||||
"""Mock entity that has async_update."""
|
||||
|
||||
async def async_update(self):
|
||||
pass
|
||||
|
||||
entity = AsyncEntity()
|
||||
await handle.async_add_entities([entity])
|
||||
assert entity.parallel_updates is not None
|
||||
assert entity.parallel_updates._value == 2
|
||||
|
||||
|
||||
async def test_parallel_updates_sync_platform(hass):
|
||||
"""Test sync platform parallel_updates default set to 1."""
|
||||
platform = MockPlatform()
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_update(*args, **kwargs):
|
||||
pass
|
||||
|
||||
platform.async_setup_platform = mock_update
|
||||
platform.PARALLEL_UPDATES = 1
|
||||
|
||||
loader.set_component(hass, 'test_domain.platform', platform)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component._platforms = {}
|
||||
|
||||
yield from component.async_setup({
|
||||
await component.async_setup({
|
||||
DOMAIN: {
|
||||
'platform': 'platform',
|
||||
}
|
||||
})
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
assert handle.parallel_updates is None
|
||||
|
||||
assert handle.parallel_updates is not None
|
||||
class SyncEntity(MockEntity):
|
||||
"""Mock entity that has update."""
|
||||
|
||||
async def update(self):
|
||||
pass
|
||||
|
||||
entity = SyncEntity()
|
||||
await handle.async_add_entities([entity])
|
||||
assert entity.parallel_updates is not None
|
||||
assert entity.parallel_updates._value == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_parallel_updates_sync_platform(hass):
|
||||
"""Warn we log when platform setup takes a long time."""
|
||||
platform = MockPlatform(setup_platform=lambda *args: None)
|
||||
async def test_parallel_updates_sync_platform_with_constant(hass):
|
||||
"""Test sync platform can set parallel_updates limit."""
|
||||
platform = MockPlatform()
|
||||
platform.PARALLEL_UPDATES = 2
|
||||
|
||||
loader.set_component(hass, 'test_domain.platform', platform)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component._platforms = {}
|
||||
|
||||
yield from component.async_setup({
|
||||
await component.async_setup({
|
||||
DOMAIN: {
|
||||
'platform': 'platform',
|
||||
}
|
||||
})
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
assert handle.parallel_updates == 2
|
||||
|
||||
assert handle.parallel_updates is not None
|
||||
class SyncEntity(MockEntity):
|
||||
"""Mock entity that has update."""
|
||||
|
||||
async def update(self):
|
||||
pass
|
||||
|
||||
entity = SyncEntity()
|
||||
await handle.async_add_entities([entity])
|
||||
assert entity.parallel_updates is not None
|
||||
assert entity.parallel_updates._value == 2
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
Loading…
x
Reference in New Issue
Block a user