Replace verify_domain_control with verify_domain_entity_control

This commit is contained in:
epenet 2025-06-02 15:36:18 +00:00
parent 93b8cc38d8
commit 7ab62319f9
9 changed files with 199 additions and 16 deletions

View File

@ -36,7 +36,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.service import verify_domain_control
from homeassistant.helpers.service import verify_domain_entity_control
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.hass_dict import HassKey
@ -162,12 +162,12 @@ def setup_service_functions(
It appears that all TCC-compatible systems support the same three zones modes.
"""
@verify_domain_control(hass, DOMAIN)
@verify_domain_entity_control(DOMAIN)
async def force_refresh(call: ServiceCall) -> None:
"""Obtain the latest state data via the vendor's RESTful API."""
await coordinator.async_refresh()
@verify_domain_control(hass, DOMAIN)
@verify_domain_entity_control(DOMAIN)
async def set_system_mode(call: ServiceCall) -> None:
"""Set the system mode."""
assert coordinator.tcs is not None # mypy
@ -179,7 +179,7 @@ def setup_service_functions(
}
async_dispatcher_send(hass, DOMAIN, payload)
@verify_domain_control(hass, DOMAIN)
@verify_domain_entity_control(DOMAIN)
async def set_zone_override(call: ServiceCall) -> None:
"""Set the zone override (setpoint)."""
entity_id = call.data[ATTR_ENTITY_ID]

View File

@ -25,7 +25,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.service import verify_domain_control
from homeassistant.helpers.service import verify_domain_entity_control
from .const import DOMAIN
@ -124,7 +124,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: GeniusHubConfigEntry) ->
def setup_service_functions(hass: HomeAssistant, broker):
"""Set up the service functions."""
@verify_domain_control(hass, DOMAIN)
@verify_domain_entity_control(DOMAIN)
async def set_zone_mode(call: ServiceCall) -> None:
"""Set the system mode."""
entity_id = call.data[ATTR_ENTITY_ID]

View File

@ -18,7 +18,7 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.config_validation import comp_entity_ids
from homeassistant.helpers.service import (
async_register_admin_service,
verify_domain_control,
verify_domain_entity_control,
)
from .const import DOMAIN
@ -126,7 +126,7 @@ async def async_setup_services(hass: HomeAssistant) -> None:
if hass.services.async_services_for_domain(DOMAIN):
return
@verify_domain_control(hass, DOMAIN)
@verify_domain_entity_control(DOMAIN)
async def async_call_hmipc_service(service: ServiceCall) -> None:
"""Call correct HomematicIP Cloud service."""
service_name = service.service

View File

