mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Speed up entity service calls (#94731)
* Speed up entity service calls - Avoid permissions check if the caller is an admin - Use set intersection instead of linear search of entity platforms to find entities * tweak * fix light test to not use an admin user
This commit is contained in:
parent
3778e1cd77
commit
68cf796be8
@ -680,15 +680,13 @@ async def entity_service_call( # noqa: C901
|
|||||||
|
|
||||||
Calls all platforms simultaneously.
|
Calls all platforms simultaneously.
|
||||||
"""
|
"""
|
||||||
|
entity_perms: None | (Callable[[str, str], bool]) = None
|
||||||
if call.context.user_id:
|
if call.context.user_id:
|
||||||
user = await hass.auth.async_get_user(call.context.user_id)
|
user = await hass.auth.async_get_user(call.context.user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise UnknownUser(context=call.context)
|
raise UnknownUser(context=call.context)
|
||||||
entity_perms: None | (
|
if not user.is_admin:
|
||||||
Callable[[str, str], bool]
|
entity_perms = user.permissions.check_entity
|
||||||
) = user.permissions.check_entity
|
|
||||||
else:
|
|
||||||
entity_perms = None
|
|
||||||
|
|
||||||
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
|
target_all_entities = call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL
|
||||||
|
|
||||||
@ -714,15 +712,15 @@ async def entity_service_call( # noqa: C901
|
|||||||
|
|
||||||
if entity_perms is None:
|
if entity_perms is None:
|
||||||
for platform in platforms:
|
for platform in platforms:
|
||||||
|
platform_entities = platform.entities
|
||||||
if target_all_entities:
|
if target_all_entities:
|
||||||
entity_candidates.extend(platform.entities.values())
|
entity_candidates.extend(platform_entities.values())
|
||||||
else:
|
else:
|
||||||
assert all_referenced is not None
|
assert all_referenced is not None
|
||||||
entity_candidates.extend(
|
entity_candidates.extend(
|
||||||
[
|
[
|
||||||
entity
|
platform_entities[entity_id]
|
||||||
for entity in platform.entities.values()
|
for entity_id in all_referenced.intersection(platform_entities)
|
||||||
if entity.entity_id in all_referenced
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -742,21 +740,20 @@ async def entity_service_call( # noqa: C901
|
|||||||
assert all_referenced is not None
|
assert all_referenced is not None
|
||||||
|
|
||||||
for platform in platforms:
|
for platform in platforms:
|
||||||
platform_entities = []
|
platform_entities = platform.entities
|
||||||
for entity in platform.entities.values():
|
platform_entity_candidates = []
|
||||||
if entity.entity_id not in all_referenced:
|
entity_id_matches = all_referenced.intersection(platform_entities)
|
||||||
continue
|
for entity_id in entity_id_matches:
|
||||||
|
if not entity_perms(entity_id, POLICY_CONTROL):
|
||||||
if not entity_perms(entity.entity_id, POLICY_CONTROL):
|
|
||||||
raise Unauthorized(
|
raise Unauthorized(
|
||||||
context=call.context,
|
context=call.context,
|
||||||
entity_id=entity.entity_id,
|
entity_id=entity_id,
|
||||||
permission=POLICY_CONTROL,
|
permission=POLICY_CONTROL,
|
||||||
)
|
)
|
||||||
|
|
||||||
platform_entities.append(entity)
|
platform_entity_candidates.append(platform_entities[entity_id])
|
||||||
|
|
||||||
entity_candidates.extend(platform_entities)
|
entity_candidates.extend(platform_entity_candidates)
|
||||||
|
|
||||||
if not target_all_entities:
|
if not target_all_entities:
|
||||||
assert referenced is not None
|
assert referenced is not None
|
||||||
@ -769,7 +766,7 @@ async def entity_service_call( # noqa: C901
|
|||||||
|
|
||||||
referenced.log_missing(missing)
|
referenced.log_missing(missing)
|
||||||
|
|
||||||
entities = []
|
entities: list[Entity] = []
|
||||||
|
|
||||||
for entity in entity_candidates:
|
for entity in entity_candidates:
|
||||||
if not entity.available:
|
if not entity.available:
|
||||||
@ -810,7 +807,7 @@ async def entity_service_call( # noqa: C901
|
|||||||
for future in done:
|
for future in done:
|
||||||
future.result() # pop exception if have
|
future.result() # pop exception if have
|
||||||
|
|
||||||
tasks = []
|
tasks: list[asyncio.Task[None]] = []
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if not entity.should_poll:
|
if not entity.should_poll:
|
||||||
|
@ -872,7 +872,7 @@ async def test_light_context(
|
|||||||
|
|
||||||
|
|
||||||
async def test_light_turn_on_auth(
|
async def test_light_turn_on_auth(
|
||||||
hass: HomeAssistant, hass_admin_user: MockUser, enable_custom_integrations: None
|
hass: HomeAssistant, hass_read_only_user: MockUser, enable_custom_integrations: None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that light context works."""
|
"""Test that light context works."""
|
||||||
platform = getattr(hass.components, "test.light")
|
platform = getattr(hass.components, "test.light")
|
||||||
@ -883,7 +883,7 @@ async def test_light_turn_on_auth(
|
|||||||
state = hass.states.get("light.ceiling")
|
state = hass.states.get("light.ceiling")
|
||||||
assert state is not None
|
assert state is not None
|
||||||
|
|
||||||
hass_admin_user.mock_policy({})
|
hass_read_only_user.mock_policy({})
|
||||||
|
|
||||||
with pytest.raises(Unauthorized):
|
with pytest.raises(Unauthorized):
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
@ -891,7 +891,7 @@ async def test_light_turn_on_auth(
|
|||||||
"turn_on",
|
"turn_on",
|
||||||
{"entity_id": state.entity_id},
|
{"entity_id": state.entity_id},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
context=core.Context(user_id=hass_admin_user.id),
|
context=core.Context(user_id=hass_read_only_user.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -714,7 +714,8 @@ async def test_call_context_target_all(
|
|||||||
return_value=Mock(
|
return_value=Mock(
|
||||||
permissions=PolicyPermissions(
|
permissions=PolicyPermissions(
|
||||||
{"entities": {"entity_ids": {"light.kitchen": True}}}, None
|
{"entities": {"entity_ids": {"light.kitchen": True}}}, None
|
||||||
)
|
),
|
||||||
|
is_admin=False,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
await service.entity_service_call(
|
await service.entity_service_call(
|
||||||
@ -767,7 +768,7 @@ async def test_call_context_target_specific_no_auth(
|
|||||||
"""Check targeting specific entities without auth."""
|
"""Check targeting specific entities without auth."""
|
||||||
with pytest.raises(exceptions.Unauthorized) as err, patch(
|
with pytest.raises(exceptions.Unauthorized) as err, patch(
|
||||||
"homeassistant.auth.AuthManager.async_get_user",
|
"homeassistant.auth.AuthManager.async_get_user",
|
||||||
return_value=Mock(permissions=PolicyPermissions({}, None)),
|
return_value=Mock(permissions=PolicyPermissions({}, None), is_admin=False),
|
||||||
):
|
):
|
||||||
await service.entity_service_call(
|
await service.entity_service_call(
|
||||||
hass,
|
hass,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user