diff --git a/homeassistant/components/switch/rest.py b/homeassistant/components/switch/rest.py index 36674c16d16..cfa11897de9 100644 --- a/homeassistant/components/switch/rest.py +++ b/homeassistant/components/switch/rest.py @@ -4,23 +4,27 @@ Support for RESTful switches. For more details about this platform, please refer to the documentation at https://home-assistant.io/components/switch.rest/ """ +import asyncio import logging -import requests +import aiohttp +import async_timeout import voluptuous as vol from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA) from homeassistant.const import (CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT) +from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv from homeassistant.helpers.template import Template CONF_BODY_OFF = 'body_off' CONF_BODY_ON = 'body_on' +CONF_IS_ON_TEMPLATE = 'is_on_template' + DEFAULT_BODY_OFF = Template('OFF') DEFAULT_BODY_ON = Template('ON') DEFAULT_NAME = 'REST Switch' DEFAULT_TIMEOUT = 10 -CONF_IS_ON_TEMPLATE = 'is_on_template' PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Required(CONF_RESOURCE): cv.url, @@ -35,13 +39,15 @@ _LOGGER = logging.getLogger(__name__) # pylint: disable=unused-argument, -def setup_platform(hass, config, add_devices, discovery_info=None): +@asyncio.coroutine +def async_setup_platform(hass, config, async_add_devices, discovery_info=None): """Set up the RESTful switch.""" name = config.get(CONF_NAME) resource = config.get(CONF_RESOURCE) body_on = config.get(CONF_BODY_ON) body_off = config.get(CONF_BODY_OFF) is_on_template = config.get(CONF_IS_ON_TEMPLATE) + websession = async_get_clientsession(hass) if is_on_template is not None: is_on_template.hass = hass @@ -51,19 +57,24 @@ def setup_platform(hass, config, add_devices, discovery_info=None): body_off.hass = hass timeout = config.get(CONF_TIMEOUT) + req = None try: - requests.get(resource, timeout=10) - except requests.exceptions.MissingSchema: + with async_timeout.timeout(timeout, loop=hass.loop): + req = yield from websession.get(resource) + except (TypeError, ValueError): _LOGGER.error("Missing resource or schema in configuration. " "Add http:// or https:// to your URL") return False - except requests.exceptions.ConnectionError: + except (asyncio.TimeoutError, aiohttp.errors.ClientError): _LOGGER.error("No route to resource/endpoint: %s", resource) return False + finally: + if req is not None: + yield from req.release() - add_devices( - [RestSwitch( - hass, name, resource, body_on, body_off, is_on_template, timeout)]) + yield from async_add_devices( + [RestSwitch(hass, name, resource, body_on, body_off, + is_on_template, timeout)]) class RestSwitch(SwitchDevice): @@ -73,7 +84,7 @@ class RestSwitch(SwitchDevice): is_on_template, timeout): """Initialize the REST switch.""" self._state = None - self._hass = hass + self.hass = hass self._name = name self._resource = resource self._body_on = body_on @@ -91,46 +102,85 @@ class RestSwitch(SwitchDevice): """Return true if device is on.""" return self._state - def turn_on(self, **kwargs): + @asyncio.coroutine + def async_turn_on(self, **kwargs): """Turn the device on.""" - body_on_t = self._body_on.render() - request = requests.post( - self._resource, data=body_on_t, timeout=self._timeout) - if request.status_code == 200: + body_on_t = self._body_on.async_render() + websession = async_get_clientsession(self.hass) + + request = None + try: + with async_timeout.timeout(self._timeout, loop=self.hass.loop): + request = yield from websession.post( + self._resource, data=bytes(body_on_t, 'utf-8')) + except (asyncio.TimeoutError, aiohttp.errors.ClientError): + _LOGGER.error("Error while turn on %s", self._resource) + return + finally: + if request is not None: + yield from request.release() + + if request.status == 200: self._state = True else: _LOGGER.error("Can't turn on %s. Is resource/endpoint offline?", self._resource) - def turn_off(self, **kwargs): + @asyncio.coroutine + def async_turn_off(self, **kwargs): """Turn the device off.""" - body_off_t = self._body_off.render() - request = requests.post( - self._resource, data=body_off_t, timeout=self._timeout) - if request.status_code == 200: + body_off_t = self._body_off.async_render() + websession = async_get_clientsession(self.hass) + + request = None + try: + with async_timeout.timeout(self._timeout, loop=self.hass.loop): + request = yield from websession.post( + self._resource, data=bytes(body_off_t, 'utf-8')) + except (asyncio.TimeoutError, aiohttp.errors.ClientError): + _LOGGER.error("Error while turn off %s", self._resource) + return + finally: + if request is not None: + yield from request.release() + + if request.status == 200: self._state = False else: _LOGGER.error("Can't turn off %s. Is resource/endpoint offline?", self._resource) - def update(self): + @asyncio.coroutine + def async_update(self): """Get the latest data from REST API and update the state.""" - request = requests.get(self._resource, timeout=self._timeout) + websession = async_get_clientsession(self.hass) + + request = None + try: + with async_timeout.timeout(self._timeout, loop=self.hass.loop): + request = yield from websession.get(self._resource) + text = yield from request.text() + except (asyncio.TimeoutError, aiohttp.errors.ClientError): + _LOGGER.exception("Error while fetch data.") + return + finally: + if request is not None: + yield from request.release() if self._is_on_template is not None: - response = self._is_on_template.render_with_possible_json_value( - request.text, 'None') - response = response.lower() - if response == 'true': + text = self._is_on_template.async_render_with_possible_json_value( + text, 'None') + text = text.lower() + if text == 'true': self._state = True - elif response == 'false': + elif text == 'false': self._state = False else: self._state = None else: - if request.text == self._body_on.template: + if text == self._body_on.template: self._state = True - elif request.text == self._body_off.template: + elif text == self._body_off.template: self._state = False else: self._state = None diff --git a/tests/components/switch/test_rest.py b/tests/components/switch/test_rest.py index dc6c58db928..38ddad5e9a2 100644 --- a/tests/components/switch/test_rest.py +++ b/tests/components/switch/test_rest.py @@ -1,76 +1,83 @@ """The tests for the REST switch platform.""" -import unittest -from unittest.mock import patch +import asyncio -import pytest -import requests -from requests.exceptions import Timeout -import requests_mock +import aiohttp import homeassistant.components.switch.rest as rest from homeassistant.bootstrap import setup_component +from homeassistant.util.async import run_coroutine_threadsafe +from homeassistant.helpers.template import Template from tests.common import get_test_home_assistant, assert_setup_component -class TestRestSwitchSetup(unittest.TestCase): +class TestRestSwitchSetup: """Tests for setting up the REST switch platform.""" - def setUp(self): + def setup_method(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - def tearDown(self): + def teardown_method(self): """Stop everything that was started.""" self.hass.stop() def test_setup_missing_config(self): """Test setup with configuration missing required entries.""" - self.assertFalse(rest.setup_platform(self.hass, { - 'platform': 'rest' - }, None)) + assert not run_coroutine_threadsafe( + rest.async_setup_platform(self.hass, { + 'platform': 'rest' + }, None), + self.hass.loop + ).result() def test_setup_missing_schema(self): """Test setup with resource missing schema.""" - self.assertFalse(rest.setup_platform(self.hass, { - 'platform': 'rest', - 'resource': 'localhost' - }, None)) + assert not run_coroutine_threadsafe( + rest.async_setup_platform(self.hass, { + 'platform': 'rest', + 'resource': 'localhost' + }, None), + self.hass.loop + ).result() - @patch('requests.get', side_effect=requests.exceptions.ConnectionError()) - def test_setup_failed_connect(self, mock_req): + def test_setup_failed_connect(self, aioclient_mock): """Test setup when connection error occurs.""" - self.assertFalse(rest.setup_platform(self.hass, { - 'platform': 'rest', - 'resource': 'http://localhost', - }, None)) - - @patch('requests.get', side_effect=Timeout()) - def test_setup_timeout(self, mock_req): - """Test setup when connection timeout occurs.""" - with self.assertRaises(Timeout): - rest.setup_platform(self.hass, { + aioclient_mock.get('http://localhost', exc=aiohttp.errors.ClientError) + assert not run_coroutine_threadsafe( + rest.async_setup_platform(self.hass, { 'platform': 'rest', 'resource': 'http://localhost', - }, None) + }, None), + self.hass.loop + ).result() - @requests_mock.Mocker() - def test_setup_minimum(self, mock_req): - """Test setup with minimum configuration.""" - mock_req.get('http://localhost', status_code=200) - self.assertTrue(setup_component(self.hass, 'switch', { - 'switch': { + def test_setup_timeout(self, aioclient_mock): + """Test setup when connection timeout occurs.""" + aioclient_mock.get('http://localhost', exc=asyncio.TimeoutError()) + assert not run_coroutine_threadsafe( + rest.async_setup_platform(self.hass, { 'platform': 'rest', - 'resource': 'http://localhost' - } - })) - self.assertEqual(1, mock_req.call_count) - assert_setup_component(1, 'switch') + 'resource': 'http://localhost', + }, None), + self.hass.loop + ).result() - @requests_mock.Mocker() - def test_setup(self, mock_req): + def test_setup_minimum(self, aioclient_mock): + """Test setup with minimum configuration.""" + aioclient_mock.get('http://localhost', status=200) + with assert_setup_component(1, 'switch'): + assert setup_component(self.hass, 'switch', { + 'switch': { + 'platform': 'rest', + 'resource': 'http://localhost' + } + }) + assert aioclient_mock.call_count == 1 + + def test_setup(self, aioclient_mock): """Test setup with valid configuration.""" - mock_req.get('localhost', status_code=200) - self.assertTrue(setup_component(self.hass, 'switch', { + aioclient_mock.get('http://localhost', status=200) + assert setup_component(self.hass, 'switch', { 'switch': { 'platform': 'rest', 'name': 'foo', @@ -78,111 +85,120 @@ class TestRestSwitchSetup(unittest.TestCase): 'body_on': 'custom on text', 'body_off': 'custom off text', } - })) - self.assertEqual(1, mock_req.call_count) + }) + assert aioclient_mock.call_count == 1 assert_setup_component(1, 'switch') -@pytest.mark.skip -class TestRestSwitch(unittest.TestCase): +class TestRestSwitch: """Tests for REST switch platform.""" - def setUp(self): + def setup_method(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() self.name = 'foo' self.resource = 'http://localhost/' - self.body_on = 'on' - self.body_off = 'off' + self.body_on = Template('on', self.hass) + self.body_off = Template('off', self.hass) self.switch = rest.RestSwitch(self.hass, self.name, self.resource, - self.body_on, self.body_off) + self.body_on, self.body_off, None, 10) - def tearDown(self): + def teardown_method(self): """Stop everything that was started.""" self.hass.stop() def test_name(self): """Test the name.""" - self.assertEqual(self.name, self.switch.name) + assert self.name == self.switch.name def test_is_on_before_update(self): """Test is_on in initial state.""" - self.assertEqual(None, self.switch.is_on) + assert self.switch.is_on is None - @requests_mock.Mocker() - def test_turn_on_success(self, mock_req): + def test_turn_on_success(self, aioclient_mock): """Test turn_on.""" - mock_req.post(self.resource, status_code=200) - self.switch.turn_on() + aioclient_mock.post(self.resource, status=200) + run_coroutine_threadsafe( + self.switch.async_turn_on(), self.hass.loop).result() - self.assertEqual(self.body_on, mock_req.last_request.text) - self.assertEqual(True, self.switch.is_on) + assert self.body_on.template == \ + aioclient_mock.mock_calls[-1][2].decode() + assert self.switch.is_on - @requests_mock.Mocker() - def test_turn_on_status_not_ok(self, mock_req): + def test_turn_on_status_not_ok(self, aioclient_mock): """Test turn_on when error status returned.""" - mock_req.post(self.resource, status_code=500) - self.switch.turn_on() + aioclient_mock.post(self.resource, status=500) + run_coroutine_threadsafe( + self.switch.async_turn_on(), self.hass.loop).result() - self.assertEqual(self.body_on, mock_req.last_request.text) - self.assertEqual(None, self.switch.is_on) + assert self.body_on.template == \ + aioclient_mock.mock_calls[-1][2].decode() + assert self.switch.is_on is None - @patch('requests.post', side_effect=Timeout()) - def test_turn_on_timeout(self, mock_req): + def test_turn_on_timeout(self, aioclient_mock): """Test turn_on when timeout occurs.""" - with self.assertRaises(Timeout): - self.switch.turn_on() + aioclient_mock.post(self.resource, status=500) + run_coroutine_threadsafe( + self.switch.async_turn_on(), self.hass.loop).result() - @requests_mock.Mocker() - def test_turn_off_success(self, mock_req): + assert self.switch.is_on is None + + def test_turn_off_success(self, aioclient_mock): """Test turn_off.""" - mock_req.post(self.resource, status_code=200) - self.switch.turn_off() + aioclient_mock.post(self.resource, status=200) + run_coroutine_threadsafe( + self.switch.async_turn_off(), self.hass.loop).result() - self.assertEqual(self.body_off, mock_req.last_request.text) - self.assertEqual(False, self.switch.is_on) + assert self.body_off.template == \ + aioclient_mock.mock_calls[-1][2].decode() + assert not self.switch.is_on - @requests_mock.Mocker() - def test_turn_off_status_not_ok(self, mock_req): + def test_turn_off_status_not_ok(self, aioclient_mock): """Test turn_off when error status returned.""" - mock_req.post(self.resource, status_code=500) - self.switch.turn_off() + aioclient_mock.post(self.resource, status=500) + run_coroutine_threadsafe( + self.switch.async_turn_off(), self.hass.loop).result() - self.assertEqual(self.body_off, mock_req.last_request.text) - self.assertEqual(None, self.switch.is_on) + assert self.body_off.template == \ + aioclient_mock.mock_calls[-1][2].decode() + assert self.switch.is_on is None - @patch('requests.post', side_effect=Timeout()) - def test_turn_off_timeout(self, mock_req): + def test_turn_off_timeout(self, aioclient_mock): """Test turn_off when timeout occurs.""" - with self.assertRaises(Timeout): - self.switch.turn_on() + aioclient_mock.post(self.resource, exc=asyncio.TimeoutError()) + run_coroutine_threadsafe( + self.switch.async_turn_on(), self.hass.loop).result() - @requests_mock.Mocker() - def test_update_when_on(self, mock_req): + assert self.switch.is_on is None + + def test_update_when_on(self, aioclient_mock): """Test update when switch is on.""" - mock_req.get(self.resource, text=self.body_on) - self.switch.update() + aioclient_mock.get(self.resource, text=self.body_on.template) + run_coroutine_threadsafe( + self.switch.async_update(), self.hass.loop).result() - self.assertEqual(True, self.switch.is_on) + assert self.switch.is_on - @requests_mock.Mocker() - def test_update_when_off(self, mock_req): + def test_update_when_off(self, aioclient_mock): """Test update when switch is off.""" - mock_req.get(self.resource, text=self.body_off) - self.switch.update() + aioclient_mock.get(self.resource, text=self.body_off.template) + run_coroutine_threadsafe( + self.switch.async_update(), self.hass.loop).result() - self.assertEqual(False, self.switch.is_on) + assert not self.switch.is_on - @requests_mock.Mocker() - def test_update_when_unknown(self, mock_req): + def test_update_when_unknown(self, aioclient_mock): """Test update when unknown status returned.""" - mock_req.get(self.resource, text='unknown status') - self.switch.update() + aioclient_mock.get(self.resource, text='unknown status') + run_coroutine_threadsafe( + self.switch.async_update(), self.hass.loop).result() - self.assertEqual(None, self.switch.is_on) + assert self.switch.is_on is None - @patch('requests.get', side_effect=Timeout()) - def test_update_timeout(self, mock_req): + def test_update_timeout(self, aioclient_mock): """Test update when timeout occurs.""" - with self.assertRaises(Timeout): - self.switch.update() + aioclient_mock.get(self.resource, exc=asyncio.TimeoutError()) + run_coroutine_threadsafe( + self.switch.async_update(), self.hass.loop).result() + + assert self.switch.is_on is None diff --git a/tests/test_util/aiohttp.py b/tests/test_util/aiohttp.py index 4abf43a6e42..c0ed579f197 100644 --- a/tests/test_util/aiohttp.py +++ b/tests/test_util/aiohttp.py @@ -20,6 +20,7 @@ class AiohttpClientMocker: auth=None, status=200, text=None, + data=None, content=None, json=None, params=None, @@ -66,12 +67,12 @@ class AiohttpClientMocker: return len(self.mock_calls) @asyncio.coroutine - def match_request(self, method, url, *, auth=None, params=None, + def match_request(self, method, url, *, data=None, auth=None, params=None, headers=None): # pylint: disable=unused-variable """Match a request against pre-registered requests.""" for response in self._mocks: if response.match_request(method, url, params): - self.mock_calls.append((method, url)) + self.mock_calls.append((method, url, data)) if self.exc: raise self.exc