mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 10:17:51 +00:00
Set default parallel_update value should base on async_update (#22149)
* Set default parallel_update value should base on async_update * Set default parallel_update value should base on async_update * Delay the parallel_update_semaphore creation * Remove outdated comment
This commit is contained in:
parent
a62c116959
commit
e85b089eff
@ -27,7 +27,6 @@ class EntityPlatform:
|
|||||||
domain: str
|
domain: str
|
||||||
platform_name: str
|
platform_name: str
|
||||||
scan_interval: timedelta
|
scan_interval: timedelta
|
||||||
parallel_updates: int
|
|
||||||
entity_namespace: str
|
entity_namespace: str
|
||||||
async_entities_added_callback: @callback method
|
async_entities_added_callback: @callback method
|
||||||
"""
|
"""
|
||||||
@ -52,22 +51,21 @@ class EntityPlatform:
|
|||||||
# which powers entity_component.add_entities
|
# which powers entity_component.add_entities
|
||||||
if platform is None:
|
if platform is None:
|
||||||
self.parallel_updates = None
|
self.parallel_updates = None
|
||||||
|
self.parallel_updates_semaphore = None
|
||||||
return
|
return
|
||||||
|
|
||||||
# Async platforms do all updates in parallel by default
|
self.parallel_updates = getattr(platform, 'PARALLEL_UPDATES', None)
|
||||||
if hasattr(platform, 'async_setup_platform'):
|
# semaphore will be created on demand
|
||||||
default_parallel_updates = 0
|
self.parallel_updates_semaphore = None
|
||||||
else:
|
|
||||||
default_parallel_updates = 1
|
|
||||||
|
|
||||||
parallel_updates = getattr(platform, 'PARALLEL_UPDATES',
|
def _get_parallel_updates_semaphore(self):
|
||||||
default_parallel_updates)
|
"""Get or create a semaphore for parallel updates."""
|
||||||
|
if self.parallel_updates_semaphore is None:
|
||||||
if parallel_updates:
|
self.parallel_updates_semaphore = asyncio.Semaphore(
|
||||||
self.parallel_updates = asyncio.Semaphore(
|
self.parallel_updates if self.parallel_updates else 1,
|
||||||
parallel_updates, loop=hass.loop)
|
loop=self.hass.loop
|
||||||
else:
|
)
|
||||||
self.parallel_updates = None
|
return self.parallel_updates_semaphore
|
||||||
|
|
||||||
async def async_setup(self, platform_config, discovery_info=None):
|
async def async_setup(self, platform_config, discovery_info=None):
|
||||||
"""Set up the platform from a config file."""
|
"""Set up the platform from a config file."""
|
||||||
@ -240,7 +238,22 @@ class EntityPlatform:
|
|||||||
|
|
||||||
entity.hass = self.hass
|
entity.hass = self.hass
|
||||||
entity.platform = self
|
entity.platform = self
|
||||||
entity.parallel_updates = self.parallel_updates
|
|
||||||
|
# Async entity
|
||||||
|
# PARALLEL_UPDATE == None: entity.parallel_updates = None
|
||||||
|
# PARALLEL_UPDATE == 0: entity.parallel_updates = None
|
||||||
|
# PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p)
|
||||||
|
# Sync entity
|
||||||
|
# PARALLEL_UPDATE == None: entity.parallel_updates = Semaphore(1)
|
||||||
|
# PARALLEL_UPDATE == 0: entity.parallel_updates = None
|
||||||
|
# PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p)
|
||||||
|
if hasattr(entity, 'async_update') and not self.parallel_updates:
|
||||||
|
entity.parallel_updates = None
|
||||||
|
elif (not hasattr(entity, 'async_update')
|
||||||
|
and self.parallel_updates == 0):
|
||||||
|
entity.parallel_updates = None
|
||||||
|
else:
|
||||||
|
entity.parallel_updates = self._get_parallel_updates_semaphore()
|
||||||
|
|
||||||
# Update properties before we generate the entity_id
|
# Update properties before we generate the entity_id
|
||||||
if update_before_add:
|
if update_before_add:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Test the entity helper."""
|
"""Test the entity helper."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import threading
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from unittest.mock import MagicMock, patch, PropertyMock
|
from unittest.mock import MagicMock, patch, PropertyMock
|
||||||
|
|
||||||
@ -225,11 +226,10 @@ def test_async_schedule_update_ha_state(hass):
|
|||||||
assert update_call is True
|
assert update_call is True
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_async_parallel_updates_with_zero(hass):
|
||||||
def test_async_parallel_updates_with_zero(hass):
|
|
||||||
"""Test parallel updates with 0 (disabled)."""
|
"""Test parallel updates with 0 (disabled)."""
|
||||||
updates = []
|
updates = []
|
||||||
test_lock = asyncio.Event(loop=hass.loop)
|
test_lock = asyncio.Event()
|
||||||
|
|
||||||
class AsyncEntity(entity.Entity):
|
class AsyncEntity(entity.Entity):
|
||||||
|
|
||||||
@ -239,37 +239,73 @@ def test_async_parallel_updates_with_zero(hass):
|
|||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._count = count
|
self._count = count
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_update(self):
|
||||||
def async_update(self):
|
|
||||||
"""Test update."""
|
"""Test update."""
|
||||||
updates.append(self._count)
|
updates.append(self._count)
|
||||||
yield from test_lock.wait()
|
await test_lock.wait()
|
||||||
|
|
||||||
ent_1 = AsyncEntity("sensor.test_1", 1)
|
ent_1 = AsyncEntity("sensor.test_1", 1)
|
||||||
ent_2 = AsyncEntity("sensor.test_2", 2)
|
ent_2 = AsyncEntity("sensor.test_2", 2)
|
||||||
|
|
||||||
|
try:
|
||||||
ent_1.async_schedule_update_ha_state(True)
|
ent_1.async_schedule_update_ha_state(True)
|
||||||
ent_2.async_schedule_update_ha_state(True)
|
ent_2.async_schedule_update_ha_state(True)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(updates) == 2:
|
if len(updates) >= 2:
|
||||||
break
|
break
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert len(updates) == 2
|
assert len(updates) == 2
|
||||||
assert updates == [1, 2]
|
assert updates == [1, 2]
|
||||||
|
finally:
|
||||||
test_lock.set()
|
test_lock.set()
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_async_parallel_updates_with_zero_on_sync_update(hass):
|
||||||
def test_async_parallel_updates_with_one(hass):
|
"""Test parallel updates with 0 (disabled)."""
|
||||||
|
updates = []
|
||||||
|
test_lock = threading.Event()
|
||||||
|
|
||||||
|
class AsyncEntity(entity.Entity):
|
||||||
|
|
||||||
|
def __init__(self, entity_id, count):
|
||||||
|
"""Initialize Async test entity."""
|
||||||
|
self.entity_id = entity_id
|
||||||
|
self.hass = hass
|
||||||
|
self._count = count
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""Test update."""
|
||||||
|
updates.append(self._count)
|
||||||
|
if not test_lock.wait(timeout=1):
|
||||||
|
# if timeout populate more data to fail the test
|
||||||
|
updates.append(self._count)
|
||||||
|
|
||||||
|
ent_1 = AsyncEntity("sensor.test_1", 1)
|
||||||
|
ent_2 = AsyncEntity("sensor.test_2", 2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ent_1.async_schedule_update_ha_state(True)
|
||||||
|
ent_2.async_schedule_update_ha_state(True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if len(updates) >= 2:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(updates) == 2
|
||||||
|
assert updates == [1, 2]
|
||||||
|
finally:
|
||||||
|
test_lock.set()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_parallel_updates_with_one(hass):
|
||||||
"""Test parallel updates with 1 (sequential)."""
|
"""Test parallel updates with 1 (sequential)."""
|
||||||
updates = []
|
updates = []
|
||||||
test_lock = asyncio.Lock(loop=hass.loop)
|
test_lock = asyncio.Lock()
|
||||||
test_semaphore = asyncio.Semaphore(1, loop=hass.loop)
|
test_semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
yield from test_lock.acquire()
|
|
||||||
|
|
||||||
class AsyncEntity(entity.Entity):
|
class AsyncEntity(entity.Entity):
|
||||||
|
|
||||||
@ -280,59 +316,71 @@ def test_async_parallel_updates_with_one(hass):
|
|||||||
self._count = count
|
self._count = count
|
||||||
self.parallel_updates = test_semaphore
|
self.parallel_updates = test_semaphore
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_update(self):
|
||||||
def async_update(self):
|
|
||||||
"""Test update."""
|
"""Test update."""
|
||||||
updates.append(self._count)
|
updates.append(self._count)
|
||||||
yield from test_lock.acquire()
|
await test_lock.acquire()
|
||||||
|
|
||||||
ent_1 = AsyncEntity("sensor.test_1", 1)
|
ent_1 = AsyncEntity("sensor.test_1", 1)
|
||||||
ent_2 = AsyncEntity("sensor.test_2", 2)
|
ent_2 = AsyncEntity("sensor.test_2", 2)
|
||||||
ent_3 = AsyncEntity("sensor.test_3", 3)
|
ent_3 = AsyncEntity("sensor.test_3", 3)
|
||||||
|
|
||||||
|
await test_lock.acquire()
|
||||||
|
|
||||||
|
try:
|
||||||
ent_1.async_schedule_update_ha_state(True)
|
ent_1.async_schedule_update_ha_state(True)
|
||||||
ent_2.async_schedule_update_ha_state(True)
|
ent_2.async_schedule_update_ha_state(True)
|
||||||
ent_3.async_schedule_update_ha_state(True)
|
ent_3.async_schedule_update_ha_state(True)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(updates) == 1:
|
if len(updates) >= 1:
|
||||||
break
|
break
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert len(updates) == 1
|
assert len(updates) == 1
|
||||||
assert updates == [1]
|
assert updates == [1]
|
||||||
|
|
||||||
|
updates.clear()
|
||||||
test_lock.release()
|
test_lock.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(updates) == 2:
|
if len(updates) >= 1:
|
||||||
break
|
break
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert len(updates) == 2
|
assert len(updates) == 1
|
||||||
assert updates == [1, 2]
|
assert updates == [2]
|
||||||
|
|
||||||
|
updates.clear()
|
||||||
test_lock.release()
|
test_lock.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(updates) == 3:
|
if len(updates) >= 1:
|
||||||
break
|
break
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert len(updates) == 3
|
assert len(updates) == 1
|
||||||
assert updates == [1, 2, 3]
|
assert updates == [3]
|
||||||
|
|
||||||
|
updates.clear()
|
||||||
|
test_lock.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# we may have more than one lock need to release in case test failed
|
||||||
|
for _ in updates:
|
||||||
|
test_lock.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
test_lock.release()
|
test_lock.release()
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_async_parallel_updates_with_two(hass):
|
||||||
def test_async_parallel_updates_with_two(hass):
|
|
||||||
"""Test parallel updates with 2 (parallel)."""
|
"""Test parallel updates with 2 (parallel)."""
|
||||||
updates = []
|
updates = []
|
||||||
test_lock = asyncio.Lock(loop=hass.loop)
|
test_lock = asyncio.Lock()
|
||||||
test_semaphore = asyncio.Semaphore(2, loop=hass.loop)
|
test_semaphore = asyncio.Semaphore(2)
|
||||||
|
|
||||||
yield from test_lock.acquire()
|
|
||||||
|
|
||||||
class AsyncEntity(entity.Entity):
|
class AsyncEntity(entity.Entity):
|
||||||
|
|
||||||
@ -354,33 +402,47 @@ def test_async_parallel_updates_with_two(hass):
|
|||||||
ent_3 = AsyncEntity("sensor.test_3", 3)
|
ent_3 = AsyncEntity("sensor.test_3", 3)
|
||||||
ent_4 = AsyncEntity("sensor.test_4", 4)
|
ent_4 = AsyncEntity("sensor.test_4", 4)
|
||||||
|
|
||||||
|
await test_lock.acquire()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
ent_1.async_schedule_update_ha_state(True)
|
ent_1.async_schedule_update_ha_state(True)
|
||||||
ent_2.async_schedule_update_ha_state(True)
|
ent_2.async_schedule_update_ha_state(True)
|
||||||
ent_3.async_schedule_update_ha_state(True)
|
ent_3.async_schedule_update_ha_state(True)
|
||||||
ent_4.async_schedule_update_ha_state(True)
|
ent_4.async_schedule_update_ha_state(True)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(updates) == 2:
|
if len(updates) >= 2:
|
||||||
break
|
break
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert len(updates) == 2
|
assert len(updates) == 2
|
||||||
assert updates == [1, 2]
|
assert updates == [1, 2]
|
||||||
|
|
||||||
|
updates.clear()
|
||||||
test_lock.release()
|
test_lock.release()
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
test_lock.release()
|
test_lock.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(updates) == 4:
|
if len(updates) >= 2:
|
||||||
break
|
break
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert len(updates) == 4
|
assert len(updates) == 2
|
||||||
assert updates == [1, 2, 3, 4]
|
assert updates == [3, 4]
|
||||||
|
|
||||||
|
updates.clear()
|
||||||
test_lock.release()
|
test_lock.release()
|
||||||
yield from asyncio.sleep(0, loop=hass.loop)
|
await asyncio.sleep(0)
|
||||||
|
test_lock.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
finally:
|
||||||
|
# we may have more than one lock need to release in case test failed
|
||||||
|
for _ in updates:
|
||||||
|
test_lock.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
test_lock.release()
|
test_lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
@ -251,51 +251,46 @@ def test_updated_state_used_for_entity_id(hass):
|
|||||||
assert entity_ids[0] == "test_domain.living_room"
|
assert entity_ids[0] == "test_domain.living_room"
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_parallel_updates_async_platform(hass):
|
||||||
def test_parallel_updates_async_platform(hass):
|
"""Test async platform does not have parallel_updates limit by default."""
|
||||||
"""Warn we log when platform setup takes a long time."""
|
|
||||||
platform = MockPlatform()
|
platform = MockPlatform()
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def mock_update(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
platform.async_setup_platform = mock_update
|
|
||||||
|
|
||||||
loader.set_component(hass, 'test_domain.platform', platform)
|
loader.set_component(hass, 'test_domain.platform', platform)
|
||||||
|
|
||||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||||
component._platforms = {}
|
component._platforms = {}
|
||||||
|
|
||||||
yield from component.async_setup({
|
await component.async_setup({
|
||||||
DOMAIN: {
|
DOMAIN: {
|
||||||
'platform': 'platform',
|
'platform': 'platform',
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
handle = list(component._platforms.values())[-1]
|
handle = list(component._platforms.values())[-1]
|
||||||
|
|
||||||
assert handle.parallel_updates is None
|
assert handle.parallel_updates is None
|
||||||
|
|
||||||
|
class AsyncEntity(MockEntity):
|
||||||
|
"""Mock entity that has async_update."""
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_update(self):
|
||||||
def test_parallel_updates_async_platform_with_constant(hass):
|
|
||||||
"""Warn we log when platform setup takes a long time."""
|
|
||||||
platform = MockPlatform()
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def mock_update(*args, **kwargs):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
platform.async_setup_platform = mock_update
|
entity = AsyncEntity()
|
||||||
platform.PARALLEL_UPDATES = 1
|
await handle.async_add_entities([entity])
|
||||||
|
assert entity.parallel_updates is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_parallel_updates_async_platform_with_constant(hass):
|
||||||
|
"""Test async platform can set parallel_updates limit."""
|
||||||
|
platform = MockPlatform()
|
||||||
|
platform.PARALLEL_UPDATES = 2
|
||||||
|
|
||||||
loader.set_component(hass, 'test_domain.platform', platform)
|
loader.set_component(hass, 'test_domain.platform', platform)
|
||||||
|
|
||||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||||
component._platforms = {}
|
component._platforms = {}
|
||||||
|
|
||||||
yield from component.async_setup({
|
await component.async_setup({
|
||||||
DOMAIN: {
|
DOMAIN: {
|
||||||
'platform': 'platform',
|
'platform': 'platform',
|
||||||
}
|
}
|
||||||
@ -303,28 +298,79 @@ def test_parallel_updates_async_platform_with_constant(hass):
|
|||||||
|
|
||||||
handle = list(component._platforms.values())[-1]
|
handle = list(component._platforms.values())[-1]
|
||||||
|
|
||||||
assert handle.parallel_updates is not None
|
assert handle.parallel_updates == 2
|
||||||
|
|
||||||
|
class AsyncEntity(MockEntity):
|
||||||
|
"""Mock entity that has async_update."""
|
||||||
|
|
||||||
|
async def async_update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
entity = AsyncEntity()
|
||||||
|
await handle.async_add_entities([entity])
|
||||||
|
assert entity.parallel_updates is not None
|
||||||
|
assert entity.parallel_updates._value == 2
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def test_parallel_updates_sync_platform(hass):
|
||||||
def test_parallel_updates_sync_platform(hass):
|
"""Test sync platform parallel_updates default set to 1."""
|
||||||
"""Warn we log when platform setup takes a long time."""
|
platform = MockPlatform()
|
||||||
platform = MockPlatform(setup_platform=lambda *args: None)
|
|
||||||
|
|
||||||
loader.set_component(hass, 'test_domain.platform', platform)
|
loader.set_component(hass, 'test_domain.platform', platform)
|
||||||
|
|
||||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||||
component._platforms = {}
|
component._platforms = {}
|
||||||
|
|
||||||
yield from component.async_setup({
|
await component.async_setup({
|
||||||
DOMAIN: {
|
DOMAIN: {
|
||||||
'platform': 'platform',
|
'platform': 'platform',
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
handle = list(component._platforms.values())[-1]
|
handle = list(component._platforms.values())[-1]
|
||||||
|
assert handle.parallel_updates is None
|
||||||
|
|
||||||
assert handle.parallel_updates is not None
|
class SyncEntity(MockEntity):
|
||||||
|
"""Mock entity that has update."""
|
||||||
|
|
||||||
|
async def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
entity = SyncEntity()
|
||||||
|
await handle.async_add_entities([entity])
|
||||||
|
assert entity.parallel_updates is not None
|
||||||
|
assert entity.parallel_updates._value == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_parallel_updates_sync_platform_with_constant(hass):
|
||||||
|
"""Test sync platform can set parallel_updates limit."""
|
||||||
|
platform = MockPlatform()
|
||||||
|
platform.PARALLEL_UPDATES = 2
|
||||||
|
|
||||||
|
loader.set_component(hass, 'test_domain.platform', platform)
|
||||||
|
|
||||||
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||||
|
component._platforms = {}
|
||||||
|
|
||||||
|
await component.async_setup({
|
||||||
|
DOMAIN: {
|
||||||
|
'platform': 'platform',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
handle = list(component._platforms.values())[-1]
|
||||||
|
assert handle.parallel_updates == 2
|
||||||
|
|
||||||
|
class SyncEntity(MockEntity):
|
||||||
|
"""Mock entity that has update."""
|
||||||
|
|
||||||
|
async def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
entity = SyncEntity()
|
||||||
|
await handle.async_add_entities([entity])
|
||||||
|
assert entity.parallel_updates is not None
|
||||||
|
assert entity.parallel_updates._value == 2
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
Loading…
x
Reference in New Issue
Block a user