diff --git a/homeassistant/components/rest/__init__.py b/homeassistant/components/rest/__init__.py index 8186db1c3c2..f2cffebdfcb 100644 --- a/homeassistant/components/rest/__init__.py +++ b/homeassistant/components/rest/__init__.py @@ -37,6 +37,7 @@ from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from .const import COORDINATOR, DOMAIN, PLATFORM_IDX, REST, REST_DATA, REST_IDX from .data import RestData from .schema import CONFIG_SCHEMA # noqa: F401 +from .utils import inject_hass_in_templates_list _LOGGER = logging.getLogger(__name__) @@ -161,6 +162,8 @@ def create_rest_data_from_config(hass, config): resource_template.hass = hass resource = resource_template.async_render(parse_result=False) + inject_hass_in_templates_list(hass, [headers, params]) + if username and password: if config.get(CONF_AUTHENTICATION) == HTTP_DIGEST_AUTHENTICATION: auth = httpx.DigestAuth(username, password) diff --git a/homeassistant/components/rest/data.py b/homeassistant/components/rest/data.py index 8b03bcfb128..513f2393127 100644 --- a/homeassistant/components/rest/data.py +++ b/homeassistant/components/rest/data.py @@ -3,6 +3,7 @@ import logging import httpx +from homeassistant.components.rest.utils import render_templates from homeassistant.helpers.httpx_client import get_async_client DEFAULT_TIMEOUT = 10 @@ -51,13 +52,16 @@ class RestData: self._hass, verify_ssl=self._verify_ssl ) + rendered_headers = render_templates(self._headers) + rendered_params = render_templates(self._params) + _LOGGER.debug("Updating from %s", self._resource) try: response = await self._async_client.request( self._method, self._resource, - headers=self._headers, - params=self._params, + headers=rendered_headers, + params=rendered_params, auth=self._auth, data=self._request_data, timeout=self._timeout, diff --git a/homeassistant/components/rest/schema.py b/homeassistant/components/rest/schema.py index a4b87051c4b..c5b6949bd39 100644 --- a/homeassistant/components/rest/schema.py +++ b/homeassistant/components/rest/schema.py @@ -54,8 +54,8 @@ RESOURCE_SCHEMA = { vol.Optional(CONF_AUTHENTICATION): vol.In( [HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION] ), - vol.Optional(CONF_HEADERS): vol.Schema({cv.string: cv.string}), - vol.Optional(CONF_PARAMS): vol.Schema({cv.string: cv.string}), + vol.Optional(CONF_HEADERS): vol.Schema({cv.string: cv.template}), + vol.Optional(CONF_PARAMS): vol.Schema({cv.string: cv.template}), vol.Optional(CONF_METHOD, default=DEFAULT_METHOD): vol.In(METHODS), vol.Optional(CONF_USERNAME): cv.string, vol.Optional(CONF_PASSWORD): cv.string, diff --git a/homeassistant/components/rest/switch.py b/homeassistant/components/rest/switch.py index e6b16de40aa..83bd5ae27ae 100644 --- a/homeassistant/components/rest/switch.py +++ b/homeassistant/components/rest/switch.py @@ -27,6 +27,8 @@ from homeassistant.const import ( from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv +from .utils import inject_hass_in_templates_list, render_templates + _LOGGER = logging.getLogger(__name__) CONF_BODY_OFF = "body_off" CONF_BODY_ON = "body_on" @@ -46,8 +48,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { vol.Required(CONF_RESOURCE): cv.url, vol.Optional(CONF_STATE_RESOURCE): cv.url, - vol.Optional(CONF_HEADERS): {cv.string: cv.string}, - vol.Optional(CONF_PARAMS): {cv.string: cv.string}, + vol.Optional(CONF_HEADERS): {cv.string: cv.template}, + vol.Optional(CONF_PARAMS): {cv.string: cv.template}, vol.Optional(CONF_BODY_OFF, default=DEFAULT_BODY_OFF): cv.template, vol.Optional(CONF_BODY_ON, default=DEFAULT_BODY_ON): cv.template, vol.Optional(CONF_IS_ON_TEMPLATE): cv.template, @@ -90,6 +92,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= body_on.hass = hass if body_off is not None: body_off.hass = hass + inject_hass_in_templates_list(hass, [headers, params]) timeout = config.get(CONF_TIMEOUT) try: @@ -204,13 +207,16 @@ class RestSwitch(SwitchEntity): """Send a state update to the device.""" websession = async_get_clientsession(self.hass, self._verify_ssl) + rendered_headers = render_templates(self._headers) + rendered_params = render_templates(self._params) + with async_timeout.timeout(self._timeout): req = await getattr(websession, self._method)( self._resource, auth=self._auth, data=bytes(body, "utf-8"), - headers=self._headers, - params=self._params, + headers=rendered_headers, + params=rendered_params, ) return req @@ -227,12 +233,15 @@ class RestSwitch(SwitchEntity): """Get the latest data from REST API and update the state.""" websession = async_get_clientsession(hass, self._verify_ssl) + rendered_headers = render_templates(self._headers) + rendered_params = render_templates(self._params) + with async_timeout.timeout(self._timeout): req = await websession.get( self._state_resource, auth=self._auth, - headers=self._headers, - params=self._params, + headers=rendered_headers, + params=rendered_params, ) text = await req.text() diff --git a/homeassistant/components/rest/utils.py b/homeassistant/components/rest/utils.py new file mode 100644 index 00000000000..24c58d294e1 --- /dev/null +++ b/homeassistant/components/rest/utils.py @@ -0,0 +1,27 @@ +"""Reusable utilities for the Rest component.""" +from __future__ import annotations + +from homeassistant.core import HomeAssistant +from homeassistant.helpers.template import Template + + +def inject_hass_in_templates_list( + hass: HomeAssistant, tpl_dict_list: list[dict[str, Template] | None] +): + """Inject hass in a list of dict of templates.""" + for tpl_dict in tpl_dict_list: + if tpl_dict is not None: + for tpl in tpl_dict.values(): + tpl.hass = hass + + +def render_templates(tpl_dict: dict[str, Template] | None): + """Render a dict of templates.""" + if tpl_dict is None: + return None + + rendered_items = {} + for item_name, template_header in tpl_dict.items(): + if (value := template_header.async_render()) is not None: + rendered_items[item_name] = value + return rendered_items diff --git a/tests/components/rest/test_binary_sensor.py b/tests/components/rest/test_binary_sensor.py index 8160a5976a7..a0cd7d5108c 100644 --- a/tests/components/rest/test_binary_sensor.py +++ b/tests/components/rest/test_binary_sensor.py @@ -179,6 +179,40 @@ async def test_setup_get(hass): assert state.attributes[ATTR_DEVICE_CLASS] == binary_sensor.DEVICE_CLASS_PLUG +@respx.mock +async def test_setup_get_template_headers_params(hass): + """Test setup with valid configuration.""" + respx.get("http://localhost").respond(status_code=200, json={}) + assert await async_setup_component( + hass, + "sensor", + { + "sensor": { + "platform": "rest", + "resource": "http://localhost", + "method": "GET", + "value_template": "{{ value_json.key }}", + "name": "foo", + "verify_ssl": "true", + "timeout": 30, + "headers": { + "Accept": CONTENT_TYPE_JSON, + "User-Agent": "Mozilla/{{ 3 + 2 }}.0", + }, + "params": { + "start": 0, + "end": "{{ 3 + 2 }}", + }, + } + }, + ) + await async_setup_component(hass, "homeassistant", {}) + + assert respx.calls.last.request.headers["Accept"] == CONTENT_TYPE_JSON + assert respx.calls.last.request.headers["User-Agent"] == "Mozilla/5.0" + assert respx.calls.last.request.url.query == b"start=0&end=5" + + @respx.mock async def test_setup_get_digest_auth(hass): """Test setup with valid configuration.""" diff --git a/tests/components/rest/test_sensor.py b/tests/components/rest/test_sensor.py index 4ff8ca12dad..a576acb2fe3 100644 --- a/tests/components/rest/test_sensor.py +++ b/tests/components/rest/test_sensor.py @@ -217,6 +217,40 @@ async def test_setup_get(hass): assert state.attributes[sensor.ATTR_STATE_CLASS] == sensor.STATE_CLASS_MEASUREMENT +@respx.mock +async def test_setup_get_templated_headers_params(hass): + """Test setup with valid configuration.""" + respx.get("http://localhost").respond(status_code=200, json={}) + assert await async_setup_component( + hass, + "sensor", + { + "sensor": { + "platform": "rest", + "resource": "http://localhost", + "method": "GET", + "value_template": "{{ value_json.key }}", + "name": "foo", + "verify_ssl": "true", + "timeout": 30, + "headers": { + "Accept": CONTENT_TYPE_JSON, + "User-Agent": "Mozilla/{{ 3 + 2 }}.0", + }, + "params": { + "start": 0, + "end": "{{ 3 + 2 }}", + }, + } + }, + ) + await async_setup_component(hass, "homeassistant", {}) + + assert respx.calls.last.request.headers["Accept"] == CONTENT_TYPE_JSON + assert respx.calls.last.request.headers["User-Agent"] == "Mozilla/5.0" + assert respx.calls.last.request.url.query == b"start=0&end=5" + + @respx.mock async def test_setup_get_digest_auth(hass): """Test setup with valid configuration.""" diff --git a/tests/components/rest/test_switch.py b/tests/components/rest/test_switch.py index 4370386dcff..1b724052b1e 100644 --- a/tests/components/rest/test_switch.py +++ b/tests/components/rest/test_switch.py @@ -27,7 +27,6 @@ DEVICE_CLASS = DEVICE_CLASS_SWITCH METHOD = "post" RESOURCE = "http://localhost/" STATE_RESOURCE = RESOURCE -HEADERS = {"Content-type": CONTENT_TYPE_JSON} AUTH = None PARAMS = None @@ -151,19 +150,51 @@ async def test_setup_with_state_resource(hass, aioclient_mock): assert_setup_component(1, SWITCH_DOMAIN) +async def test_setup_with_templated_headers_params(hass, aioclient_mock): + """Test setup with valid configuration.""" + aioclient_mock.get("http://localhost", status=HTTPStatus.OK) + assert await async_setup_component( + hass, + SWITCH_DOMAIN, + { + SWITCH_DOMAIN: { + CONF_PLATFORM: DOMAIN, + CONF_NAME: "foo", + CONF_RESOURCE: "http://localhost", + CONF_HEADERS: { + "Accept": CONTENT_TYPE_JSON, + "User-Agent": "Mozilla/{{ 3 + 2 }}.0", + }, + CONF_PARAMS: { + "start": 0, + "end": "{{ 3 + 2 }}", + }, + } + }, + ) + await hass.async_block_till_done() + assert aioclient_mock.call_count == 1 + assert aioclient_mock.mock_calls[-1][3].get("Accept") == CONTENT_TYPE_JSON + assert aioclient_mock.mock_calls[-1][3].get("User-Agent") == "Mozilla/5.0" + assert aioclient_mock.mock_calls[-1][1].query["start"] == "0" + assert aioclient_mock.mock_calls[-1][1].query["end"] == "5" + assert_setup_component(1, SWITCH_DOMAIN) + + """Tests for REST switch platform.""" def _setup_test_switch(hass): body_on = Template("on", hass) body_off = Template("off", hass) + headers = {"Content-type": Template(CONTENT_TYPE_JSON, hass)} switch = rest.RestSwitch( NAME, DEVICE_CLASS, RESOURCE, STATE_RESOURCE, METHOD, - HEADERS, + headers, PARAMS, AUTH, body_on,