Add a timeout for async_add_entities (#38474)

This commit is contained in:
J. Nick Koston 2020-08-05 09:06:21 -07:00 committed by GitHub
parent d66ddeb69e
commit 7590af3930
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 3 deletions

View File

@ -1,5 +1,6 @@
"""Class to manage the entities for a single platform.""" """Class to manage the entities for a single platform."""
import asyncio import asyncio
from contextlib import suppress
from contextvars import ContextVar from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
from logging import Logger from logging import Logger
@ -23,6 +24,8 @@ if TYPE_CHECKING:
SLOW_SETUP_WARNING = 10 SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60 SLOW_SETUP_MAX_WAIT = 60
SLOW_ADD_ENTITIES_MAX_WAIT = 60
PLATFORM_NOT_READY_RETRIES = 10 PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform" DATA_ENTITY_PLATFORM = "entity_platform"
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
@ -282,9 +285,11 @@ class EntityPlatform:
device_registry = await hass.helpers.device_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry()
entity_registry = await hass.helpers.entity_registry.async_get_registry() entity_registry = await hass.helpers.entity_registry.async_get_registry()
tasks = [ tasks = [
asyncio.create_task(
self._async_add_entity( # type: ignore self._async_add_entity( # type: ignore
entity, update_before_add, entity_registry, device_registry entity, update_before_add, entity_registry, device_registry
) )
)
for entity in new_entities for entity in new_entities
] ]
@ -292,7 +297,24 @@ class EntityPlatform:
if not tasks: if not tasks:
return return
await asyncio.gather(*tasks) await asyncio.wait(tasks, timeout=SLOW_ADD_ENTITIES_MAX_WAIT)
for idx, entity in enumerate(new_entities):
task = tasks[idx]
if task.done():
await task
continue
self.logger.warning(
"Timed out adding entity %s for domain %s with platform %s after %ds.",
entity.entity_id,
self.domain,
self.platform_name,
SLOW_ADD_ENTITIES_MAX_WAIT,
)
task.cancel()
with suppress(asyncio.CancelledError):
await task
if self._async_unsub_polling is not None or not any( if self._async_unsub_polling is not None or not any(
entity.should_poll for entity in self.entities.values() entity.should_poll for entity in self.entities.values()

View File

@ -931,3 +931,39 @@ async def test_invalid_entity_id(hass):
await platform.async_add_entities([entity]) await platform.async_add_entities([entity])
assert entity.hass is None assert entity.hass is None
assert entity.platform is None assert entity.platform is None
class MockBlockingEntity(MockEntity):
"""Class to mock an entity that will block adding entities."""
async def async_added_to_hass(self):
"""Block for a long time."""
await asyncio.sleep(1000)
async def test_setup_entry_with_entities_that_block_forever(hass, caplog):
"""Test we cancel adding entities when we reach the timeout."""
registry = mock_registry(hass)
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities([MockBlockingEntity(name="test1", unique_id="unique")])
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)
config_entry = MockConfigEntry(entry_id="super-mock-id")
mock_entity_platform = MockEntityPlatform(
hass, platform_name=config_entry.domain, platform=platform
)
with patch.object(entity_platform, "SLOW_ADD_ENTITIES_MAX_WAIT", 0.01):
assert await mock_entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
full_name = f"{mock_entity_platform.domain}.{config_entry.domain}"
assert full_name in hass.config.components
assert len(hass.states.async_entity_ids()) == 0
assert len(registry.entities) == 1
assert "Timed out adding entity" in caplog.text
assert "test_domain.test1" in caplog.text
assert "test_domain" in caplog.text
assert "test" in caplog.text