Add service helper for registering platform services (#151965)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery
2025-09-10 12:21:31 +02:00
committed by GitHub
parent 07392e3ff7
commit a1e68336fc
2 changed files with 399 additions and 12 deletions

View File

@@ -760,7 +760,7 @@ def _get_permissible_entity_candidates(
@bind_hass @bind_hass
async def entity_service_call( async def entity_service_call(
hass: HomeAssistant, hass: HomeAssistant,
registered_entities: dict[str, Entity], registered_entities: dict[str, Entity] | Callable[[], dict[str, Entity]],
func: str | HassJob, func: str | HassJob,
call: ServiceCall, call: ServiceCall,
required_features: Iterable[int] | None = None, required_features: Iterable[int] | None = None,
@@ -799,10 +799,15 @@ async def entity_service_call(
else: else:
data = call data = call
if callable(registered_entities):
_registered_entities = registered_entities()
else:
_registered_entities = registered_entities
# A list with entities to call the service on. # A list with entities to call the service on.
entity_candidates = _get_permissible_entity_candidates( entity_candidates = _get_permissible_entity_candidates(
call, call,
registered_entities, _registered_entities,
entity_perms, entity_perms,
target_all_entities, target_all_entities,
all_referenced, all_referenced,
@@ -1112,6 +1117,23 @@ class ReloadServiceHelper[_T]:
self._service_condition.notify_all() self._service_condition.notify_all()
def _validate_entity_service_schema(
schema: VolDictType | VolSchemaType | None,
) -> VolSchemaType:
"""Validate that a schema is an entity service schema."""
if schema is None or isinstance(schema, dict):
return cv.make_entity_service_schema(schema)
if not cv.is_entity_service_schema(schema):
from .frame import ReportBehavior, report_usage # noqa: PLC0415
report_usage(
"registers an entity service with a non entity service schema",
core_behavior=ReportBehavior.LOG,
breaks_in_ha_version="2025.9",
)
return schema
@callback @callback
def async_register_entity_service( def async_register_entity_service(
hass: HomeAssistant, hass: HomeAssistant,
@@ -1131,16 +1153,7 @@ def async_register_entity_service(
EntityPlatform.async_register_entity_service and should not be called EntityPlatform.async_register_entity_service and should not be called
directly by integrations. directly by integrations.
""" """
if schema is None or isinstance(schema, dict): schema = _validate_entity_service_schema(schema)
schema = cv.make_entity_service_schema(schema)
elif not cv.is_entity_service_schema(schema):
from .frame import ReportBehavior, report_usage # noqa: PLC0415
report_usage(
"registers an entity service with a non entity service schema",
core_behavior=ReportBehavior.LOG,
breaks_in_ha_version="2025.9",
)
service_func: str | HassJob[..., Any] service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func) service_func = func if isinstance(func, str) else HassJob(func)
@@ -1159,3 +1172,47 @@ def async_register_entity_service(
supports_response, supports_response,
job_type=job_type, job_type=job_type,
) )
@callback
def async_register_platform_entity_service(
hass: HomeAssistant,
service_domain: str,
service_name: str,
*,
entity_domain: str,
func: str | Callable[..., Any],
required_features: Iterable[int] | None = None,
schema: VolDictType | VolSchemaType | None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Help registering a platform entity service."""
from .entity_platform import DATA_DOMAIN_PLATFORM_ENTITIES # noqa: PLC0415
schema = _validate_entity_service_schema(schema)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
def get_entities() -> dict[str, Entity]:
entities = hass.data.get(DATA_DOMAIN_PLATFORM_ENTITIES, {}).get(
(entity_domain, service_domain)
)
if entities is None:
return {}
return entities
hass.services.async_register(
service_domain,
service_name,
partial(
entity_service_call,
hass,
get_entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
job_type=HassJobType.Coroutinefunction,
)

View File

@@ -5,6 +5,7 @@ from collections.abc import Iterable
from copy import deepcopy from copy import deepcopy
import dataclasses import dataclasses
import io import io
import logging
from typing import Any from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@@ -36,6 +37,7 @@ from homeassistant.core import (
ServiceCall, ServiceCall,
ServiceResponse, ServiceResponse,
SupportsResponse, SupportsResponse,
callback,
) )
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
@@ -55,6 +57,7 @@ from homeassistant.util.yaml.loader import parse_yaml
from tests.common import ( from tests.common import (
MockEntity, MockEntity,
MockEntityPlatform,
MockModule, MockModule,
MockUser, MockUser,
RegistryEntryWithDefaults, RegistryEntryWithDefaults,
@@ -2461,3 +2464,330 @@ async def test_deprecated_async_extract_referenced_entity_ids(
assert args[0][2] is False assert args[0][2] is False
assert dataclasses.asdict(result) == dataclasses.asdict(mock_selected) assert dataclasses.asdict(result) == dataclasses.asdict(mock_selected)
async def test_register_platform_entity_service(
hass: HomeAssistant,
) -> None:
"""Test registering a platform entity service."""
entities = []
@callback
def handle_service(entity, *_):
entities.append(entity)
service.async_register_platform_entity_service(
hass,
"mock_platform",
"hello",
entity_domain="mock_integration",
schema={},
func=handle_service,
)
await hass.services.async_call(
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
)
assert entities == []
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(entity_id="mock_integration.entity1")
entity2 = MockEntity(entity_id="mock_integration.entity2")
await entity_platform.async_add_entities([entity1, entity2])
await hass.services.async_call(
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
)
assert entities == unordered([entity1, entity2])
async def test_register_platform_entity_service_response_data(
hass: HomeAssistant,
) -> None:
"""Test an entity service that supports response data."""
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
assert call.return_response
return {"response-key": "response-value"}
service.async_register_platform_entity_service(
hass,
"mock_platform",
"hello",
entity_domain="mock_integration",
schema={"some": str},
func=generate_response,
supports_response=SupportsResponse.ONLY,
)
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity = MockEntity(entity_id="mock_integration.entity")
await entity_platform.async_add_entities([entity])
response_data = await hass.services.async_call(
"mock_platform",
"hello",
service_data={"some": "data"},
target={"entity_id": [entity.entity_id]},
blocking=True,
return_response=True,
)
assert response_data == {
"mock_integration.entity": {"response-key": "response-value"}
}
async def test_register_platform_entity_service_response_data_multiple_matches(
hass: HomeAssistant,
) -> None:
"""Test an entity service with response data and matching many entities."""
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
assert call.return_response
return {"response-key": f"response-value-{target.entity_id}"}
service.async_register_platform_entity_service(
hass,
"mock_platform",
"hello",
entity_domain="mock_integration",
schema={"some": str},
func=generate_response,
supports_response=SupportsResponse.ONLY,
)
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(entity_id="mock_integration.entity1")
entity2 = MockEntity(entity_id="mock_integration.entity2")
await entity_platform.async_add_entities([entity1, entity2])
response_data = await hass.services.async_call(
"mock_platform",
"hello",
service_data={"some": "data"},
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
blocking=True,
return_response=True,
)
assert response_data == {
"mock_integration.entity1": {
"response-key": "response-value-mock_integration.entity1"
},
"mock_integration.entity2": {
"response-key": "response-value-mock_integration.entity2"
},
}
async def test_register_platform_entity_service_response_data_multiple_matches_raises(
hass: HomeAssistant,
) -> None:
"""Test entity service response matching many entities raises."""
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
assert call.return_response
if target.entity_id == "mock_integration.entity1":
raise RuntimeError("Something went wrong")
return {"response-key": f"response-value-{target.entity_id}"}
service.async_register_platform_entity_service(
hass,
"mock_platform",
"hello",
entity_domain="mock_integration",
schema={"some": str},
func=generate_response,
supports_response=SupportsResponse.ONLY,
)
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(entity_id="mock_integration.entity1")
entity2 = MockEntity(entity_id="mock_integration.entity2")
await entity_platform.async_add_entities([entity1, entity2])
with pytest.raises(RuntimeError, match="Something went wrong"):
await hass.services.async_call(
"mock_platform",
"hello",
service_data={"some": "data"},
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
blocking=True,
return_response=True,
)
async def test_register_platform_entity_service_limited_to_matching_platforms(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry,
) -> None:
"""Test entity services only target entities for the platform and domain."""
mock_area = area_registry.async_get_or_create("mock_area")
entity1_entry = entity_registry.async_get_or_create(
"base_platform", "mock_platform", "1234", suggested_object_id="entity1"
)
entity_registry.async_update_entity(entity1_entry.entity_id, area_id=mock_area.id)
entity2_entry = entity_registry.async_get_or_create(
"base_platform", "mock_platform", "5678", suggested_object_id="entity2"
)
entity_registry.async_update_entity(entity2_entry.entity_id, area_id=mock_area.id)
entity3_entry = entity_registry.async_get_or_create(
"base_platform", "other_mock_platform", "7891", suggested_object_id="entity3"
)
entity_registry.async_update_entity(entity3_entry.entity_id, area_id=mock_area.id)
entity4_entry = entity_registry.async_get_or_create(
"base_platform", "other_mock_platform", "1433", suggested_object_id="entity4"
)
entity_registry.async_update_entity(entity4_entry.entity_id, area_id=mock_area.id)
async def generate_response(
target: MockEntity, call: ServiceCall
) -> ServiceResponse:
assert call.return_response
return {"response-key": f"response-value-{target.entity_id}"}
service.async_register_platform_entity_service(
hass,
"mock_platform",
"hello",
entity_domain="base_platform",
schema={"some": str},
func=generate_response,
supports_response=SupportsResponse.ONLY,
)
entity_platform = MockEntityPlatform(
hass, domain="base_platform", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(
entity_id=entity1_entry.entity_id, unique_id=entity1_entry.unique_id
)
entity2 = MockEntity(
entity_id=entity2_entry.entity_id, unique_id=entity2_entry.unique_id
)
await entity_platform.async_add_entities([entity1, entity2])
other_entity_platform = MockEntityPlatform(
hass, domain="base_platform", platform_name="other_mock_platform", platform=None
)
entity3 = MockEntity(
entity_id=entity3_entry.entity_id, unique_id=entity3_entry.unique_id
)
entity4 = MockEntity(
entity_id=entity4_entry.entity_id, unique_id=entity4_entry.unique_id
)
await other_entity_platform.async_add_entities([entity3, entity4])
response_data = await hass.services.async_call(
"mock_platform",
"hello",
service_data={"some": "data"},
target={"area_id": [mock_area.id]},
blocking=True,
return_response=True,
)
# We should not target entity3 and entity4 even though they are in the area
# because they are only part of the domain and not the platform
assert response_data == {
"base_platform.entity1": {
"response-key": "response-value-base_platform.entity1"
},
"base_platform.entity2": {
"response-key": "response-value-base_platform.entity2"
},
}
async def test_register_platform_entity_service_none_schema(
hass: HomeAssistant,
) -> None:
"""Test registering a service with schema set to None."""
entities = []
@callback
def handle_service(entity, *_):
entities.append(entity)
service.async_register_platform_entity_service(
hass,
"mock_platform",
"hello",
entity_domain="mock_integration",
schema=None,
func=handle_service,
)
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(name="entity_1")
entity2 = MockEntity(name="entity_1")
await entity_platform.async_add_entities([entity1, entity2])
await hass.services.async_call(
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
)
assert len(entities) == 2
assert entity1 in entities
assert entity2 in entities
async def test_register_platform_entity_service_non_entity_service_schema(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test attempting to register a service with a non entity service schema."""
expected_message = "registers an entity service with a non entity service schema"
for idx, schema in enumerate(
(
vol.Schema({"some": str}),
vol.All(vol.Schema({"some": str})),
vol.Any(vol.Schema({"some": str})),
)
):
service.async_register_platform_entity_service(
hass,
"mock_platform",
f"hello_{idx}",
entity_domain="mock_integration",
schema=schema,
func=Mock(),
)
assert expected_message in caplog.text
caplog.clear()
for idx, schema in enumerate(
(
cv.make_entity_service_schema({"some": str}),
vol.Schema(cv.make_entity_service_schema({"some": str})),
vol.All(cv.make_entity_service_schema({"some": str})),
)
):
service.async_register_platform_entity_service(
hass,
"mock_platform",
f"test_service_{idx}",
entity_domain="mock_integration",
schema=schema,
func=Mock(),
)
assert expected_message not in caplog.text
assert not any(x.levelno > logging.DEBUG for x in caplog.records)