mirror of
https://github.com/home-assistant/core.git
synced 2025-11-10 11:29:46 +00:00
Add service helper for registering platform services (#151965)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
@@ -760,7 +760,7 @@ def _get_permissible_entity_candidates(
|
|||||||
@bind_hass
|
@bind_hass
|
||||||
async def entity_service_call(
|
async def entity_service_call(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
registered_entities: dict[str, Entity],
|
registered_entities: dict[str, Entity] | Callable[[], dict[str, Entity]],
|
||||||
func: str | HassJob,
|
func: str | HassJob,
|
||||||
call: ServiceCall,
|
call: ServiceCall,
|
||||||
required_features: Iterable[int] | None = None,
|
required_features: Iterable[int] | None = None,
|
||||||
@@ -799,10 +799,15 @@ async def entity_service_call(
|
|||||||
else:
|
else:
|
||||||
data = call
|
data = call
|
||||||
|
|
||||||
|
if callable(registered_entities):
|
||||||
|
_registered_entities = registered_entities()
|
||||||
|
else:
|
||||||
|
_registered_entities = registered_entities
|
||||||
|
|
||||||
# A list with entities to call the service on.
|
# A list with entities to call the service on.
|
||||||
entity_candidates = _get_permissible_entity_candidates(
|
entity_candidates = _get_permissible_entity_candidates(
|
||||||
call,
|
call,
|
||||||
registered_entities,
|
_registered_entities,
|
||||||
entity_perms,
|
entity_perms,
|
||||||
target_all_entities,
|
target_all_entities,
|
||||||
all_referenced,
|
all_referenced,
|
||||||
@@ -1112,6 +1117,23 @@ class ReloadServiceHelper[_T]:
|
|||||||
self._service_condition.notify_all()
|
self._service_condition.notify_all()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_entity_service_schema(
|
||||||
|
schema: VolDictType | VolSchemaType | None,
|
||||||
|
) -> VolSchemaType:
|
||||||
|
"""Validate that a schema is an entity service schema."""
|
||||||
|
if schema is None or isinstance(schema, dict):
|
||||||
|
return cv.make_entity_service_schema(schema)
|
||||||
|
if not cv.is_entity_service_schema(schema):
|
||||||
|
from .frame import ReportBehavior, report_usage # noqa: PLC0415
|
||||||
|
|
||||||
|
report_usage(
|
||||||
|
"registers an entity service with a non entity service schema",
|
||||||
|
core_behavior=ReportBehavior.LOG,
|
||||||
|
breaks_in_ha_version="2025.9",
|
||||||
|
)
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_entity_service(
|
def async_register_entity_service(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@@ -1131,16 +1153,7 @@ def async_register_entity_service(
|
|||||||
EntityPlatform.async_register_entity_service and should not be called
|
EntityPlatform.async_register_entity_service and should not be called
|
||||||
directly by integrations.
|
directly by integrations.
|
||||||
"""
|
"""
|
||||||
if schema is None or isinstance(schema, dict):
|
schema = _validate_entity_service_schema(schema)
|
||||||
schema = cv.make_entity_service_schema(schema)
|
|
||||||
elif not cv.is_entity_service_schema(schema):
|
|
||||||
from .frame import ReportBehavior, report_usage # noqa: PLC0415
|
|
||||||
|
|
||||||
report_usage(
|
|
||||||
"registers an entity service with a non entity service schema",
|
|
||||||
core_behavior=ReportBehavior.LOG,
|
|
||||||
breaks_in_ha_version="2025.9",
|
|
||||||
)
|
|
||||||
|
|
||||||
service_func: str | HassJob[..., Any]
|
service_func: str | HassJob[..., Any]
|
||||||
service_func = func if isinstance(func, str) else HassJob(func)
|
service_func = func if isinstance(func, str) else HassJob(func)
|
||||||
@@ -1159,3 +1172,47 @@ def async_register_entity_service(
|
|||||||
supports_response,
|
supports_response,
|
||||||
job_type=job_type,
|
job_type=job_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_register_platform_entity_service(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
service_domain: str,
|
||||||
|
service_name: str,
|
||||||
|
*,
|
||||||
|
entity_domain: str,
|
||||||
|
func: str | Callable[..., Any],
|
||||||
|
required_features: Iterable[int] | None = None,
|
||||||
|
schema: VolDictType | VolSchemaType | None,
|
||||||
|
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||||
|
) -> None:
|
||||||
|
"""Help registering a platform entity service."""
|
||||||
|
from .entity_platform import DATA_DOMAIN_PLATFORM_ENTITIES # noqa: PLC0415
|
||||||
|
|
||||||
|
schema = _validate_entity_service_schema(schema)
|
||||||
|
|
||||||
|
service_func: str | HassJob[..., Any]
|
||||||
|
service_func = func if isinstance(func, str) else HassJob(func)
|
||||||
|
|
||||||
|
def get_entities() -> dict[str, Entity]:
|
||||||
|
entities = hass.data.get(DATA_DOMAIN_PLATFORM_ENTITIES, {}).get(
|
||||||
|
(entity_domain, service_domain)
|
||||||
|
)
|
||||||
|
if entities is None:
|
||||||
|
return {}
|
||||||
|
return entities
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
service_domain,
|
||||||
|
service_name,
|
||||||
|
partial(
|
||||||
|
entity_service_call,
|
||||||
|
hass,
|
||||||
|
get_entities,
|
||||||
|
service_func,
|
||||||
|
required_features=required_features,
|
||||||
|
),
|
||||||
|
schema,
|
||||||
|
supports_response,
|
||||||
|
job_type=HassJobType.Coroutinefunction,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from collections.abc import Iterable
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
@@ -36,6 +37,7 @@ from homeassistant.core import (
|
|||||||
ServiceCall,
|
ServiceCall,
|
||||||
ServiceResponse,
|
ServiceResponse,
|
||||||
SupportsResponse,
|
SupportsResponse,
|
||||||
|
callback,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
@@ -55,6 +57,7 @@ from homeassistant.util.yaml.loader import parse_yaml
|
|||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockEntity,
|
MockEntity,
|
||||||
|
MockEntityPlatform,
|
||||||
MockModule,
|
MockModule,
|
||||||
MockUser,
|
MockUser,
|
||||||
RegistryEntryWithDefaults,
|
RegistryEntryWithDefaults,
|
||||||
@@ -2461,3 +2464,330 @@ async def test_deprecated_async_extract_referenced_entity_ids(
|
|||||||
assert args[0][2] is False
|
assert args[0][2] is False
|
||||||
|
|
||||||
assert dataclasses.asdict(result) == dataclasses.asdict(mock_selected)
|
assert dataclasses.asdict(result) == dataclasses.asdict(mock_selected)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_platform_entity_service(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test registering a platform entity service."""
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def handle_service(entity, *_):
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
entity_domain="mock_integration",
|
||||||
|
schema={},
|
||||||
|
func=handle_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
|
||||||
|
)
|
||||||
|
assert entities == []
|
||||||
|
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = MockEntity(entity_id="mock_integration.entity1")
|
||||||
|
entity2 = MockEntity(entity_id="mock_integration.entity2")
|
||||||
|
await entity_platform.async_add_entities([entity1, entity2])
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entities == unordered([entity1, entity2])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_platform_entity_service_response_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test an entity service that supports response data."""
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
target: MockEntity, call: ServiceCall
|
||||||
|
) -> ServiceResponse:
|
||||||
|
assert call.return_response
|
||||||
|
return {"response-key": "response-value"}
|
||||||
|
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
entity_domain="mock_integration",
|
||||||
|
schema={"some": str},
|
||||||
|
func=generate_response,
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity = MockEntity(entity_id="mock_integration.entity")
|
||||||
|
await entity_platform.async_add_entities([entity])
|
||||||
|
|
||||||
|
response_data = await hass.services.async_call(
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
service_data={"some": "data"},
|
||||||
|
target={"entity_id": [entity.entity_id]},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
assert response_data == {
|
||||||
|
"mock_integration.entity": {"response-key": "response-value"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_platform_entity_service_response_data_multiple_matches(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test an entity service with response data and matching many entities."""
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
target: MockEntity, call: ServiceCall
|
||||||
|
) -> ServiceResponse:
|
||||||
|
assert call.return_response
|
||||||
|
return {"response-key": f"response-value-{target.entity_id}"}
|
||||||
|
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
entity_domain="mock_integration",
|
||||||
|
schema={"some": str},
|
||||||
|
func=generate_response,
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = MockEntity(entity_id="mock_integration.entity1")
|
||||||
|
entity2 = MockEntity(entity_id="mock_integration.entity2")
|
||||||
|
await entity_platform.async_add_entities([entity1, entity2])
|
||||||
|
|
||||||
|
response_data = await hass.services.async_call(
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
service_data={"some": "data"},
|
||||||
|
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
assert response_data == {
|
||||||
|
"mock_integration.entity1": {
|
||||||
|
"response-key": "response-value-mock_integration.entity1"
|
||||||
|
},
|
||||||
|
"mock_integration.entity2": {
|
||||||
|
"response-key": "response-value-mock_integration.entity2"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_platform_entity_service_response_data_multiple_matches_raises(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test entity service response matching many entities raises."""
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
target: MockEntity, call: ServiceCall
|
||||||
|
) -> ServiceResponse:
|
||||||
|
assert call.return_response
|
||||||
|
if target.entity_id == "mock_integration.entity1":
|
||||||
|
raise RuntimeError("Something went wrong")
|
||||||
|
return {"response-key": f"response-value-{target.entity_id}"}
|
||||||
|
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
entity_domain="mock_integration",
|
||||||
|
schema={"some": str},
|
||||||
|
func=generate_response,
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = MockEntity(entity_id="mock_integration.entity1")
|
||||||
|
entity2 = MockEntity(entity_id="mock_integration.entity2")
|
||||||
|
await entity_platform.async_add_entities([entity1, entity2])
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Something went wrong"):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
service_data={"some": "data"},
|
||||||
|
target={"entity_id": [entity1.entity_id, entity2.entity_id]},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_platform_entity_service_limited_to_matching_platforms(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test entity services only target entities for the platform and domain."""
|
||||||
|
|
||||||
|
mock_area = area_registry.async_get_or_create("mock_area")
|
||||||
|
|
||||||
|
entity1_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "mock_platform", "1234", suggested_object_id="entity1"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity1_entry.entity_id, area_id=mock_area.id)
|
||||||
|
entity2_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "mock_platform", "5678", suggested_object_id="entity2"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity2_entry.entity_id, area_id=mock_area.id)
|
||||||
|
entity3_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "other_mock_platform", "7891", suggested_object_id="entity3"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity3_entry.entity_id, area_id=mock_area.id)
|
||||||
|
entity4_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "other_mock_platform", "1433", suggested_object_id="entity4"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity4_entry.entity_id, area_id=mock_area.id)
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
target: MockEntity, call: ServiceCall
|
||||||
|
) -> ServiceResponse:
|
||||||
|
assert call.return_response
|
||||||
|
return {"response-key": f"response-value-{target.entity_id}"}
|
||||||
|
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
entity_domain="base_platform",
|
||||||
|
schema={"some": str},
|
||||||
|
func=generate_response,
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="base_platform", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = MockEntity(
|
||||||
|
entity_id=entity1_entry.entity_id, unique_id=entity1_entry.unique_id
|
||||||
|
)
|
||||||
|
entity2 = MockEntity(
|
||||||
|
entity_id=entity2_entry.entity_id, unique_id=entity2_entry.unique_id
|
||||||
|
)
|
||||||
|
await entity_platform.async_add_entities([entity1, entity2])
|
||||||
|
|
||||||
|
other_entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="base_platform", platform_name="other_mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity3 = MockEntity(
|
||||||
|
entity_id=entity3_entry.entity_id, unique_id=entity3_entry.unique_id
|
||||||
|
)
|
||||||
|
entity4 = MockEntity(
|
||||||
|
entity_id=entity4_entry.entity_id, unique_id=entity4_entry.unique_id
|
||||||
|
)
|
||||||
|
await other_entity_platform.async_add_entities([entity3, entity4])
|
||||||
|
|
||||||
|
response_data = await hass.services.async_call(
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
service_data={"some": "data"},
|
||||||
|
target={"area_id": [mock_area.id]},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
# We should not target entity3 and entity4 even though they are in the area
|
||||||
|
# because they are only part of the domain and not the platform
|
||||||
|
assert response_data == {
|
||||||
|
"base_platform.entity1": {
|
||||||
|
"response-key": "response-value-base_platform.entity1"
|
||||||
|
},
|
||||||
|
"base_platform.entity2": {
|
||||||
|
"response-key": "response-value-base_platform.entity2"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_platform_entity_service_none_schema(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test registering a service with schema set to None."""
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def handle_service(entity, *_):
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
entity_domain="mock_integration",
|
||||||
|
schema=None,
|
||||||
|
func=handle_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = MockEntity(name="entity_1")
|
||||||
|
entity2 = MockEntity(name="entity_1")
|
||||||
|
await entity_platform.async_add_entities([entity1, entity2])
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(entities) == 2
|
||||||
|
assert entity1 in entities
|
||||||
|
assert entity2 in entities
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_platform_entity_service_non_entity_service_schema(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test attempting to register a service with a non entity service schema."""
|
||||||
|
expected_message = "registers an entity service with a non entity service schema"
|
||||||
|
|
||||||
|
for idx, schema in enumerate(
|
||||||
|
(
|
||||||
|
vol.Schema({"some": str}),
|
||||||
|
vol.All(vol.Schema({"some": str})),
|
||||||
|
vol.Any(vol.Schema({"some": str})),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
f"hello_{idx}",
|
||||||
|
entity_domain="mock_integration",
|
||||||
|
schema=schema,
|
||||||
|
func=Mock(),
|
||||||
|
)
|
||||||
|
assert expected_message in caplog.text
|
||||||
|
caplog.clear()
|
||||||
|
|
||||||
|
for idx, schema in enumerate(
|
||||||
|
(
|
||||||
|
cv.make_entity_service_schema({"some": str}),
|
||||||
|
vol.Schema(cv.make_entity_service_schema({"some": str})),
|
||||||
|
vol.All(cv.make_entity_service_schema({"some": str})),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
service.async_register_platform_entity_service(
|
||||||
|
hass,
|
||||||
|
"mock_platform",
|
||||||
|
f"test_service_{idx}",
|
||||||
|
entity_domain="mock_integration",
|
||||||
|
schema=schema,
|
||||||
|
func=Mock(),
|
||||||
|
)
|
||||||
|
assert expected_message not in caplog.text
|
||||||
|
assert not any(x.levelno > logging.DEBUG for x in caplog.records)
|
||||||
|
|||||||
Reference in New Issue
Block a user