mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 08:47:10 +00:00
Avoid scheduling a task to add each entity when not using update_before_add (#110951)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
3aecec5082
commit
d9addc45f9
@ -346,11 +346,11 @@ class EntityPlatform:
|
|||||||
|
|
||||||
# Block till all entities are done
|
# Block till all entities are done
|
||||||
while self._tasks:
|
while self._tasks:
|
||||||
pending = [task for task in self._tasks if not task.done()]
|
# Await all tasks even if they are done
|
||||||
|
# to ensure exceptions are propagated
|
||||||
|
pending = self._tasks.copy()
|
||||||
self._tasks.clear()
|
self._tasks.clear()
|
||||||
|
await asyncio.gather(*pending)
|
||||||
if pending:
|
|
||||||
await asyncio.gather(*pending)
|
|
||||||
|
|
||||||
hass.config.components.add(full_name)
|
hass.config.components.add(full_name)
|
||||||
self._setup_complete = True
|
self._setup_complete = True
|
||||||
@ -505,6 +505,82 @@ class EntityPlatform:
|
|||||||
self.hass.loop,
|
self.hass.loop,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
|
async def _async_add_and_update_entities(
|
||||||
|
self,
|
||||||
|
coros: list[Coroutine[Any, Any, None]],
|
||||||
|
entities: list[Entity],
|
||||||
|
timeout: float,
|
||||||
|
) -> None:
|
||||||
|
"""Add entities for a single platform and update them.
|
||||||
|
|
||||||
|
Since we are updating the entities before adding them, we need to
|
||||||
|
schedule the coroutines as tasks so we can await them in the event
|
||||||
|
loop. This is because the update is likely to yield control to the
|
||||||
|
event loop and will finish faster if we run them concurrently.
|
||||||
|
"""
|
||||||
|
results: list[BaseException | None] | None = None
|
||||||
|
try:
|
||||||
|
async with self.hass.timeout.async_timeout(timeout, self.domain):
|
||||||
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
||||||
|
except TimeoutError:
|
||||||
|
self.logger.warning(
|
||||||
|
"Timed out adding entities for domain %s with platform %s after %ds",
|
||||||
|
self.domain,
|
||||||
|
self.platform_name,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
entity = entities[idx]
|
||||||
|
self.logger.exception(
|
||||||
|
"Error adding entity %s for domain %s with platform %s",
|
||||||
|
entity.entity_id,
|
||||||
|
self.domain,
|
||||||
|
self.platform_name,
|
||||||
|
exc_info=result,
|
||||||
|
)
|
||||||
|
elif isinstance(result, BaseException):
|
||||||
|
raise result
|
||||||
|
|
||||||
|
async def _async_add_entities(
|
||||||
|
self,
|
||||||
|
coros: list[Coroutine[Any, Any, None]],
|
||||||
|
entities: list[Entity],
|
||||||
|
timeout: float,
|
||||||
|
) -> None:
|
||||||
|
"""Add entities for a single platform without updating.
|
||||||
|
|
||||||
|
In this case we are not updating the entities before adding them
|
||||||
|
which means its unlikely that we will not have to yield control
|
||||||
|
to the event loop so we can await the coros directly without
|
||||||
|
scheduling them as tasks.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with self.hass.timeout.async_timeout(timeout, self.domain):
|
||||||
|
for idx, coro in enumerate(coros):
|
||||||
|
try:
|
||||||
|
await coro
|
||||||
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
|
entity = entities[idx]
|
||||||
|
self.logger.exception(
|
||||||
|
"Error adding entity %s for domain %s with platform %s",
|
||||||
|
entity.entity_id,
|
||||||
|
self.domain,
|
||||||
|
self.platform_name,
|
||||||
|
exc_info=ex,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
self.logger.warning(
|
||||||
|
"Timed out adding entities for domain %s with platform %s after %ds",
|
||||||
|
self.domain,
|
||||||
|
self.platform_name,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_add_entities(
|
async def async_add_entities(
|
||||||
self, new_entities: Iterable[Entity], update_before_add: bool = False
|
self, new_entities: Iterable[Entity], update_before_add: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -517,40 +593,31 @@ class EntityPlatform:
|
|||||||
return
|
return
|
||||||
|
|
||||||
hass = self.hass
|
hass = self.hass
|
||||||
|
|
||||||
entity_registry = ent_reg.async_get(hass)
|
entity_registry = ent_reg.async_get(hass)
|
||||||
tasks = [
|
coros: list[Coroutine[Any, Any, None]] = []
|
||||||
self._async_add_entity(entity, update_before_add, entity_registry)
|
entities: list[Entity] = []
|
||||||
for entity in new_entities
|
for entity in new_entities:
|
||||||
]
|
coros.append(
|
||||||
|
self._async_add_entity(entity, update_before_add, entity_registry)
|
||||||
|
)
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
# No entities for processing
|
# No entities for processing
|
||||||
if not tasks:
|
if not coros:
|
||||||
return
|
return
|
||||||
|
|
||||||
timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(tasks), SLOW_ADD_MIN_TIMEOUT)
|
timeout = max(SLOW_ADD_ENTITY_MAX_WAIT * len(coros), SLOW_ADD_MIN_TIMEOUT)
|
||||||
try:
|
if update_before_add:
|
||||||
async with self.hass.timeout.async_timeout(timeout, self.domain):
|
add_func = self._async_add_and_update_entities
|
||||||
await asyncio.gather(*tasks)
|
else:
|
||||||
except TimeoutError:
|
add_func = self._async_add_entities
|
||||||
self.logger.warning(
|
|
||||||
"Timed out adding entities for domain %s with platform %s after %ds",
|
await add_func(coros, entities, timeout)
|
||||||
self.domain,
|
|
||||||
self.platform_name,
|
|
||||||
timeout,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
self.logger.exception(
|
|
||||||
"Error adding entities for domain %s with platform %s",
|
|
||||||
self.domain,
|
|
||||||
self.platform_name,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(self.config_entry and self.config_entry.pref_disable_polling)
|
(self.config_entry and self.config_entry.pref_disable_polling)
|
||||||
or self._async_unsub_polling is not None
|
or self._async_unsub_polling is not None
|
||||||
or not any(entity.should_poll for entity in self.entities.values())
|
or not any(entity.should_poll for entity in entities)
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -150,6 +150,7 @@ async def test_state_changed_event_sends_message(
|
|||||||
platform = MockEntityPlatform(hass)
|
platform = MockEntityPlatform(hass)
|
||||||
entity = MockEntity(unique_id="1234")
|
entity = MockEntity(unique_id="1234")
|
||||||
await platform.async_add_entities([entity])
|
await platform.async_add_entities([entity])
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
mqtt_mock.async_publish.assert_called_with(
|
mqtt_mock.async_publish.assert_called_with(
|
||||||
"pub/test_domain/test_platform_1234/state", "unknown", 1, True
|
"pub/test_domain/test_platform_1234/state", "unknown", 1, True
|
||||||
|
@ -1992,7 +1992,7 @@ async def test_non_numeric_validation_raise(
|
|||||||
state = hass.states.get(entity0.entity_id)
|
state = hass.states.get(entity0.entity_id)
|
||||||
assert state is None
|
assert state is None
|
||||||
|
|
||||||
assert ("Error adding entities for domain sensor with platform test") in caplog.text
|
assert ("for domain sensor with platform test") in caplog.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -1747,22 +1747,26 @@ async def test_suggest_report_issue_custom_component(
|
|||||||
assert suggestion == "create a bug report at https://some_url"
|
assert suggestion == "create a bug report at https://some_url"
|
||||||
|
|
||||||
|
|
||||||
async def test_reuse_entity_object_after_abort(hass: HomeAssistant) -> None:
|
async def test_reuse_entity_object_after_abort(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
"""Test reuse entity object."""
|
"""Test reuse entity object."""
|
||||||
platform = MockEntityPlatform(hass, domain="test")
|
platform = MockEntityPlatform(hass, domain="test")
|
||||||
ent = entity.Entity()
|
ent = entity.Entity()
|
||||||
ent.entity_id = "invalid"
|
ent.entity_id = "invalid"
|
||||||
with pytest.raises(HomeAssistantError, match="Invalid entity ID: invalid"):
|
await platform.async_add_entities([ent])
|
||||||
await platform.async_add_entities([ent])
|
assert "Invalid entity ID: invalid" in caplog.text
|
||||||
with pytest.raises(
|
await platform.async_add_entities([ent])
|
||||||
HomeAssistantError,
|
assert (
|
||||||
match="Entity 'invalid' cannot be added a second time to an entity platform",
|
"Entity 'invalid' cannot be added a second time to an entity platform"
|
||||||
):
|
in caplog.text
|
||||||
await platform.async_add_entities([ent])
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_reuse_entity_object_after_entity_registry_remove(
|
async def test_reuse_entity_object_after_entity_registry_remove(
|
||||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
hass: HomeAssistant,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test reuse entity object."""
|
"""Test reuse entity object."""
|
||||||
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
||||||
@ -1777,15 +1781,15 @@ async def test_reuse_entity_object_after_entity_registry_remove(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(hass.states.async_entity_ids()) == 0
|
assert len(hass.states.async_entity_ids()) == 0
|
||||||
|
|
||||||
with pytest.raises(
|
await platform.async_add_entities([ent])
|
||||||
HomeAssistantError,
|
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
|
||||||
match="Entity 'test.test_5678' cannot be added a second time",
|
assert len(hass.states.async_entity_ids()) == 0
|
||||||
):
|
|
||||||
await platform.async_add_entities([ent])
|
|
||||||
|
|
||||||
|
|
||||||
async def test_reuse_entity_object_after_entity_registry_disabled(
|
async def test_reuse_entity_object_after_entity_registry_disabled(
|
||||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
hass: HomeAssistant,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test reuse entity object."""
|
"""Test reuse entity object."""
|
||||||
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
||||||
@ -1802,11 +1806,9 @@ async def test_reuse_entity_object_after_entity_registry_disabled(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(hass.states.async_entity_ids()) == 0
|
assert len(hass.states.async_entity_ids()) == 0
|
||||||
|
|
||||||
with pytest.raises(
|
await platform.async_add_entities([ent])
|
||||||
HomeAssistantError,
|
assert len(hass.states.async_entity_ids()) == 0
|
||||||
match="Entity 'test.test_5678' cannot be added a second time",
|
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
|
||||||
):
|
|
||||||
await platform.async_add_entities([ent])
|
|
||||||
|
|
||||||
|
|
||||||
async def test_change_entity_id(
|
async def test_change_entity_id(
|
||||||
|
@ -1710,14 +1710,23 @@ async def test_register_entity_service_limited_to_matching_platforms(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_entity_id(hass: HomeAssistant) -> None:
|
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||||
|
async def test_invalid_entity_id(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, update_before_add: bool
|
||||||
|
) -> None:
|
||||||
"""Test specifying an invalid entity id."""
|
"""Test specifying an invalid entity id."""
|
||||||
platform = MockEntityPlatform(hass)
|
platform = MockEntityPlatform(hass)
|
||||||
entity = MockEntity(entity_id="invalid_entity_id")
|
entity = MockEntity(entity_id="invalid_entity_id")
|
||||||
with pytest.raises(HomeAssistantError):
|
entity2 = MockEntity(entity_id="valid.entity_id")
|
||||||
await platform.async_add_entities([entity])
|
await platform.async_add_entities(
|
||||||
|
[entity, entity2], update_before_add=update_before_add
|
||||||
|
)
|
||||||
assert entity.hass is None
|
assert entity.hass is None
|
||||||
assert entity.platform is None
|
assert entity.platform is None
|
||||||
|
assert "Invalid entity ID: invalid_entity_id" in caplog.text
|
||||||
|
# Ensure the valid entity was still added
|
||||||
|
assert entity2.hass is not None
|
||||||
|
assert entity2.platform is not None
|
||||||
|
|
||||||
|
|
||||||
class MockBlockingEntity(MockEntity):
|
class MockBlockingEntity(MockEntity):
|
||||||
@ -1728,16 +1737,21 @@ class MockBlockingEntity(MockEntity):
|
|||||||
await asyncio.sleep(1000)
|
await asyncio.sleep(1000)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||||
async def test_setup_entry_with_entities_that_block_forever(
|
async def test_setup_entry_with_entities_that_block_forever(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
|
update_before_add: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we cancel adding entities when we reach the timeout."""
|
"""Test we cancel adding entities when we reach the timeout."""
|
||||||
|
|
||||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
"""Mock setup entry method."""
|
"""Mock setup entry method."""
|
||||||
async_add_entities([MockBlockingEntity(name="test1", unique_id="unique")])
|
async_add_entities(
|
||||||
|
[MockBlockingEntity(name="test1", unique_id="unique")],
|
||||||
|
update_before_add=update_before_add,
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||||
@ -1761,7 +1775,47 @@ async def test_setup_entry_with_entities_that_block_forever(
|
|||||||
assert "test" in caplog.text
|
assert "test" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None:
|
class MockCancellingEntity(MockEntity):
|
||||||
|
"""Class to mock an entity get cancelled while adding."""
|
||||||
|
|
||||||
|
async def async_added_to_hass(self):
|
||||||
|
"""Mock cancellation."""
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||||
|
async def test_cancellation_is_not_blocked(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
update_before_add: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Test cancellation is not blocked while adding entities."""
|
||||||
|
|
||||||
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
|
"""Mock setup entry method."""
|
||||||
|
async_add_entities(
|
||||||
|
[MockCancellingEntity(name="test1", unique_id="unique")],
|
||||||
|
update_before_add=update_before_add,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||||
|
config_entry = MockConfigEntry(entry_id="super-mock-id")
|
||||||
|
platform = MockEntityPlatform(
|
||||||
|
hass, platform_name=config_entry.domain, platform=platform
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
assert await platform.async_setup_entry(config_entry)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
full_name = f"{config_entry.domain}.{platform.domain}"
|
||||||
|
assert full_name not in hass.config.components
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||||
|
async def test_two_platforms_add_same_entity(
|
||||||
|
hass: HomeAssistant, update_before_add: bool
|
||||||
|
) -> None:
|
||||||
"""Test two platforms in the same domain adding an entity with the same name."""
|
"""Test two platforms in the same domain adding an entity with the same name."""
|
||||||
entity_platform1 = MockEntityPlatform(
|
entity_platform1 = MockEntityPlatform(
|
||||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
@ -1774,8 +1828,12 @@ async def test_two_platforms_add_same_entity(hass: HomeAssistant) -> None:
|
|||||||
entity2 = SlowEntity(name="entity_1")
|
entity2 = SlowEntity(name="entity_1")
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
entity_platform1.async_add_entities([entity1]),
|
entity_platform1.async_add_entities(
|
||||||
entity_platform2.async_add_entities([entity2]),
|
[entity1], update_before_add=update_before_add
|
||||||
|
),
|
||||||
|
entity_platform2.async_add_entities(
|
||||||
|
[entity2], update_before_add=update_before_add
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
entities = []
|
entities = []
|
||||||
@ -1816,12 +1874,14 @@ class SlowEntity(MockEntity):
|
|||||||
(True, None, "test_domain.device_bla"),
|
(True, None, "test_domain.device_bla"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||||
async def test_entity_name_influences_entity_id(
|
async def test_entity_name_influences_entity_id(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
has_entity_name: bool,
|
has_entity_name: bool,
|
||||||
entity_name: str | None,
|
entity_name: str | None,
|
||||||
expected_entity_id: str,
|
expected_entity_id: str,
|
||||||
|
update_before_add: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test entity_id is influenced by entity name."""
|
"""Test entity_id is influenced by entity name."""
|
||||||
|
|
||||||
@ -1839,7 +1899,8 @@ async def test_entity_name_influences_entity_id(
|
|||||||
has_entity_name=has_entity_name,
|
has_entity_name=has_entity_name,
|
||||||
name=entity_name,
|
name=entity_name,
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
|
update_before_add=update_before_add,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -1867,12 +1928,14 @@ async def test_entity_name_influences_entity_id(
|
|||||||
("cn", True, "test_domain.device_bla_english_name"),
|
("cn", True, "test_domain.device_bla_english_name"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("update_before_add", (True, False))
|
||||||
async def test_translated_entity_name_influences_entity_id(
|
async def test_translated_entity_name_influences_entity_id(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
language: str,
|
language: str,
|
||||||
has_entity_name: bool,
|
has_entity_name: bool,
|
||||||
expected_entity_id: str,
|
expected_entity_id: str,
|
||||||
|
update_before_add: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test entity_id is influenced by translated entity name."""
|
"""Test entity_id is influenced by translated entity name."""
|
||||||
|
|
||||||
@ -1909,7 +1972,9 @@ async def test_translated_entity_name_influences_entity_id(
|
|||||||
|
|
||||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
"""Mock setup entry method."""
|
"""Mock setup entry method."""
|
||||||
async_add_entities([TranslatedEntity(has_entity_name)])
|
async_add_entities(
|
||||||
|
[TranslatedEntity(has_entity_name)], update_before_add=update_before_add
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user