diff --git a/homeassistant/components/evohome/__init__.py b/homeassistant/components/evohome/__init__.py index 9dce352df30..bfe8b128d2c 100644 --- a/homeassistant/components/evohome/__init__.py +++ b/homeassistant/components/evohome/__init__.py @@ -36,7 +36,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.discovery import async_load_platform from homeassistant.helpers.dispatcher import async_dispatcher_send -from homeassistant.helpers.service import verify_domain_control +from homeassistant.helpers.service import verify_domain_entity_control from homeassistant.helpers.typing import ConfigType from homeassistant.util.hass_dict import HassKey @@ -162,12 +162,12 @@ def setup_service_functions( It appears that all TCC-compatible systems support the same three zones modes. """ - @verify_domain_control(hass, DOMAIN) + @verify_domain_entity_control(DOMAIN) async def force_refresh(call: ServiceCall) -> None: """Obtain the latest state data via the vendor's RESTful API.""" await coordinator.async_refresh() - @verify_domain_control(hass, DOMAIN) + @verify_domain_entity_control(DOMAIN) async def set_system_mode(call: ServiceCall) -> None: """Set the system mode.""" assert coordinator.tcs is not None # mypy @@ -179,7 +179,7 @@ def setup_service_functions( } async_dispatcher_send(hass, DOMAIN, payload) - @verify_domain_control(hass, DOMAIN) + @verify_domain_entity_control(DOMAIN) async def set_zone_override(call: ServiceCall) -> None: """Set the zone override (setpoint).""" entity_id = call.data[ATTR_ENTITY_ID] diff --git a/homeassistant/components/geniushub/__init__.py b/homeassistant/components/geniushub/__init__.py index 9ca6ecfcfe0..a20fe62d113 100644 --- a/homeassistant/components/geniushub/__init__.py +++ b/homeassistant/components/geniushub/__init__.py @@ -25,7 +25,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_track_time_interval -from homeassistant.helpers.service import verify_domain_control +from homeassistant.helpers.service import verify_domain_entity_control from .const import DOMAIN @@ -124,7 +124,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: GeniusHubConfigEntry) -> def setup_service_functions(hass: HomeAssistant, broker): """Set up the service functions.""" - @verify_domain_control(hass, DOMAIN) + @verify_domain_entity_control(DOMAIN) async def set_zone_mode(call: ServiceCall) -> None: """Set the system mode.""" entity_id = call.data[ATTR_ENTITY_ID] diff --git a/homeassistant/components/homematicip_cloud/services.py b/homeassistant/components/homematicip_cloud/services.py index 2e76a0b7aac..a49d41a54a1 100644 --- a/homeassistant/components/homematicip_cloud/services.py +++ b/homeassistant/components/homematicip_cloud/services.py @@ -18,7 +18,7 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.config_validation import comp_entity_ids from homeassistant.helpers.service import ( async_register_admin_service, - verify_domain_control, + verify_domain_entity_control, ) from .const import DOMAIN @@ -126,7 +126,7 @@ async def async_setup_services(hass: HomeAssistant) -> None: if hass.services.async_services_for_domain(DOMAIN): return - @verify_domain_control(hass, DOMAIN) + @verify_domain_entity_control(DOMAIN) async def async_call_hmipc_service(service: ServiceCall) -> None: """Call correct HomematicIP Cloud service.""" service_name = service.service diff --git a/homeassistant/components/hue/services.py b/homeassistant/components/hue/services.py index 18dd19e3391..3981952f16f 100644 --- a/homeassistant/components/hue/services.py +++ b/homeassistant/components/hue/services.py @@ -10,7 +10,7 @@ import voluptuous as vol from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.service import verify_domain_control +from homeassistant.helpers.service import verify_domain_entity_control from .bridge import HueBridge, HueConfigEntry from .const import ( @@ -64,7 +64,7 @@ def async_register_services(hass: HomeAssistant) -> None: hass.services.async_register( DOMAIN, SERVICE_HUE_ACTIVATE_SCENE, - verify_domain_control(hass, DOMAIN)(hue_activate_scene), + verify_domain_entity_control(DOMAIN)(hue_activate_scene), schema=vol.Schema( { vol.Required(ATTR_GROUP_NAME): cv.string, diff --git a/homeassistant/components/monoprice/media_player.py b/homeassistant/components/monoprice/media_player.py index 9d678c16874..b5391014005 100644 --- a/homeassistant/components/monoprice/media_player.py +++ b/homeassistant/components/monoprice/media_player.py @@ -89,7 +89,7 @@ async def async_setup_entry( elif service_call.service == SERVICE_RESTORE: entity.restore() - @service.verify_domain_control(hass, DOMAIN) + @service.verify_domain_entity_control(DOMAIN) async def async_service_handle(service_call: core.ServiceCall) -> None: """Handle for services.""" entities = await platform.async_extract_from_service(service_call) diff --git a/homeassistant/components/simplisafe/__init__.py b/homeassistant/components/simplisafe/__init__.py index 8a75baa69c6..cabfc57d967 100644 --- a/homeassistant/components/simplisafe/__init__.py +++ b/homeassistant/components/simplisafe/__init__.py @@ -63,7 +63,7 @@ from homeassistant.helpers import ( from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.service import ( async_register_admin_service, - verify_domain_control, + verify_domain_entity_control, ) from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed @@ -290,7 +290,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up SimpliSafe as config entry.""" _async_standardize_config_entry(hass, entry) - _verify_domain_control = verify_domain_control(hass, DOMAIN) + _verify_domain_control = verify_domain_entity_control(DOMAIN) websession = aiohttp_client.async_get_clientsession(hass) try: diff --git a/homeassistant/components/sonos/media_player.py b/homeassistant/components/sonos/media_player.py index f1f95659469..b3a46cec239 100644 --- a/homeassistant/components/sonos/media_player.py +++ b/homeassistant/components/sonos/media_player.py @@ -120,7 +120,7 @@ async def async_setup_entry( _LOGGER.debug("Creating media_player on %s", speaker.zone_name) async_add_entities([SonosMediaPlayerEntity(speaker)]) - @service.verify_domain_control(hass, DOMAIN) + @service.verify_domain_entity_control(DOMAIN) async def async_service_handle(service_call: ServiceCall) -> None: """Handle dispatched services.""" assert platform is not None diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index f157e82bc53..5bd000fd09b 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -1151,6 +1151,18 @@ def async_register_admin_service( @callback def verify_domain_control( hass: HomeAssistant, domain: str +) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]: + """Ensure permission to access any entity under domain in service call. + + The use of this decorator is discouraged, and it should not be used + for new functions - please use `verify_domain_entity_control`. + """ + + return verify_domain_entity_control(domain) + + +def verify_domain_entity_control( + domain: str, ) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]: """Ensure permission to access any entity under domain in service call.""" @@ -1166,7 +1178,7 @@ def verify_domain_control( if not call.context.user_id: return await service_handler(call) - user = await hass.auth.async_get_user(call.context.user_id) + user = await call.hass.auth.async_get_user(call.context.user_id) if user is None: raise UnknownUser( @@ -1175,7 +1187,7 @@ def verify_domain_control( user_id=call.context.user_id, ) - reg = entity_registry.async_get(hass) + reg = entity_registry.async_get(call.hass) authorized = False diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 38e7e1ae452..c46a6905b55 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -1678,6 +1678,7 @@ async def test_register_admin_service_return_response( async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> None: """Test domain verification in a service call with an unknown user.""" + # Note: deprecated - replaced by test_domain_entity_control_not_async calls = [] def mock_service_log(call): @@ -1690,6 +1691,7 @@ async def test_domain_control_not_async(hass: HomeAssistant, mock_entities) -> N async def test_domain_control_unknown(hass: HomeAssistant, mock_entities) -> None: """Test domain verification in a service call with an unknown user.""" + # Note: deprecated - replaced by test_domain_entity_control_unknown calls = [] async def mock_service_log(call): @@ -1723,6 +1725,7 @@ async def test_domain_control_unauthorized( hass: HomeAssistant, hass_read_only_user: MockUser ) -> None: """Test domain verification in a service call with an unauthorized user.""" + # Note: deprecated - replaced by test_domain_entity_control_unauthorized mock_registry( hass, { @@ -1764,6 +1767,7 @@ async def test_domain_control_admin( hass: HomeAssistant, hass_admin_user: MockUser ) -> None: """Test domain verification in a service call with an admin user.""" + # Note: deprecated - replaced by test_domain_entity_control_admin mock_registry( hass, { @@ -1802,6 +1806,7 @@ async def test_domain_control_admin( async def test_domain_control_no_user(hass: HomeAssistant) -> None: """Test domain verification in a service call with no user.""" + # Note: deprecated - replaced by test_domain_entity_control_no_user mock_registry( hass, { @@ -1838,6 +1843,172 @@ async def test_domain_control_no_user(hass: HomeAssistant) -> None: assert len(calls) == 1 +async def test_domain_entity_control_not_async( + hass: HomeAssistant, mock_entities +) -> None: + """Test domain verification in a service call with an unknown user.""" + calls = [] + + def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + with pytest.raises(exceptions.HomeAssistantError): + service.verify_domain_entity_control(hass, "test_domain")(mock_service_log) + + +async def test_domain_entity_control_unknown( + hass: HomeAssistant, mock_entities +) -> None: + """Test domain verification in a service call with an unknown user.""" + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + with patch( + "homeassistant.helpers.entity_registry.async_get", + return_value=Mock(entities=mock_entities), + ): + protected_mock_service = service.verify_domain_entity_control( + hass, "test_domain" + )(mock_service_log) + + hass.services.async_register( + "test_domain", "test_service", protected_mock_service, schema=None + ) + + with pytest.raises(exceptions.UnknownUser): + await hass.services.async_call( + "test_domain", + "test_service", + {}, + blocking=True, + context=Context(user_id="fake_user_id"), + ) + assert len(calls) == 0 + + +async def test_domain_entity_control_unauthorized( + hass: HomeAssistant, hass_read_only_user: MockUser +) -> None: + """Test domain verification in a service call with an unauthorized user.""" + mock_registry( + hass, + { + "light.kitchen": RegistryEntryWithDefaults( + entity_id="light.kitchen", + unique_id="kitchen", + platform="test_domain", + ) + }, + ) + + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + protected_mock_service = service.verify_domain_entity_control(hass, "test_domain")( + mock_service_log + ) + + hass.services.async_register( + "test_domain", "test_service", protected_mock_service, schema=None + ) + + with pytest.raises(exceptions.Unauthorized): + await hass.services.async_call( + "test_domain", + "test_service", + {}, + blocking=True, + context=Context(user_id=hass_read_only_user.id), + ) + + assert len(calls) == 0 + + +async def test_domain_entity_control_admin( + hass: HomeAssistant, hass_admin_user: MockUser +) -> None: + """Test domain verification in a service call with an admin user.""" + mock_registry( + hass, + { + "light.kitchen": RegistryEntryWithDefaults( + entity_id="light.kitchen", + unique_id="kitchen", + platform="test_domain", + ) + }, + ) + + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + protected_mock_service = service.verify_domain_entity_control(hass, "test_domain")( + mock_service_log + ) + + hass.services.async_register( + "test_domain", "test_service", protected_mock_service, schema=None + ) + + await hass.services.async_call( + "test_domain", + "test_service", + {}, + blocking=True, + context=Context(user_id=hass_admin_user.id), + ) + + assert len(calls) == 1 + + +async def test_domain_entity_control_no_user(hass: HomeAssistant) -> None: + """Test domain verification in a service call with no user.""" + mock_registry( + hass, + { + "light.kitchen": RegistryEntryWithDefaults( + entity_id="light.kitchen", + unique_id="kitchen", + platform="test_domain", + ) + }, + ) + + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + protected_mock_service = service.verify_domain_entity_control(hass, "test_domain")( + mock_service_log + ) + + hass.services.async_register( + "test_domain", "test_service", protected_mock_service, schema=None + ) + + await hass.services.async_call( + "test_domain", + "test_service", + {}, + blocking=True, + context=Context(user_id=None), + ) + + assert len(calls) == 1 + + async def test_extract_from_service_available_device(hass: HomeAssistant) -> None: """Test the extraction of entity from service and device is available.""" entities = [