diff --git a/homeassistant/components/rest_command/__init__.py b/homeassistant/components/rest_command/__init__.py index dcf790748ec..a07ca03a258 100644 --- a/homeassistant/components/rest_command/__init__.py +++ b/homeassistant/components/rest_command/__init__.py @@ -1,6 +1,7 @@ """Support for exposing regular REST commands as services.""" import asyncio from http import HTTPStatus +from json.decoder import JSONDecodeError import logging import aiohttp @@ -18,7 +19,14 @@ from homeassistant.const import ( CONF_VERIFY_SSL, SERVICE_RELOAD, ) -from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.core import ( + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, + callback, +) +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv from homeassistant.helpers.reload import async_integration_yaml_config @@ -98,17 +106,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: template_payload = command_config[CONF_PAYLOAD] template_payload.hass = hass - template_headers = None - if CONF_HEADERS in command_config: - template_headers = command_config[CONF_HEADERS] - for template_header in template_headers.values(): - template_header.hass = hass + template_headers = command_config.get(CONF_HEADERS, {}) + for template_header in template_headers.values(): + template_header.hass = hass - content_type = None - if CONF_CONTENT_TYPE in command_config: - content_type = command_config[CONF_CONTENT_TYPE] + content_type = command_config.get(CONF_CONTENT_TYPE) - async def async_service_handler(service: ServiceCall) -> None: + async def async_service_handler(service: ServiceCall) -> ServiceResponse: """Execute a shell command service.""" payload = None if template_payload: @@ -123,17 +127,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: variables=service.data, parse_result=False ) - headers = None - if template_headers: - headers = {} - for header_name, template_header in template_headers.items(): - headers[header_name] = template_header.async_render( - variables=service.data, parse_result=False - ) + headers = {} + for header_name, template_header in template_headers.items(): + headers[header_name] = template_header.async_render( + variables=service.data, parse_result=False + ) if content_type: - if headers is None: - headers = {} headers[hdrs.CONTENT_TYPE] = content_type try: @@ -141,7 +141,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: request_url, data=payload, auth=auth, - headers=headers, + headers=headers or None, timeout=timeout, ) as response: if response.status < HTTPStatus.BAD_REQUEST: @@ -159,8 +159,30 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: payload, ) - except asyncio.TimeoutError: + if not service.return_response: + return None + + _content = None + try: + if response.content_type == "application/json": + _content = await response.json() + else: + _content = await response.text() + except (JSONDecodeError, AttributeError) as err: + _LOGGER.error("Response of `%s` has invalid JSON", request_url) + raise HomeAssistantError from err + + except UnicodeDecodeError as err: + _LOGGER.error( + "Response of `%s` could not be interpreted as text", + request_url, + ) + raise HomeAssistantError from err + return {"content": _content, "status": response.status} + + except asyncio.TimeoutError as err: _LOGGER.warning("Timeout call %s", request_url) + raise HomeAssistantError from err except aiohttp.ClientError as err: _LOGGER.error( @@ -168,9 +190,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: request_url, err, ) + raise HomeAssistantError from err # register services - hass.services.async_register(DOMAIN, name, async_service_handler) + hass.services.async_register( + DOMAIN, + name, + async_service_handler, + supports_response=SupportsResponse.OPTIONAL, + ) for name, command_config in config[DOMAIN].items(): async_register_rest_command(name, command_config) diff --git a/tests/components/rest_command/test_init.py b/tests/components/rest_command/test_init.py index c43fe84ea8f..ce0359e0fdb 100644 --- a/tests/components/rest_command/test_init.py +++ b/tests/components/rest_command/test_init.py @@ -1,9 +1,11 @@ """The tests for the rest command platform.""" import asyncio +import base64 from http import HTTPStatus from unittest.mock import patch import aiohttp +import pytest import homeassistant.components.rest_command as rc from homeassistant.const import ( @@ -11,6 +13,7 @@ from homeassistant.const import ( CONTENT_TYPE_TEXT_PLAIN, SERVICE_RELOAD, ) +from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import setup_component from tests.common import assert_setup_component, get_test_home_assistant @@ -352,3 +355,94 @@ class TestRestCommandComponent: == "text/json" ) assert aioclient_mock.mock_calls[6][3].get("Accept") == "application/2json" + + def test_rest_command_get_response_plaintext(self, aioclient_mock): + """Get rest_command response, text.""" + with assert_setup_component(5): + setup_component(self.hass, rc.DOMAIN, self.config) + + aioclient_mock.get( + self.url, content=b"success", headers={"content-type": "text/plain"} + ) + + response = self.hass.services.call( + rc.DOMAIN, "get_test", {}, blocking=True, return_response=True + ) + self.hass.block_till_done() + + assert len(aioclient_mock.mock_calls) == 1 + assert response["content"] == "success" + assert response["status"] == 200 + + def test_rest_command_get_response_json(self, aioclient_mock): + """Get rest_command response, json.""" + with assert_setup_component(5): + setup_component(self.hass, rc.DOMAIN, self.config) + + aioclient_mock.get( + self.url, + json={"status": "success", "number": 42}, + headers={"content-type": "application/json"}, + ) + + response = self.hass.services.call( + rc.DOMAIN, "get_test", {}, blocking=True, return_response=True + ) + self.hass.block_till_done() + + assert len(aioclient_mock.mock_calls) == 1 + assert response["content"]["status"] == "success" + assert response["content"]["number"] == 42 + assert response["status"] == 200 + + def test_rest_command_get_response_malformed_json(self, aioclient_mock): + """Get rest_command response, malformed json.""" + with assert_setup_component(5): + setup_component(self.hass, rc.DOMAIN, self.config) + + aioclient_mock.get( + self.url, + content='{"status": "failure", 42', + headers={"content-type": "application/json"}, + ) + + # No problem without 'return_response' + response = self.hass.services.call(rc.DOMAIN, "get_test", {}, blocking=True) + self.hass.block_till_done() + assert not response + + # Throws error when requesting response + with pytest.raises(HomeAssistantError): + response = self.hass.services.call( + rc.DOMAIN, "get_test", {}, blocking=True, return_response=True + ) + self.hass.block_till_done() + + def test_rest_command_get_response_none(self, aioclient_mock): + """Get rest_command response, other.""" + with assert_setup_component(5): + setup_component(self.hass, rc.DOMAIN, self.config) + + png = base64.decodebytes( + b"iVBORw0KGgoAAAANSUhEUgAAAAIAAAABCAIAAAB7QOjdAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAAFiUAABYlAUlSJPAAAAAPSURBVBhXY/h/ku////8AECAE1JZPvDAAAAAASUVORK5CYII=" + ) + + aioclient_mock.get( + self.url, + content=png, + headers={"content-type": "text/plain"}, + ) + + # No problem without 'return_response' + response = self.hass.services.call(rc.DOMAIN, "get_test", {}, blocking=True) + self.hass.block_till_done() + assert not response + + # Throws Decode error when requesting response + with pytest.raises(HomeAssistantError): + response = self.hass.services.call( + rc.DOMAIN, "get_test", {}, blocking=True, return_response=True + ) + self.hass.block_till_done() + + assert not response