From 4c181bbfe5f4f62dbd467c589c010ebecd844bd4 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 5 Mar 2021 15:34:18 -0800 Subject: [PATCH] Raise error instead of crashing when template passed to call service target (#47467) --- .../components/websocket_api/commands.py | 16 +++++--- .../components/websocket_api/test_commands.py | 40 +++++++++++-------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index ddd7548cd68..53531cf9ba9 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -13,10 +13,9 @@ from homeassistant.exceptions import ( TemplateError, Unauthorized, ) -from homeassistant.helpers import config_validation as cv, entity +from homeassistant.helpers import config_validation as cv, entity, template from homeassistant.helpers.event import TrackTemplate, async_track_template_result from homeassistant.helpers.service import async_get_all_descriptions -from homeassistant.helpers.template import Template from homeassistant.loader import IntegrationNotFound, async_get_integration from . import const, decorators, messages @@ -132,6 +131,11 @@ async def handle_call_service(hass, connection, msg): if msg["domain"] == HASS_DOMAIN and msg["service"] in ["restart", "stop"]: blocking = False + # We do not support templates. + target = msg.get("target") + if template.is_complex(target): + raise vol.Invalid("Templates are not supported here") + try: context = connection.context(msg) await hass.services.async_call( @@ -140,7 +144,7 @@ async def handle_call_service(hass, connection, msg): msg.get("service_data"), blocking, context, - target=msg.get("target"), + target=target, ) connection.send_message( messages.result_message(msg["id"], {"context": context}) @@ -256,14 +260,14 @@ def handle_ping(hass, connection, msg): async def handle_render_template(hass, connection, msg): """Handle render_template command.""" template_str = msg["template"] - template = Template(template_str, hass) + template_obj = template.Template(template_str, hass) variables = msg.get("variables") timeout = msg.get("timeout") info = None if timeout: try: - timed_out = await template.async_render_will_timeout(timeout) + timed_out = await template_obj.async_render_will_timeout(timeout) except TemplateError as ex: connection.send_error(msg["id"], const.ERR_TEMPLATE_ERROR, str(ex)) return @@ -294,7 +298,7 @@ async def handle_render_template(hass, connection, msg): try: info = async_track_template_result( hass, - [TrackTemplate(template, variables)], + [TrackTemplate(template_obj, variables)], _template_listener, raise_on_template_error=True, ) diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 1f7abc42c4e..f596db63c5e 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -21,13 +21,7 @@ from tests.common import MockEntity, MockEntityPlatform, async_mock_service async def test_call_service(hass, websocket_client): """Test call service command.""" - calls = [] - - @callback - def service_call(call): - calls.append(call) - - hass.services.async_register("domain_test", "test_service", service_call) + calls = async_mock_service(hass, "domain_test", "test_service") await websocket_client.send_json( { @@ -54,13 +48,7 @@ async def test_call_service(hass, websocket_client): async def test_call_service_target(hass, websocket_client): """Test call service command with target.""" - calls = [] - - @callback - def service_call(call): - calls.append(call) - - hass.services.async_register("domain_test", "test_service", service_call) + calls = async_mock_service(hass, "domain_test", "test_service") await websocket_client.send_json( { @@ -93,6 +81,28 @@ async def test_call_service_target(hass, websocket_client): } +async def test_call_service_target_template(hass, websocket_client): + """Test call service command with target does not allow template.""" + await websocket_client.send_json( + { + "id": 5, + "type": "call_service", + "domain": "domain_test", + "service": "test_service", + "service_data": {"hello": "world"}, + "target": { + "entity_id": "{{ 1 }}", + }, + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 5 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == const.ERR_INVALID_FORMAT + + async def test_call_service_not_found(hass, websocket_client): """Test call service command.""" await websocket_client.send_json( @@ -232,7 +242,6 @@ async def test_call_service_error(hass, websocket_client): ) msg = await websocket_client.receive_json() - print(msg) assert msg["id"] == 5 assert msg["type"] == const.TYPE_RESULT assert msg["success"] is False @@ -249,7 +258,6 @@ async def test_call_service_error(hass, websocket_client): ) msg = await websocket_client.receive_json() - print(msg) assert msg["id"] == 6 assert msg["type"] == const.TYPE_RESULT assert msg["success"] is False