diff --git a/homeassistant/components/rest_command.py b/homeassistant/components/rest_command.py index 026f0e9a19b..4632315b757 100644 --- a/homeassistant/components/rest_command.py +++ b/homeassistant/components/rest_command.py @@ -14,7 +14,7 @@ import voluptuous as vol from homeassistant.const import ( CONF_TIMEOUT, CONF_USERNAME, CONF_PASSWORD, CONF_URL, CONF_PAYLOAD, - CONF_METHOD) + CONF_METHOD, CONF_HEADERS) from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv @@ -38,6 +38,7 @@ COMMAND_SCHEMA = vol.Schema({ vol.Required(CONF_URL): cv.template, vol.Optional(CONF_METHOD, default=DEFAULT_METHOD): vol.All(vol.Lower, vol.In(SUPPORT_REST_METHODS)), + vol.Optional(CONF_HEADERS): vol.Schema({cv.string: cv.string}), vol.Inclusive(CONF_USERNAME, 'authentication'): cv.string, vol.Inclusive(CONF_PASSWORD, 'authentication'): cv.string, vol.Optional(CONF_PAYLOAD): cv.template, @@ -77,9 +78,14 @@ def async_setup(hass, config): template_payload.hass = hass headers = None + if CONF_HEADERS in command_config: + headers = command_config[CONF_HEADERS] + if CONF_CONTENT_TYPE in command_config: content_type = command_config[CONF_CONTENT_TYPE] - headers = {hdrs.CONTENT_TYPE: content_type} + if headers is None: + headers = {} + headers[hdrs.CONTENT_TYPE] = content_type @asyncio.coroutine def async_service_handler(service): diff --git a/homeassistant/components/sensor/rest.py b/homeassistant/components/sensor/rest.py index 19f5a1c271e..c295dcf16dc 100644 --- a/homeassistant/components/sensor/rest.py +++ b/homeassistant/components/sensor/rest.py @@ -35,7 +35,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Required(CONF_RESOURCE): cv.url, vol.Optional(CONF_AUTHENTICATION): vol.In([HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION]), - vol.Optional(CONF_HEADERS): {cv.string: cv.string}, + vol.Optional(CONF_HEADERS): vol.Schema({cv.string: cv.string}), vol.Optional(CONF_JSON_ATTRS, default=[]): cv.ensure_list_csv, vol.Optional(CONF_METHOD, default=DEFAULT_METHOD): vol.In(METHODS), vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, diff --git a/tests/components/test_rest_command.py b/tests/components/test_rest_command.py index 9dbea53cd64..3ddcfae8c01 100644 --- a/tests/components/test_rest_command.py +++ b/tests/components/test_rest_command.py @@ -222,21 +222,82 @@ class TestRestCommandComponent(object): assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == b'data' - def test_rest_command_content_type(self, aioclient_mock): - """Call a rest command with a content type.""" - data = { - 'payload': 'item', - 'content_type': 'text/plain' + def test_rest_command_headers(self, aioclient_mock): + """Call a rest command with custom headers and content types.""" + header_config_variations = { + rc.DOMAIN: { + 'no_headers_test': {}, + 'content_type_test': { + 'content_type': 'text/plain' + }, + 'headers_test': { + 'headers': { + 'Accept': 'application/json', + 'User-Agent': 'Mozilla/5.0' + } + }, + 'headers_and_content_type_test': { + 'headers': { + 'Accept': 'application/json' + }, + 'content_type': 'text/plain' + }, + 'headers_and_content_type_override_test': { + 'headers': { + 'Accept': 'application/json', + aiohttp.hdrs.CONTENT_TYPE: 'application/pdf' + }, + 'content_type': 'text/plain' + } + } } - self.config[rc.DOMAIN]['post_test'].update(data) - with assert_setup_component(4): - setup_component(self.hass, rc.DOMAIN, self.config) + # add common parameters + for variation in header_config_variations[rc.DOMAIN].values(): + variation.update({'url': self.url, 'method': 'post', + 'payload': 'test data'}) + with assert_setup_component(5): + setup_component(self.hass, rc.DOMAIN, header_config_variations) + + # provide post request data aioclient_mock.post(self.url, content=b'success') - self.hass.services.call(rc.DOMAIN, 'post_test', {}) - self.hass.block_till_done() + for test_service in ['no_headers_test', + 'content_type_test', + 'headers_test', + 'headers_and_content_type_test', + 'headers_and_content_type_override_test']: + self.hass.services.call(rc.DOMAIN, test_service, {}) - assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[0][2] == b'item' + self.hass.block_till_done() + assert len(aioclient_mock.mock_calls) == 5 + + # no_headers_test + assert aioclient_mock.mock_calls[0][3] is None + + # content_type_test + assert len(aioclient_mock.mock_calls[1][3]) == 1 + assert aioclient_mock.mock_calls[1][3].get( + aiohttp.hdrs.CONTENT_TYPE) == 'text/plain' + + # headers_test + assert len(aioclient_mock.mock_calls[2][3]) == 2 + assert aioclient_mock.mock_calls[2][3].get( + 'Accept') == 'application/json' + assert aioclient_mock.mock_calls[2][3].get( + 'User-Agent') == 'Mozilla/5.0' + + # headers_and_content_type_test + assert len(aioclient_mock.mock_calls[3][3]) == 2 + assert aioclient_mock.mock_calls[3][3].get( + aiohttp.hdrs.CONTENT_TYPE) == 'text/plain' + assert aioclient_mock.mock_calls[3][3].get( + 'Accept') == 'application/json' + + # headers_and_content_type_override_test + assert len(aioclient_mock.mock_calls[4][3]) == 2 + assert aioclient_mock.mock_calls[4][3].get( + aiohttp.hdrs.CONTENT_TYPE) == 'text/plain' + assert aioclient_mock.mock_calls[4][3].get( + 'Accept') == 'application/json'