diff --git a/homeassistant/components/switch/rest.py b/homeassistant/components/switch/rest.py index b68cc038e89..9c589d1d95b 100644 --- a/homeassistant/components/switch/rest.py +++ b/homeassistant/components/switch/rest.py @@ -13,8 +13,8 @@ import voluptuous as vol from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA) from homeassistant.const import ( - CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT, CONF_METHOD, CONF_USERNAME, - CONF_PASSWORD) + CONF_HEADERS, CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT, CONF_METHOD, + CONF_USERNAME, CONF_PASSWORD) from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv @@ -34,6 +34,7 @@ SUPPORT_REST_METHODS = ['post', 'put'] PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Required(CONF_RESOURCE): cv.url, + vol.Optional(CONF_HEADERS): {cv.string: cv.string}, vol.Optional(CONF_BODY_OFF, default=DEFAULT_BODY_OFF): cv.template, vol.Optional(CONF_BODY_ON, default=DEFAULT_BODY_ON): cv.template, vol.Optional(CONF_IS_ON_TEMPLATE): cv.template, @@ -54,6 +55,7 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None): body_on = config.get(CONF_BODY_ON) is_on_template = config.get(CONF_IS_ON_TEMPLATE) method = config.get(CONF_METHOD) + headers = config.get(CONF_HEADERS) name = config.get(CONF_NAME) username = config.get(CONF_USERNAME) password = config.get(CONF_PASSWORD) @@ -72,8 +74,8 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None): timeout = config.get(CONF_TIMEOUT) try: - switch = RestSwitch(name, resource, method, auth, body_on, body_off, - is_on_template, timeout) + switch = RestSwitch(name, resource, method, headers, auth, body_on, + body_off, is_on_template, timeout) req = yield from switch.get_device_state(hass) if req.status >= 400: @@ -90,13 +92,14 @@ def async_setup_platform(hass, config, async_add_devices, discovery_info=None): class RestSwitch(SwitchDevice): """Representation of a switch that can be toggled using REST.""" - def __init__(self, name, resource, method, auth, body_on, body_off, - is_on_template, timeout): + def __init__(self, name, resource, method, headers, auth, body_on, + body_off, is_on_template, timeout): """Initialize the REST switch.""" self._state = None self._name = name self._resource = resource self._method = method + self._headers = headers self._auth = auth self._body_on = body_on self._body_off = body_off @@ -153,7 +156,8 @@ class RestSwitch(SwitchDevice): with async_timeout.timeout(self._timeout, loop=self.hass.loop): req = yield from getattr(websession, self._method)( - self._resource, auth=self._auth, data=bytes(body, 'utf-8')) + self._resource, auth=self._auth, data=bytes(body, 'utf-8'), + headers=self._headers) return req @asyncio.coroutine diff --git a/tests/components/switch/test_rest.py b/tests/components/switch/test_rest.py index 1b8215660bd..064d0b1825b 100644 --- a/tests/components/switch/test_rest.py +++ b/tests/components/switch/test_rest.py @@ -82,6 +82,7 @@ class TestRestSwitchSetup: 'platform': 'rest', 'name': 'foo', 'resource': 'http://localhost', + 'headers': {'Content-type': 'application/json'}, 'body_on': 'custom on text', 'body_off': 'custom off text', } @@ -99,12 +100,13 @@ class TestRestSwitch: self.name = 'foo' self.method = 'post' self.resource = 'http://localhost/' + self.headers = {'Content-type': 'application/json'} self.auth = None self.body_on = Template('on', self.hass) self.body_off = Template('off', self.hass) self.switch = rest.RestSwitch( - self.name, self.resource, self.method, self.auth, self.body_on, - self.body_off, None, 10) + self.name, self.resource, self.method, self.headers, self.auth, + self.body_on, self.body_off, None, 10) self.switch.hass = self.hass def teardown_method(self):