Drop hass argument from verify_domain_control

This commit is contained in:
epenet 2025-07-02 10:49:25 +00:00
parent c75b34a911
commit 07b953d6d4
3 changed files with 135 additions and 23 deletions

View File

@ -138,6 +138,32 @@ def deprecated_function[**_P, _R](
return deprecated_decorator return deprecated_decorator
def deprecate_hass_binding[**_P, _T](
breaks_in_ha_version: str | None = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Decorate function to indicate that first argument hass will be ignored."""
def _decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
@functools.wraps(func)
def _inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:
from homeassistant.core import HomeAssistant # noqa: PLC0415
if isinstance(args[0], HomeAssistant):
_print_deprecation_warning(
func,
"without hass",
"argument",
"called with hass as the first argument",
breaks_in_ha_version,
)
args = args[1:] # type: ignore[assignment]
return func(*args, **kwargs)
return _inner
return _decorator
def _print_deprecation_warning( def _print_deprecation_warning(
obj: Any, obj: Any,
replacement: str, replacement: str,

View File

@ -9,7 +9,7 @@ from enum import Enum
from functools import cache, partial from functools import cache, partial
import logging import logging
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, TypedDict, cast, override from typing import TYPE_CHECKING, Any, TypedDict, cast, overload, override
import voluptuous as vol import voluptuous as vol
@ -57,7 +57,7 @@ from . import (
template, template,
translation, translation,
) )
from .deprecation import deprecated_class, deprecated_function from .deprecation import deprecate_hass_binding, deprecated_class, deprecated_function
from .selector import TargetSelector from .selector import TargetSelector
from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType
@ -986,10 +986,22 @@ def async_register_admin_service(
) )
@bind_hass # Overloads can be dropped when all core calls have been updated to drop hass argument
@overload
def verify_domain_control(
domain: str,
) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]: ...
@overload
def verify_domain_control(
hass: HomeAssistant,
domain: str,
) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]: ...
@deprecate_hass_binding(breaks_in_ha_version="2026.2") # type: ignore[misc]
@callback @callback
def verify_domain_control( def verify_domain_control(
hass: HomeAssistant, domain: str domain: str,
) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]: ) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]:
"""Ensure permission to access any entity under domain in service call.""" """Ensure permission to access any entity under domain in service call."""
@ -1005,6 +1017,7 @@ def verify_domain_control(
if not call.context.user_id: if not call.context.user_id:
return await service_handler(call) return await service_handler(call)
hass = call.hass
user = await hass.auth.async_get_user(call.context.user_id) user = await hass.auth.async_get_user(call.context.user_id)
if user is None: if user is None:

View File

@ -1,7 +1,7 @@
"""Test service helpers.""" """Test service helpers."""
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Callable, Iterable
from copy import deepcopy from copy import deepcopy
import dataclasses import dataclasses
import io import io
@ -1711,7 +1711,27 @@ async def test_register_admin_service_return_response(
assert result == {"test-reply": "test-value1"} assert result == {"test-reply": "test-value1"}
async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> None: _DEPRECATED_VERIFY_DOMAIN_CONTROL_MESSAGE = (
"verify_domain_control is a deprecated argument which will be removed"
" in HA Core 2026.2. Use without hass instead"
)
@pytest.mark.parametrize(
# Check that with or without hass behaves the same
("decorator", "in_caplog"),
[
(service.verify_domain_control, True), # deprecated with hass
(lambda _, domain: service.verify_domain_control(domain), False),
],
)
async def test_domain_control_not_async(
hass: HomeAssistant,
mock_entities,
decorator: Callable[[HomeAssistant, str], Any],
in_caplog: bool,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test domain verification in a service call with an unknown user.""" """Test domain verification in a service call with an unknown user."""
calls = [] calls = []
@ -1720,10 +1740,26 @@ async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> N
calls.append(call) calls.append(call)
with pytest.raises(exceptions.HomeAssistantError): with pytest.raises(exceptions.HomeAssistantError):
service.verify_domain_control(hass, "test_domain")(mock_service_log) decorator(hass, "test_domain")(mock_service_log)
assert (_DEPRECATED_VERIFY_DOMAIN_CONTROL_MESSAGE in caplog.text) == in_caplog
async def test_domain_control_unknown(hass: HomeAssistant, mock_entities) -> None: @pytest.mark.parametrize(
# Check that with or without hass behaves the same
("decorator", "in_caplog"),
[
(service.verify_domain_control, True), # deprecated with hass
(lambda _, domain: service.verify_domain_control(domain), False),
],
)
async def test_domain_control_unknown(
hass: HomeAssistant,
mock_entities,
decorator: Callable[[HomeAssistant, str], Any],
in_caplog: bool,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test domain verification in a service call with an unknown user.""" """Test domain verification in a service call with an unknown user."""
calls = [] calls = []
@ -1735,9 +1771,7 @@ async def test_domain_control_unknown(hass: HomeAssistant, mock_entities) -> Non
"homeassistant.helpers.entity_registry.async_get", "homeassistant.helpers.entity_registry.async_get",
return_value=Mock(entities=mock_entities), return_value=Mock(entities=mock_entities),
): ):
protected_mock_service = service.verify_domain_control(hass, "test_domain")( protected_mock_service = decorator(hass, "test_domain")(mock_service_log)
mock_service_log
)
hass.services.async_register( hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None "test_domain", "test_service", protected_mock_service, schema=None
@ -1753,9 +1787,23 @@ async def test_domain_control_unknown(hass: HomeAssistant, mock_entities) -> Non
) )
assert len(calls) == 0 assert len(calls) == 0
assert (_DEPRECATED_VERIFY_DOMAIN_CONTROL_MESSAGE in caplog.text) == in_caplog
@pytest.mark.parametrize(
# Check that with or without hass behaves the same
("decorator", "in_caplog"),
[
(service.verify_domain_control, True), # deprecated with hass
(lambda _, domain: service.verify_domain_control(domain), False),
],
)
async def test_domain_control_unauthorized( async def test_domain_control_unauthorized(
hass: HomeAssistant, hass_read_only_user: MockUser hass: HomeAssistant,
hass_read_only_user: MockUser,
decorator: Callable[[HomeAssistant, str], Any],
in_caplog: bool,
caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test domain verification in a service call with an unauthorized user.""" """Test domain verification in a service call with an unauthorized user."""
mock_registry( mock_registry(
@ -1775,9 +1823,7 @@ async def test_domain_control_unauthorized(
"""Define a protected service.""" """Define a protected service."""
calls.append(call) calls.append(call)
protected_mock_service = service.verify_domain_control(hass, "test_domain")( protected_mock_service = decorator(hass, "test_domain")(mock_service_log)
mock_service_log
)
hass.services.async_register( hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None "test_domain", "test_service", protected_mock_service, schema=None
@ -1794,9 +1840,23 @@ async def test_domain_control_unauthorized(
assert len(calls) == 0 assert len(calls) == 0
assert (_DEPRECATED_VERIFY_DOMAIN_CONTROL_MESSAGE in caplog.text) == in_caplog
@pytest.mark.parametrize(
# Check that with or without hass behaves the same
("decorator", "in_caplog"),
[
(service.verify_domain_control, True), # deprecated with hass
(lambda _, domain: service.verify_domain_control(domain), False),
],
)
async def test_domain_control_admin( async def test_domain_control_admin(
hass: HomeAssistant, hass_admin_user: MockUser hass: HomeAssistant,
hass_admin_user: MockUser,
decorator: Callable[[HomeAssistant, str], Any],
in_caplog: bool,
caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test domain verification in a service call with an admin user.""" """Test domain verification in a service call with an admin user."""
mock_registry( mock_registry(
@ -1816,9 +1876,7 @@ async def test_domain_control_admin(
"""Define a protected service.""" """Define a protected service."""
calls.append(call) calls.append(call)
protected_mock_service = service.verify_domain_control(hass, "test_domain")( protected_mock_service = decorator(hass, "test_domain")(mock_service_log)
mock_service_log
)
hass.services.async_register( hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None "test_domain", "test_service", protected_mock_service, schema=None
@ -1834,8 +1892,23 @@ async def test_domain_control_admin(
assert len(calls) == 1 assert len(calls) == 1
assert (_DEPRECATED_VERIFY_DOMAIN_CONTROL_MESSAGE in caplog.text) == in_caplog
async def test_domain_control_no_user(hass: HomeAssistant) -> None:
@pytest.mark.parametrize(
# Check that with or without hass behaves the same
("decorator", "in_caplog"),
[
(service.verify_domain_control, True), # deprecated with hass
(lambda _, domain: service.verify_domain_control(domain), False),
],
)
async def test_domain_control_no_user(
hass: HomeAssistant,
decorator: Callable[[HomeAssistant, str], Any],
in_caplog: bool,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test domain verification in a service call with no user.""" """Test domain verification in a service call with no user."""
mock_registry( mock_registry(
hass, hass,
@ -1854,9 +1927,7 @@ async def test_domain_control_no_user(hass: HomeAssistant) -> None:
"""Define a protected service.""" """Define a protected service."""
calls.append(call) calls.append(call)
protected_mock_service = service.verify_domain_control(hass, "test_domain")( protected_mock_service = decorator(hass, "test_domain")(mock_service_log)
mock_service_log
)
hass.services.async_register( hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None "test_domain", "test_service", protected_mock_service, schema=None
@ -1872,6 +1943,8 @@ async def test_domain_control_no_user(hass: HomeAssistant) -> None:
assert len(calls) == 1 assert len(calls) == 1
assert (_DEPRECATED_VERIFY_DOMAIN_CONTROL_MESSAGE in caplog.text) == in_caplog
async def test_extract_from_service_available_device(hass: HomeAssistant) -> None: async def test_extract_from_service_available_device(hass: HomeAssistant) -> None:
"""Test the extraction of entity from service and device is available.""" """Test the extraction of entity from service and device is available."""