mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Add target to service call API (#45898)
* Add target to service call API * Fix _async_call_service_step * CONF_SERVICE_ENTITY_ID overrules target * Move merging up before processing schema * Restore services.yaml * Add test
This commit is contained in:
parent
7d2d98fc3c
commit
4b493c5ab9
@ -378,7 +378,7 @@ class APIDomainServicesView(HomeAssistantView):
|
|||||||
with AsyncTrackStates(hass) as changed_states:
|
with AsyncTrackStates(hass) as changed_states:
|
||||||
try:
|
try:
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
domain, service, data, True, self.context(request)
|
domain, service, data, blocking=True, context=self.context(request)
|
||||||
)
|
)
|
||||||
except (vol.Invalid, ServiceNotFound) as ex:
|
except (vol.Invalid, ServiceNotFound) as ex:
|
||||||
raise HTTPBadRequest() from ex
|
raise HTTPBadRequest() from ex
|
||||||
|
@ -121,6 +121,7 @@ def handle_unsubscribe_events(hass, connection, msg):
|
|||||||
vol.Required("type"): "call_service",
|
vol.Required("type"): "call_service",
|
||||||
vol.Required("domain"): str,
|
vol.Required("domain"): str,
|
||||||
vol.Required("service"): str,
|
vol.Required("service"): str,
|
||||||
|
vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS,
|
||||||
vol.Optional("service_data"): dict,
|
vol.Optional("service_data"): dict,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -139,6 +140,7 @@ async def handle_call_service(hass, connection, msg):
|
|||||||
msg.get("service_data"),
|
msg.get("service_data"),
|
||||||
blocking,
|
blocking,
|
||||||
context,
|
context,
|
||||||
|
target=msg.get("target"),
|
||||||
)
|
)
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
messages.result_message(msg["id"], {"context": context})
|
messages.result_message(msg["id"], {"context": context})
|
||||||
|
@ -1358,6 +1358,7 @@ class ServiceRegistry:
|
|||||||
blocking: bool = False,
|
blocking: bool = False,
|
||||||
context: Optional[Context] = None,
|
context: Optional[Context] = None,
|
||||||
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
||||||
|
target: Optional[Dict] = None,
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Call a service.
|
Call a service.
|
||||||
@ -1365,7 +1366,9 @@ class ServiceRegistry:
|
|||||||
See description of async_call for details.
|
See description of async_call for details.
|
||||||
"""
|
"""
|
||||||
return asyncio.run_coroutine_threadsafe(
|
return asyncio.run_coroutine_threadsafe(
|
||||||
self.async_call(domain, service, service_data, blocking, context, limit),
|
self.async_call(
|
||||||
|
domain, service, service_data, blocking, context, limit, target
|
||||||
|
),
|
||||||
self._hass.loop,
|
self._hass.loop,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@ -1377,6 +1380,7 @@ class ServiceRegistry:
|
|||||||
blocking: bool = False,
|
blocking: bool = False,
|
||||||
context: Optional[Context] = None,
|
context: Optional[Context] = None,
|
||||||
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
limit: Optional[float] = SERVICE_CALL_LIMIT,
|
||||||
|
target: Optional[Dict] = None,
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Call a service.
|
Call a service.
|
||||||
@ -1404,6 +1408,9 @@ class ServiceRegistry:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise ServiceNotFound(domain, service) from None
|
raise ServiceNotFound(domain, service) from None
|
||||||
|
|
||||||
|
if target:
|
||||||
|
service_data.update(target)
|
||||||
|
|
||||||
if handler.schema:
|
if handler.schema:
|
||||||
try:
|
try:
|
||||||
processed_data = handler.schema(service_data)
|
processed_data = handler.schema(service_data)
|
||||||
|
@ -433,14 +433,14 @@ class _ScriptRun:
|
|||||||
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
|
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
|
||||||
self._log("Executing step %s", self._script.last_action)
|
self._log("Executing step %s", self._script.last_action)
|
||||||
|
|
||||||
domain, service_name, service_data = service.async_prepare_call_from_config(
|
params = service.async_prepare_call_from_config(
|
||||||
self._hass, self._action, self._variables
|
self._hass, self._action, self._variables
|
||||||
)
|
)
|
||||||
|
|
||||||
running_script = (
|
running_script = (
|
||||||
domain == "automation"
|
params["domain"] == "automation"
|
||||||
and service_name == "trigger"
|
and params["service_name"] == "trigger"
|
||||||
or domain in ("python_script", "script")
|
or params["domain"] in ("python_script", "script")
|
||||||
)
|
)
|
||||||
# If this might start a script then disable the call timeout.
|
# If this might start a script then disable the call timeout.
|
||||||
# Otherwise use the normal service call limit.
|
# Otherwise use the normal service call limit.
|
||||||
@ -451,9 +451,7 @@ class _ScriptRun:
|
|||||||
|
|
||||||
service_task = self._hass.async_create_task(
|
service_task = self._hass.async_create_task(
|
||||||
self._hass.services.async_call(
|
self._hass.services.async_call(
|
||||||
domain,
|
**params,
|
||||||
service_name,
|
|
||||||
service_data,
|
|
||||||
blocking=True,
|
blocking=True,
|
||||||
context=self._context,
|
context=self._context,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
@ -14,6 +14,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -70,6 +71,15 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceParams(TypedDict):
|
||||||
|
"""Type for service call parameters."""
|
||||||
|
|
||||||
|
domain: str
|
||||||
|
service: str
|
||||||
|
service_data: Dict[str, Any]
|
||||||
|
target: Optional[Dict]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SelectedEntities:
|
class SelectedEntities:
|
||||||
"""Class to hold the selected entities."""
|
"""Class to hold the selected entities."""
|
||||||
@ -136,7 +146,7 @@ async def async_call_from_config(
|
|||||||
raise
|
raise
|
||||||
_LOGGER.error(ex)
|
_LOGGER.error(ex)
|
||||||
else:
|
else:
|
||||||
await hass.services.async_call(*params, blocking, context)
|
await hass.services.async_call(**params, blocking=blocking, context=context)
|
||||||
|
|
||||||
|
|
||||||
@ha.callback
|
@ha.callback
|
||||||
@ -146,7 +156,7 @@ def async_prepare_call_from_config(
|
|||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
variables: TemplateVarsType = None,
|
variables: TemplateVarsType = None,
|
||||||
validate_config: bool = False,
|
validate_config: bool = False,
|
||||||
) -> Tuple[str, str, Dict[str, Any]]:
|
) -> ServiceParams:
|
||||||
"""Prepare to call a service based on a config hash."""
|
"""Prepare to call a service based on a config hash."""
|
||||||
if validate_config:
|
if validate_config:
|
||||||
try:
|
try:
|
||||||
@ -177,10 +187,9 @@ def async_prepare_call_from_config(
|
|||||||
|
|
||||||
domain, service = domain_service.split(".", 1)
|
domain, service = domain_service.split(".", 1)
|
||||||
|
|
||||||
service_data = {}
|
target = config.get(CONF_TARGET)
|
||||||
|
|
||||||
if CONF_TARGET in config:
|
service_data = {}
|
||||||
service_data.update(config[CONF_TARGET])
|
|
||||||
|
|
||||||
for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]:
|
for conf in [CONF_SERVICE_DATA, CONF_SERVICE_DATA_TEMPLATE]:
|
||||||
if conf not in config:
|
if conf not in config:
|
||||||
@ -192,9 +201,17 @@ def async_prepare_call_from_config(
|
|||||||
raise HomeAssistantError(f"Error rendering data template: {ex}") from ex
|
raise HomeAssistantError(f"Error rendering data template: {ex}") from ex
|
||||||
|
|
||||||
if CONF_SERVICE_ENTITY_ID in config:
|
if CONF_SERVICE_ENTITY_ID in config:
|
||||||
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
if target:
|
||||||
|
target[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
||||||
|
else:
|
||||||
|
target = {ATTR_ENTITY_ID: config[CONF_SERVICE_ENTITY_ID]}
|
||||||
|
|
||||||
return domain, service, service_data
|
return {
|
||||||
|
"domain": domain,
|
||||||
|
"service": service,
|
||||||
|
"service_data": service_data,
|
||||||
|
"target": target,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
@ -431,6 +448,7 @@ async def async_get_all_descriptions(
|
|||||||
|
|
||||||
description = descriptions_cache[cache_key] = {
|
description = descriptions_cache[cache_key] = {
|
||||||
"description": yaml_description.get("description", ""),
|
"description": yaml_description.get("description", ""),
|
||||||
|
"target": yaml_description.get("target"),
|
||||||
"fields": yaml_description.get("fields", {}),
|
"fields": yaml_description.get("fields", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,6 +52,47 @@ async def test_call_service(hass, websocket_client):
|
|||||||
assert call.data == {"hello": "world"}
|
assert call.data == {"hello": "world"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_service_target(hass, websocket_client):
|
||||||
|
"""Test call service command with target."""
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def service_call(call):
|
||||||
|
calls.append(call)
|
||||||
|
|
||||||
|
hass.services.async_register("domain_test", "test_service", service_call)
|
||||||
|
|
||||||
|
await websocket_client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "call_service",
|
||||||
|
"domain": "domain_test",
|
||||||
|
"service": "test_service",
|
||||||
|
"service_data": {"hello": "world"},
|
||||||
|
"target": {
|
||||||
|
"entity_id": ["entity.one", "entity.two"],
|
||||||
|
"device_id": "deviceid",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg["id"] == 5
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
|
||||||
|
assert call.domain == "domain_test"
|
||||||
|
assert call.service == "test_service"
|
||||||
|
assert call.data == {
|
||||||
|
"hello": "world",
|
||||||
|
"entity_id": ["entity.one", "entity.two"],
|
||||||
|
"device_id": ["deviceid"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_call_service_not_found(hass, websocket_client):
|
async def test_call_service_not_found(hass, websocket_client):
|
||||||
"""Test call service command."""
|
"""Test call service command."""
|
||||||
await websocket_client.send_json(
|
await websocket_client.send_json(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user