diff --git a/homeassistant/components/rest_command/__init__.py b/homeassistant/components/rest_command/__init__.py index 10f37b6ac4c..7dfbb964167 100644 --- a/homeassistant/components/rest_command/__init__.py +++ b/homeassistant/components/rest_command/__init__.py @@ -37,7 +37,7 @@ COMMAND_SCHEMA = vol.Schema( vol.Optional(CONF_METHOD, default=DEFAULT_METHOD): vol.All( vol.Lower, vol.In(SUPPORT_REST_METHODS) ), - vol.Optional(CONF_HEADERS): vol.Schema({cv.string: cv.string}), + vol.Optional(CONF_HEADERS): vol.Schema({cv.string: cv.template}), vol.Inclusive(CONF_USERNAME, "authentication"): cv.string, vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string, vol.Optional(CONF_PAYLOAD): cv.template, @@ -75,15 +75,15 @@ async def async_setup(hass, config): template_payload = command_config[CONF_PAYLOAD] template_payload.hass = hass - headers = None + template_headers = None if CONF_HEADERS in command_config: - headers = command_config[CONF_HEADERS] + template_headers = command_config[CONF_HEADERS] + for template_header in template_headers.values(): + template_header.hass = hass + content_type = None if CONF_CONTENT_TYPE in command_config: content_type = command_config[CONF_CONTENT_TYPE] - if headers is None: - headers = {} - headers[hdrs.CONTENT_TYPE] = content_type async def async_service_handler(service): """Execute a shell command service.""" @@ -94,6 +94,20 @@ async def async_setup(hass, config): ) request_url = template_url.async_render(variables=service.data) + + headers = None + if template_headers: + headers = {} + for header_name, template_header in template_headers.items(): + headers[header_name] = template_header.async_render( + variables=service.data + ) + + if content_type: + if headers is None: + headers = {} + headers[hdrs.CONTENT_TYPE] = content_type + try: async with getattr(websession, method)( request_url, diff --git a/tests/components/rest_command/test_init.py b/tests/components/rest_command/test_init.py index b7ac5a4be8a..ba63091041d 100644 --- a/tests/components/rest_command/test_init.py +++ b/tests/components/rest_command/test_init.py @@ -236,6 +236,19 @@ class TestRestCommandComponent: }, "content_type": "text/plain", }, + "headers_template_test": { + "headers": { + "Accept": "application/json", + "User-Agent": "Mozilla/{{ 3 + 2 }}.0", + } + }, + "headers_and_content_type_override_template_test": { + "headers": { + "Accept": "application/{{ 1 + 1 }}json", + aiohttp.hdrs.CONTENT_TYPE: "application/pdf", + }, + "content_type": "text/json", + }, } } @@ -245,7 +258,7 @@ class TestRestCommandComponent: {"url": self.url, "method": "post", "payload": "test data"} ) - with assert_setup_component(5): + with assert_setup_component(7): setup_component(self.hass, rc.DOMAIN, header_config_variations) # provide post request data @@ -257,11 +270,13 @@ class TestRestCommandComponent: "headers_test", "headers_and_content_type_test", "headers_and_content_type_override_test", + "headers_template_test", + "headers_and_content_type_override_template_test", ]: self.hass.services.call(rc.DOMAIN, test_service, {}) self.hass.block_till_done() - assert len(aioclient_mock.mock_calls) == 5 + assert len(aioclient_mock.mock_calls) == 7 # no_headers_test assert aioclient_mock.mock_calls[0][3] is None @@ -293,3 +308,16 @@ class TestRestCommandComponent: == "text/plain" ) assert aioclient_mock.mock_calls[4][3].get("Accept") == "application/json" + + # headers_template_test + assert len(aioclient_mock.mock_calls[5][3]) == 2 + assert aioclient_mock.mock_calls[5][3].get("Accept") == "application/json" + assert aioclient_mock.mock_calls[5][3].get("User-Agent") == "Mozilla/5.0" + + # headers_and_content_type_override_template_test + assert len(aioclient_mock.mock_calls[6][3]) == 2 + assert ( + aioclient_mock.mock_calls[6][3].get(aiohttp.hdrs.CONTENT_TYPE) + == "text/json" + ) + assert aioclient_mock.mock_calls[6][3].get("Accept") == "application/2json"