From a1e68336fcba009ffd205c17d472d41164c7858e Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 10 Sep 2025 12:21:31 +0200 Subject: [PATCH] Add service helper for registering platform services (#151965) Co-authored-by: Martin Hjelmare --- homeassistant/helpers/service.py | 81 ++++++-- tests/helpers/test_service.py | 330 +++++++++++++++++++++++++++++++ 2 files changed, 399 insertions(+), 12 deletions(-) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index a30d5c67cef..70bded4b599 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -760,7 +760,7 @@ def _get_permissible_entity_candidates( @bind_hass async def entity_service_call( hass: HomeAssistant, - registered_entities: dict[str, Entity], + registered_entities: dict[str, Entity] | Callable[[], dict[str, Entity]], func: str | HassJob, call: ServiceCall, required_features: Iterable[int] | None = None, @@ -799,10 +799,15 @@ async def entity_service_call( else: data = call + if callable(registered_entities): + _registered_entities = registered_entities() + else: + _registered_entities = registered_entities + # A list with entities to call the service on. entity_candidates = _get_permissible_entity_candidates( call, - registered_entities, + _registered_entities, entity_perms, target_all_entities, all_referenced, @@ -1112,6 +1117,23 @@ class ReloadServiceHelper[_T]: self._service_condition.notify_all() +def _validate_entity_service_schema( + schema: VolDictType | VolSchemaType | None, +) -> VolSchemaType: + """Validate that a schema is an entity service schema.""" + if schema is None or isinstance(schema, dict): + return cv.make_entity_service_schema(schema) + if not cv.is_entity_service_schema(schema): + from .frame import ReportBehavior, report_usage # noqa: PLC0415 + + report_usage( + "registers an entity service with a non entity service schema", + core_behavior=ReportBehavior.LOG, + breaks_in_ha_version="2025.9", + ) + return schema + + @callback def async_register_entity_service( hass: HomeAssistant, @@ -1131,16 +1153,7 @@ def async_register_entity_service( EntityPlatform.async_register_entity_service and should not be called directly by integrations. """ - if schema is None or isinstance(schema, dict): - schema = cv.make_entity_service_schema(schema) - elif not cv.is_entity_service_schema(schema): - from .frame import ReportBehavior, report_usage # noqa: PLC0415 - - report_usage( - "registers an entity service with a non entity service schema", - core_behavior=ReportBehavior.LOG, - breaks_in_ha_version="2025.9", - ) + schema = _validate_entity_service_schema(schema) service_func: str | HassJob[..., Any] service_func = func if isinstance(func, str) else HassJob(func) @@ -1159,3 +1172,47 @@ def async_register_entity_service( supports_response, job_type=job_type, ) + + +@callback +def async_register_platform_entity_service( + hass: HomeAssistant, + service_domain: str, + service_name: str, + *, + entity_domain: str, + func: str | Callable[..., Any], + required_features: Iterable[int] | None = None, + schema: VolDictType | VolSchemaType | None, + supports_response: SupportsResponse = SupportsResponse.NONE, +) -> None: + """Help registering a platform entity service.""" + from .entity_platform import DATA_DOMAIN_PLATFORM_ENTITIES # noqa: PLC0415 + + schema = _validate_entity_service_schema(schema) + + service_func: str | HassJob[..., Any] + service_func = func if isinstance(func, str) else HassJob(func) + + def get_entities() -> dict[str, Entity]: + entities = hass.data.get(DATA_DOMAIN_PLATFORM_ENTITIES, {}).get( + (entity_domain, service_domain) + ) + if entities is None: + return {} + return entities + + hass.services.async_register( + service_domain, + service_name, + partial( + entity_service_call, + hass, + get_entities, + service_func, + required_features=required_features, + ), + schema, + supports_response, + job_type=HassJobType.Coroutinefunction, + ) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index d41e46beba5..2af35fa95ec 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -5,6 +5,7 @@ from collections.abc import Iterable from copy import deepcopy import dataclasses import io +import logging from typing import Any from unittest.mock import AsyncMock, Mock, patch @@ -36,6 +37,7 @@ from homeassistant.core import ( ServiceCall, ServiceResponse, SupportsResponse, + callback, ) from homeassistant.helpers import ( area_registry as ar, @@ -55,6 +57,7 @@ from homeassistant.util.yaml.loader import parse_yaml from tests.common import ( MockEntity, + MockEntityPlatform, MockModule, MockUser, RegistryEntryWithDefaults, @@ -2461,3 +2464,330 @@ async def test_deprecated_async_extract_referenced_entity_ids( assert args[0][2] is False assert dataclasses.asdict(result) == dataclasses.asdict(mock_selected) + + +async def test_register_platform_entity_service( + hass: HomeAssistant, +) -> None: + """Test registering a platform entity service.""" + entities = [] + + @callback + def handle_service(entity, *_): + entities.append(entity) + + service.async_register_platform_entity_service( + hass, + "mock_platform", + "hello", + entity_domain="mock_integration", + schema={}, + func=handle_service, + ) + + await hass.services.async_call( + "mock_platform", "hello", {"entity_id": "all"}, blocking=True + ) + assert entities == [] + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity(entity_id="mock_integration.entity1") + entity2 = MockEntity(entity_id="mock_integration.entity2") + await entity_platform.async_add_entities([entity1, entity2]) + + await hass.services.async_call( + "mock_platform", "hello", {"entity_id": "all"}, blocking=True + ) + + assert entities == unordered([entity1, entity2]) + + +async def test_register_platform_entity_service_response_data( + hass: HomeAssistant, +) -> None: + """Test an entity service that supports response data.""" + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + return {"response-key": "response-value"} + + service.async_register_platform_entity_service( + hass, + "mock_platform", + "hello", + entity_domain="mock_integration", + schema={"some": str}, + func=generate_response, + supports_response=SupportsResponse.ONLY, + ) + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity = MockEntity(entity_id="mock_integration.entity") + await entity_platform.async_add_entities([entity]) + + response_data = await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity.entity_id]}, + blocking=True, + return_response=True, + ) + assert response_data == { + "mock_integration.entity": {"response-key": "response-value"} + } + + +async def test_register_platform_entity_service_response_data_multiple_matches( + hass: HomeAssistant, +) -> None: + """Test an entity service with response data and matching many entities.""" + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + return {"response-key": f"response-value-{target.entity_id}"} + + service.async_register_platform_entity_service( + hass, + "mock_platform", + "hello", + entity_domain="mock_integration", + schema={"some": str}, + func=generate_response, + supports_response=SupportsResponse.ONLY, + ) + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity(entity_id="mock_integration.entity1") + entity2 = MockEntity(entity_id="mock_integration.entity2") + await entity_platform.async_add_entities([entity1, entity2]) + + response_data = await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity1.entity_id, entity2.entity_id]}, + blocking=True, + return_response=True, + ) + assert response_data == { + "mock_integration.entity1": { + "response-key": "response-value-mock_integration.entity1" + }, + "mock_integration.entity2": { + "response-key": "response-value-mock_integration.entity2" + }, + } + + +async def test_register_platform_entity_service_response_data_multiple_matches_raises( + hass: HomeAssistant, +) -> None: + """Test entity service response matching many entities raises.""" + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + if target.entity_id == "mock_integration.entity1": + raise RuntimeError("Something went wrong") + return {"response-key": f"response-value-{target.entity_id}"} + + service.async_register_platform_entity_service( + hass, + "mock_platform", + "hello", + entity_domain="mock_integration", + schema={"some": str}, + func=generate_response, + supports_response=SupportsResponse.ONLY, + ) + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity(entity_id="mock_integration.entity1") + entity2 = MockEntity(entity_id="mock_integration.entity2") + await entity_platform.async_add_entities([entity1, entity2]) + + with pytest.raises(RuntimeError, match="Something went wrong"): + await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"entity_id": [entity1.entity_id, entity2.entity_id]}, + blocking=True, + return_response=True, + ) + + +async def test_register_platform_entity_service_limited_to_matching_platforms( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + area_registry: ar.AreaRegistry, +) -> None: + """Test entity services only target entities for the platform and domain.""" + + mock_area = area_registry.async_get_or_create("mock_area") + + entity1_entry = entity_registry.async_get_or_create( + "base_platform", "mock_platform", "1234", suggested_object_id="entity1" + ) + entity_registry.async_update_entity(entity1_entry.entity_id, area_id=mock_area.id) + entity2_entry = entity_registry.async_get_or_create( + "base_platform", "mock_platform", "5678", suggested_object_id="entity2" + ) + entity_registry.async_update_entity(entity2_entry.entity_id, area_id=mock_area.id) + entity3_entry = entity_registry.async_get_or_create( + "base_platform", "other_mock_platform", "7891", suggested_object_id="entity3" + ) + entity_registry.async_update_entity(entity3_entry.entity_id, area_id=mock_area.id) + entity4_entry = entity_registry.async_get_or_create( + "base_platform", "other_mock_platform", "1433", suggested_object_id="entity4" + ) + entity_registry.async_update_entity(entity4_entry.entity_id, area_id=mock_area.id) + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + return {"response-key": f"response-value-{target.entity_id}"} + + service.async_register_platform_entity_service( + hass, + "mock_platform", + "hello", + entity_domain="base_platform", + schema={"some": str}, + func=generate_response, + supports_response=SupportsResponse.ONLY, + ) + + entity_platform = MockEntityPlatform( + hass, domain="base_platform", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity( + entity_id=entity1_entry.entity_id, unique_id=entity1_entry.unique_id + ) + entity2 = MockEntity( + entity_id=entity2_entry.entity_id, unique_id=entity2_entry.unique_id + ) + await entity_platform.async_add_entities([entity1, entity2]) + + other_entity_platform = MockEntityPlatform( + hass, domain="base_platform", platform_name="other_mock_platform", platform=None + ) + entity3 = MockEntity( + entity_id=entity3_entry.entity_id, unique_id=entity3_entry.unique_id + ) + entity4 = MockEntity( + entity_id=entity4_entry.entity_id, unique_id=entity4_entry.unique_id + ) + await other_entity_platform.async_add_entities([entity3, entity4]) + + response_data = await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"area_id": [mock_area.id]}, + blocking=True, + return_response=True, + ) + # We should not target entity3 and entity4 even though they are in the area + # because they are only part of the domain and not the platform + assert response_data == { + "base_platform.entity1": { + "response-key": "response-value-base_platform.entity1" + }, + "base_platform.entity2": { + "response-key": "response-value-base_platform.entity2" + }, + } + + +async def test_register_platform_entity_service_none_schema( + hass: HomeAssistant, +) -> None: + """Test registering a service with schema set to None.""" + entities = [] + + @callback + def handle_service(entity, *_): + entities.append(entity) + + service.async_register_platform_entity_service( + hass, + "mock_platform", + "hello", + entity_domain="mock_integration", + schema=None, + func=handle_service, + ) + + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity(name="entity_1") + entity2 = MockEntity(name="entity_1") + await entity_platform.async_add_entities([entity1, entity2]) + + await hass.services.async_call( + "mock_platform", "hello", {"entity_id": "all"}, blocking=True + ) + + assert len(entities) == 2 + assert entity1 in entities + assert entity2 in entities + + +async def test_register_platform_entity_service_non_entity_service_schema( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test attempting to register a service with a non entity service schema.""" + expected_message = "registers an entity service with a non entity service schema" + + for idx, schema in enumerate( + ( + vol.Schema({"some": str}), + vol.All(vol.Schema({"some": str})), + vol.Any(vol.Schema({"some": str})), + ) + ): + service.async_register_platform_entity_service( + hass, + "mock_platform", + f"hello_{idx}", + entity_domain="mock_integration", + schema=schema, + func=Mock(), + ) + assert expected_message in caplog.text + caplog.clear() + + for idx, schema in enumerate( + ( + cv.make_entity_service_schema({"some": str}), + vol.Schema(cv.make_entity_service_schema({"some": str})), + vol.All(cv.make_entity_service_schema({"some": str})), + ) + ): + service.async_register_platform_entity_service( + hass, + "mock_platform", + f"test_service_{idx}", + entity_domain="mock_integration", + schema=schema, + func=Mock(), + ) + assert expected_message not in caplog.text + assert not any(x.levelno > logging.DEBUG for x in caplog.records)