@ -10,7 +10,7 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import verify_domain_control
from homeassistant.helpers.service import verify_domain_entity_control
from .bridge import HueBridge, HueConfigEntry
from .const import (
@ -64,7 +64,7 @@ def async_register_services(hass: HomeAssistant) -> None:
hass.services.async_register(
DOMAIN,
SERVICE_HUE_ACTIVATE_SCENE,
verify_domain_control(hass, DOMAIN)(hue_activate_scene),
verify_domain_entity_control(DOMAIN)(hue_activate_scene),
schema=vol.Schema(
{
vol.Required(ATTR_GROUP_NAME): cv.string,

View File

@ -89,7 +89,7 @@ async def async_setup_entry(
elif service_call.service == SERVICE_RESTORE:
entity.restore()
@service.verify_domain_control(hass, DOMAIN)
@service.verify_domain_entity_control(DOMAIN)
async def async_service_handle(service_call: core.ServiceCall) -> None:
"""Handle for services."""
entities = await platform.async_extract_from_service(service_call)

View File

@ -63,7 +63,7 @@ from homeassistant.helpers import (
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.service import (
async_register_admin_service,
verify_domain_control,
verify_domain_entity_control,
)
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
@ -290,7 +290,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up SimpliSafe as config entry."""
_async_standardize_config_entry(hass, entry)
_verify_domain_control = verify_domain_control(hass, DOMAIN)
_verify_domain_control = verify_domain_entity_control(DOMAIN)
websession = aiohttp_client.async_get_clientsession(hass)
try:

View File

@ -120,7 +120,7 @@ async def async_setup_entry(
_LOGGER.debug("Creating media_player on %s", speaker.zone_name)
async_add_entities([SonosMediaPlayerEntity(speaker)])
@service.verify_domain_control(hass, DOMAIN)
@service.verify_domain_entity_control(DOMAIN)
async def async_service_handle(service_call: ServiceCall) -> None:
"""Handle dispatched services."""
assert platform is not None

View File

@ -1151,6 +1151,18 @@ def async_register_admin_service(
@callback
def verify_domain_control(
hass: HomeAssistant, domain: str
) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]:
"""Ensure permission to access any entity under domain in service call.
The use of this decorator is discouraged, and it should not be used
for new functions - please use `verify_domain_entity_control`.
"""
return verify_domain_entity_control(domain)
def verify_domain_entity_control(
domain: str,
) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]:
"""Ensure permission to access any entity under domain in service call."""
@ -1166,7 +1178,7 @@ def verify_domain_control(
if not call.context.user_id:
return await service_handler(call)
user = await hass.auth.async_get_user(call.context.user_id)
user = await call.hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
@ -1175,7 +1187,7 @@ def verify_domain_control(
user_id=call.context.user_id,
)
reg = entity_registry.async_get(hass)
reg = entity_registry.async_get(call.hass)
authorized = False

View File

@ -1678,6 +1678,7 @@ async def test_register_admin_service_return_response(
async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> None:
"""Test domain verification in a service call with an unknown user."""
# Note: deprecated - replaced by test_domain_entity_control_not_async
calls = []
def mock_service_log(call):
@ -1690,6 +1691,7 @@ async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> N
async def test_domain_control_unknown(hass: HomeAssistant, mock_entities) -> None:
"""Test domain verification in a service call with an unknown user."""
# Note: deprecated - replaced by test_domain_entity_control_unknown
calls = []
async def mock_service_log(call):
@ -1723,6 +1725,7 @@ async def test_domain_control_unauthorized(
hass: HomeAssistant, hass_read_only_user: MockUser
) -> None:
"""Test domain verification in a service call with an unauthorized user."""
# Note: deprecated - replaced by test_domain_entity_control_unauthorized
mock_registry(
hass,
{
@ -1764,6 +1767,7 @@ async def test_domain_control_admin(
hass: HomeAssistant, hass_admin_user: MockUser
) -> None:
"""Test domain verification in a service call with an admin user."""
# Note: deprecated - replaced by test_domain_entity_control_admin
mock_registry(
hass,
{
@ -1802,6 +1806,7 @@ async def test_domain_control_admin(
async def test_domain_control_no_user(hass: HomeAssistant) -> None:
"""Test domain verification in a service call with no user."""
# Note: deprecated - replaced by test_domain_entity_control_no_user
mock_registry(
hass,
{
@ -1838,6 +1843,172 @@ async def test_domain_control_no_user(hass: HomeAssistant) -> None:
assert len(calls) == 1
async def test_domain_entity_control_not_async(
hass: HomeAssistant, mock_entities
) -> None:
"""Test domain verification in a service call with an unknown user."""
calls = []
def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with pytest.raises(exceptions.HomeAssistantError):
service.verify_domain_entity_control(hass, "test_domain")(mock_service_log)
async def test_domain_entity_control_unknown(
hass: HomeAssistant, mock_entities
) -> None:
"""Test domain verification in a service call with an unknown user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch(
"homeassistant.helpers.entity_registry.async_get",
return_value=Mock(entities=mock_entities),
):
protected_mock_service = service.verify_domain_entity_control(
hass, "test_domain"
)(mock_service_log)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
with pytest.raises(exceptions.UnknownUser):
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=Context(user_id="fake_user_id"),
)
assert len(calls) == 0
async def test_domain_entity_control_unauthorized(
hass: HomeAssistant, hass_read_only_user: MockUser
) -> None:
"""Test domain verification in a service call with an unauthorized user."""
mock_registry(
hass,
{
"light.kitchen": RegistryEntryWithDefaults(
entity_id="light.kitchen",
unique_id="kitchen",
platform="test_domain",
)
},
)
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
protected_mock_service = service.verify_domain_entity_control(hass, "test_domain")(
mock_service_log
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=Context(user_id=hass_read_only_user.id),
)
assert len(calls) == 0
async def test_domain_entity_control_admin(
hass: HomeAssistant, hass_admin_user: MockUser
) -> None:
"""Test domain verification in a service call with an admin user."""
mock_registry(
hass,
{
"light.kitchen": RegistryEntryWithDefaults(
entity_id="light.kitchen",
unique_id="kitchen",
platform="test_domain",
)
},
)
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
protected_mock_service = service.verify_domain_entity_control(hass, "test_domain")(
mock_service_log
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=Context(user_id=hass_admin_user.id),
)
assert len(calls) == 1
async def test_domain_entity_control_no_user(hass: HomeAssistant) -> None:
"""Test domain verification in a service call with no user."""
mock_registry(
hass,
{
"light.kitchen": RegistryEntryWithDefaults(
entity_id="light.kitchen",
unique_id="kitchen",
platform="test_domain",
)
},
)
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
protected_mock_service = service.verify_domain_entity_control(hass, "test_domain")(
mock_service_log
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=Context(user_id=None),
)
assert len(calls) == 1
async def test_extract_from_service_available_device(hass: HomeAssistant) -> None:
"""Test the extraction of entity from service and device is available."""
entities = [