Speed up entity service calls (#94731)

* Speed up entity service calls

- Avoid permissions check if the caller is an admin
- Use set intersection instead of linear search of entity platforms to find entities

* tweak

* fix light test to not use an admin user
This commit is contained in:
J. Nick Koston 2023-06-16 20:07:57 -05:00 committed by GitHub
parent 3778e1cd77
commit 68cf796be8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 25 deletions

View File

@ -680,15 +680,13 @@ async def entity_service_call( # noqa: C901
Calls all platforms simultaneously. Calls all platforms simultaneously.
""" """
entity_perms: None | (Callable[[str, str], bool]) = None
if call.context.user_id: if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id) user = await hass.auth.async_get_user(call.context.user_id)
if user is None: if user is None:
raise UnknownUser(context=call.context) raise UnknownUser(context=call.context)
entity_perms: None | ( if not user.is_admin:
Callable[[str, str], bool] entity_perms = user.permissions.check_entity
) = user.permissions.check_entity
else:
entity_perms = None
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
@ -714,15 +712,15 @@ async def entity_service_call( # noqa: C901
if entity_perms is None: if entity_perms is None:
for platform in platforms: for platform in platforms:
platform_entities = platform.entities
if target_all_entities: if target_all_entities:
entity_candidates.extend(platform.entities.values()) entity_candidates.extend(platform_entities.values())
else: else:
assert all_referenced is not None assert all_referenced is not None
entity_candidates.extend( entity_candidates.extend(
[ [
entity platform_entities[entity_id]
for entity in platform.entities.values() for entity_id in all_referenced.intersection(platform_entities)
if entity.entity_id in all_referenced
] ]
) )
@ -742,21 +740,20 @@ async def entity_service_call( # noqa: C901
assert all_referenced is not None assert all_referenced is not None
for platform in platforms: for platform in platforms:
platform_entities = [] platform_entities = platform.entities
for entity in platform.entities.values(): platform_entity_candidates = []
if entity.entity_id not in all_referenced: entity_id_matches = all_referenced.intersection(platform_entities)
continue for entity_id in entity_id_matches:
if not entity_perms(entity_id, POLICY_CONTROL):
if not entity_perms(entity.entity_id, POLICY_CONTROL):
raise Unauthorized( raise Unauthorized(
context=call.context, context=call.context,
entity_id=entity.entity_id, entity_id=entity_id,
permission=POLICY_CONTROL, permission=POLICY_CONTROL,
) )
platform_entities.append(entity) platform_entity_candidates.append(platform_entities[entity_id])
entity_candidates.extend(platform_entities) entity_candidates.extend(platform_entity_candidates)
if not target_all_entities: if not target_all_entities:
assert referenced is not None assert referenced is not None
@ -769,7 +766,7 @@ async def entity_service_call( # noqa: C901
referenced.log_missing(missing) referenced.log_missing(missing)
entities = [] entities: list[Entity] = []
for entity in entity_candidates: for entity in entity_candidates:
if not entity.available: if not entity.available:
@ -810,7 +807,7 @@ async def entity_service_call( # noqa: C901
for future in done: for future in done:
future.result() # pop exception if have future.result() # pop exception if have
tasks = [] tasks: list[asyncio.Task[None]] = []
for entity in entities: for entity in entities:
if not entity.should_poll: if not entity.should_poll:

View File

@ -872,7 +872,7 @@ async def test_light_context(
async def test_light_turn_on_auth( async def test_light_turn_on_auth(
hass: HomeAssistant, hass_admin_user: MockUser, enable_custom_integrations: None hass: HomeAssistant, hass_read_only_user: MockUser, enable_custom_integrations: None
) -> None: ) -> None:
"""Test that light context works.""" """Test that light context works."""
platform = getattr(hass.components, "test.light") platform = getattr(hass.components, "test.light")
@ -883,7 +883,7 @@ async def test_light_turn_on_auth(
state = hass.states.get("light.ceiling") state = hass.states.get("light.ceiling")
assert state is not None assert state is not None
hass_admin_user.mock_policy({}) hass_read_only_user.mock_policy({})
with pytest.raises(Unauthorized): with pytest.raises(Unauthorized):
await hass.services.async_call( await hass.services.async_call(
@ -891,7 +891,7 @@ async def test_light_turn_on_auth(
"turn_on", "turn_on",
{"entity_id": state.entity_id}, {"entity_id": state.entity_id},
blocking=True, blocking=True,
context=core.Context(user_id=hass_admin_user.id), context=core.Context(user_id=hass_read_only_user.id),
) )

View File

@ -714,7 +714,8 @@ async def test_call_context_target_all(
return_value=Mock( return_value=Mock(
permissions=PolicyPermissions( permissions=PolicyPermissions(
{"entities": {"entity_ids": {"light.kitchen": True}}}, None {"entities": {"entity_ids": {"light.kitchen": True}}}, None
) ),
is_admin=False,
), ),
): ):
await service.entity_service_call( await service.entity_service_call(
@ -767,7 +768,7 @@ async def test_call_context_target_specific_no_auth(
"""Check targeting specific entities without auth.""" """Check targeting specific entities without auth."""
with pytest.raises(exceptions.Unauthorized) as err, patch( with pytest.raises(exceptions.Unauthorized) as err, patch(
"homeassistant.auth.AuthManager.async_get_user", "homeassistant.auth.AuthManager.async_get_user",
return_value=Mock(permissions=PolicyPermissions({}, None)), return_value=Mock(permissions=PolicyPermissions({}, None), is_admin=False),
): ):
await service.entity_service_call( await service.entity_service_call(
hass, hass,